提交 c0ca8fee 编写于 作者: R rensilin

dump params

Change-Id: I3a1a1c67500614b946cec92794eefe4a7f20f11a
上级 825e947c
...@@ -114,12 +114,14 @@ class ModelBuilder: ...@@ -114,12 +114,14 @@ class ModelBuilder:
with open(os.path.join(self._save_path, name), 'w') as f: with open(os.path.join(self._save_path, name), 'w') as f:
f.write(program.desc.serialize_to_string()) f.write(program.desc.serialize_to_string())
params = filter(fluid.io.is_parameter, main_program.list_vars())
model_desc_path = os.path.join(self._save_path, 'model.yaml') model_desc_path = os.path.join(self._save_path, 'model.yaml')
model_desc = { model_desc = {
'inputs': [{"name": var.name, "shape": var.shape} for var in inputs], 'inputs': [{"name": var.name, "shape": var.shape} for var in inputs],
'outputs': [{"name": var.name, "shape": var.shape} for var in outputs], 'outputs': [{"name": var.name, "shape": var.shape} for var in outputs],
'labels': [{"name": var.name, "shape": var.shape} for var in labels], 'labels': [{"name": var.name, "shape": var.shape} for var in labels],
'vars': [{"name": var.name, "shape": var.shape} for var in main_program.list_vars() if fluid.io.is_parameter(var)], 'vars': [{"name": var.name, "shape": var.shape} for var in params],
'loss': loss.name, 'loss': loss.name,
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册