Reducing our stored model state by 80% using bit manipulation magic in Python
This week was pretty cool, together with a bearded grey wizard on our data engineering team we managed to reduce our json payload size more than 80% (from roughly 18kb to less than 4kb every request) using some python bit manipulation magic.
This was pretty wild, but also pretty necessary, because this is part of an API that will be called around 20 million times a day.
For some context, I’m working in the machine learning infrastructure team at my company and we are currently revamping our complete architecture how we host and productise models.
One of our core tasks as a machine learning platform is storing and retrieving model state, and because this happens so often (with every model call) we need to do this fast and efficiently.
We (the platform team) care about how things run in production, the data science team… not so much. To our horror we realized that the data science team added this thing called the answer_type_count
to the state they wanted to store.
@dataclass
class ModelState(PupilState):
answer_type_counts: Dict[str, List[int]] # <== new
f_embedding: float
s_embedding: List[float]
We thought this was fine, but when we started looking at some production json payloads with the envisioned embedding sizes we realized that this would be a problem because the answer_type_counts
was a vector of 3000 integers! Storing this extra data every API call would be a disaster and explode our costs.
I was in a meeting with our PO and the wise greybeard from the data engineering team and the core message of that meeting was: we have to get this payload down. The greybeard said “No problem, let’s get to work.”
The first thing the greybeard quizzed me on was about the numbers: “What’s the current payload size? What kind of numbers do we have? How many do we need to store? Do we need to store all of them? Can we store a subset of them?”
I answered them in rapid succession. “Current payload size is 18kb. We need to store ints. Around 3000 of them. The maximum int we need to store is 101 because that’s the last embedding that the model uses…”
When I said this, his eyes lit up. He continued: “That’s great news, because we can count to 2, 4, 8, … we can count to 255 with just 8 bits. So if we need to store just 101 we only need 7 bits, but for simplicity let’s just take the first 8 bits or the first byte.”
Now look. Here I am kind of like a fish out of the water. This bit manipulation felt like black magic to me, and I was stoked, eager to help out wherever possible.
I knew that by default Python stores its integers in 32 bit format. So I said: “So do I understand it correctly that what you’re saying is that instead of using 4 bytes of the 32 bit integer, we can just chop off the first 3 bytes and only take the last byte? Because we only need to count to 101?”
“Yup,” the greybeard replied.
“How do we do this?” I asked.
The greybeard explained that we could take the list of ints and pack together the contents as bytes, converting them to unsigned chars using struct
.
Compressing this list of integers then looks something like this
>>> lst = [1, 2, 3]
>>> capped = [min(x, 101) for x in lst]
>>> compressed = base64.b64encode(zlib.compress(struct.pack("<" + "B" * len(capped), *capped))).decode("utf-8")
>>> compressed
eJxjZGIGAAANAAc=
So we convert the list [1,2,3]
to the string "eJxjZGIGAAANAAc="
.
Let’s run through this code line by line:
[min(x, 101) for x in lst]
caps the elements to a maximum of 101struct.pack()
with converts the elements in the list to unsigned chars ("B"
) in the little endian form ("<"
) allowing us to only store 8-bits or 1-byte per integerzlib.compress()
compresses these bytes even further, this is especially powerful when the list contains a lot of repetitionbase64.b64encode
then encodes the bytes into a format that we can send over the web (i.e. as a stringified json)
What’s really cool is that the more repetition we have in the list, the better the compression.
lst = 10_000 * [0, 1]
capped = [min(x, 101) for x in lst]
compressed = base64.b64encode(zlib.compress(struct.pack("<" + "B" * len(capped), *capped))).decode("utf-8")
size_with = len(compressed) # 60
size_without = len(json.dumps(lst)) # 60_000
Decompressing is just doing everything the other way around.
>>> bs: bytes = zlib.decompress(base64.b64decode(compressed))
>>> lst = list(struct.unpack("<" + "B" * len(bs), bs))
>>> st
[1, 2, 3]
Applying this “simple” compression algorithm is what pushed the state down from 18kb to 4kb, a rough 80% reduction in size, amazing. Really cool stuff.
Comments