1 minute read

In this blog post I want to share with you a trick that I find myself using quite often when trying to instantiate different machine learning models with different parameters in the same machine learning pipeline.

I use this trick to keep my training code generic and model-agnostic while writing machine learning code.

We start from a simple model interface that implements a forward pass

from abc import ABC, abstractmethod

class ModelInterface(ABC):
    def forward(self):
        raise NotImplementedError

Our first model, model A, is instantiated with a variable specifying the latent dimensions:

class ModelA(ModelInterface):
    def __init__(self, latent_dim):
        self.latent_dim = latent_dim

    def forward(self):
        print(f"fwd pass with latent_dim={self.latent_dim}")

Our second model, model B, is instantiated with a different variable specifying the number of embeddings:

class ModelB(ModelInterface):
    def __init__(self, num_embeddings):
        self.num_embeddings = num_embeddings

    def forward(self):
        print(f"fwd pass with num_embeddings={self.num_embeddings}")

How do we now write generic code that can handle both these models?

Let me show you.

First, and this is extra, we create an identifier enum for each model:

from enum import Enum

class ModelIdentifier(Enum):
    ModelA = "ModelA"
    ModelB = "ModelB"

Then, we create the model map, mapping the identifier to the model

model_name_to_model = {
    ModelIdentifier.ModelA: ModelA,
    ModelIdentifier.ModelB: ModelB,

Now, using dictionary unpacking we can use this model map in combination with the model identifier to choose the right model and instantiate it!

Consider model A for example:

config_a = {
    "latent_dim": 10,
model = model_name_to_model[ModelIdentifier.ModelA](**config_a)
model.forward() # Results in: fwd pass with latent_dim=1

And consider model B now with a different config

config_b = {
    "num_embeddings": 1,
model = model_name_to_model[ModelIdentifier.ModelB](**config_b)
model.forward() # Results in: fwd pass with num_embeddings=1

Of course, this wouldnโ€™t work if we did not adhere to the interface, but if you stay true to the interface, then this can come in handy. Thatโ€™s it! Until next time!

Remember You can use dictionary unpacking (**my_dict) in combination with an interface to instantiate models with different parameters!