提交 d37dc3fa 编写于 作者: D dengkaipeng

fix load_params in stnet.py

上级 5e682b3c
...@@ -241,14 +241,3 @@ def regist_model(name, model): ...@@ -241,14 +241,3 @@ def regist_model(name, model):
def get_model(name, cfg, mode='train', args=None): def get_model(name, cfg, mode='train', args=None):
return model_zoo.get(name, cfg, mode, args) return model_zoo.get(name, cfg, mode, args)
if __name__ == "__main__":
class TestModel(ModelBase):
pass
model_zoo.regist('test', TestModel)
m = model_zoo.get('test', './config.txt')
print(m.get_train_config('batch_size'))
m.build_model()
m = model_zoo.get('test2', './config.txt')
...@@ -154,7 +154,13 @@ class STNET(ModelBase): ...@@ -154,7 +154,13 @@ class STNET(ModelBase):
return {} return {}
def load_pretrain_params(self, exe, pretrain, prog): def load_pretrain_params(self, exe, pretrain, prog):
fluid.io.load_params(exe, pretrain, main_program=prog) def is_parameter(var):
if isinstance(var, fluid.framework.Parameter):
return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \
and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name))
vars = filter(is_parameter, prog.list_vars())
fluid.io.load_vars(exe, pretrain, vars=vars)
param_tensor = fluid.global_scope().find_var( param_tensor = fluid.global_scope().find_var(
"conv1_weights").get_tensor() "conv1_weights").get_tensor()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册