未验证 提交 4eaae120 编写于 作者: C ceci3 提交者: GitHub

fix train config (#1226)

* fix train config

* fix
上级 81fa1823
......@@ -137,14 +137,13 @@ class AutoCompression:
# load config
if isinstance(config, str):
config = load_config(config)
self.strategy_config = extract_strategy_config(config)
self.train_config = extract_train_config(config)
elif isinstance(config, dict):
if 'TrainConfig' in config:
self.train_config = config.pop('TrainConfig')
self.train_config = TrainConfig(**config.pop('TrainConfig'))
else:
self.train_config = None
self.strategy_config = config
self.strategy_config = extract_strategy_config(config)
# prepare dataloader
self.feed_vars = get_feed_vars(self.model_dir, model_filename,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册