This issue isn’t specific to guild
, but I was hoping someone had some advice on how to fix this problem.
I want to parameterize through guild
which model to test. The different models I want to test all have different interfaces exposing different hyperparameters. A minimal example here:
import argparse
class ModelA:
def __init__(self, filter_size):
self.name = "ModelA"
self.filter_size = filter_size
class ModelB:
def __init__(self, num_layers):
self.name = "ModelB"
self.num_layers = num_layers
MODELS = {"ModelA": ModelA, "ModelB": ModelB}
if __name__ == "__main__":
pa = argparse.ArgumentParser()
pa.add_argument("--model")
pa.add_argument("--filter_size")
pa.add_argument("--num_layers")
args = pa.parse_args()
model = MODELS[args.model]
# Init model
model(...)
As you can see, it is not clear how to initialize the chosen model. I have thought of a couple of different solutions myself:
-
A separate script for each model. This quickly becomes infeasible with many models, so I opted not to go this way.
-
Some magic using the python
inspect
module. Basically infer which parameters the model constructor requires and then extract them from theNamespace
object. This don’t work well withguild
flags as unusedflags
will still be capture inguild compare
. -
Only pass required parameters for the object you chose, so something like
python test_model.py --model_name ModelA --filter_size 3
, which is the solution I am leaning towards now. I am unsure though how that would work withguild
though as I don’t think you can make flags optional. -
Require all models to have
*args, **kwargs
in the constructor, but again, I think this is a bit hacky. -
Factory methods for each model
ModelA.from_namespace
, but this is also a little hacky, I think.
Open for any suggestions here!
EDIT:
I just remembered seeing in pytorch-lightning
how they handle this. I like this approach, so I might just do this. Will guild
be able to capture the flags here?
EDIT2:
Looks like guild
doesn’t handle this pattern well:
if __name__ == "__main__":
pa = argparse.ArgumentParser()
pa.add_argument("--model")
pa.add_argument("--filter_size")
pa.add_argument("--num_layers")
pa = parse_args()
temp_args = pa.parse_known_args()[0]
model = MODELS[temp_args.model_name]
pa = model.add_model_specific_args(pa)
args = pa.parse_args()
model = model.from_args(args)
guild
throws me this warning and error:
WARNING: cannot import flags from test.py: .venv/bin/python3: Error while finding module specification for 'guild.plugins.import_argparse_flags_main' (ModuleNotFoundError: No module named 'guild.plugins'; 'guild' is not a package)
and
.venv/bin/python: Error while finding module specification for 'guild.op_main' (AttributeError: module 'guild' has no attribute '__path__')