1 minute read

In this blog post I want to share with you a trick that I find myself using quite often when trying to instantiate different machine learning models with different parameters in the same machine learning pipeline.

I use this trick to keep my training code generic and model-agnostic while writing machine learning code.

We start from a simple model interface that implements a forward pass

from abc import ABC, abstractmethod

class ModelInterface(ABC):
    def forward(self):
        raise NotImplementedError

Our first model, model A, is instantiated with a variable specifying the latent dimensions:

class ModelA(ModelInterface):
    def __init__(self, latent_dim):
        self.latent_dim = latent_dim

    def forward(self):
        print(f"fwd pass with latent_dim={self.latent_dim}")

Our second model, model B, is instantiated with a different variable specifying the number of embeddings:

class ModelB(ModelInterface):
    def __init__(self, num_embeddings):
        self.num_embeddings = num_embeddings

    def forward(self):
        print(f"fwd pass with num_embeddings={self.num_embeddings}")

How do we now write generic code that can handle both these models?

Let me show you.

First, and this is extra, we create an identifier enum for each model:

from enum import Enum

class ModelIdentifier(Enum):
    ModelA = "ModelA"
    ModelB = "ModelB"

Then, we create the model map, mapping the identifier to the model

model_name_to_model = {
    ModelIdentifier.ModelA: ModelA,
    ModelIdentifier.ModelB: ModelB,

Now, using dictionary unpacking we can use this model map in combination with the model identifier to choose the right model and instantiate it!

Consider model A for example:

config_a = {
    "latent_dim": 10,
model = model_name_to_model[ModelIdentifier.ModelA](**config_a)
model.forward() # Results in: fwd pass with latent_dim=1

And consider model B now with a different config

config_b = {
    "num_embeddings": 1,
model = model_name_to_model[ModelIdentifier.ModelB](**config_b)
model.forward() # Results in: fwd pass with num_embeddings=1

Of course, this wouldn’t work if we did not adhere to the interface, but if you stay true to the interface, then this can come in handy. That’s it! Until next time!

Remember You can use dictionary unpacking (**my_dict) in combination with an interface to instantiate models with different parameters!