From c0ca8fee09708a7ddc6989e5f0225fa82913ad57 Mon Sep 17 00:00:00 2001 From: rensilin Date: Tue, 20 Aug 2019 14:28:35 +0800 Subject: [PATCH] dump params Change-Id: I3a1a1c67500614b946cec92794eefe4a7f20f11a --- .../train/custom_trainer/feed/scripts/create_programs.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py b/paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py index c9b9ca94..e2b1cb7c 100644 --- a/paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py +++ b/paddle/fluid/train/custom_trainer/feed/scripts/create_programs.py @@ -114,12 +114,14 @@ class ModelBuilder: with open(os.path.join(self._save_path, name), 'w') as f: f.write(program.desc.serialize_to_string()) + params = filter(fluid.io.is_parameter, main_program.list_vars()) + model_desc_path = os.path.join(self._save_path, 'model.yaml') model_desc = { 'inputs': [{"name": var.name, "shape": var.shape} for var in inputs], 'outputs': [{"name": var.name, "shape": var.shape} for var in outputs], 'labels': [{"name": var.name, "shape": var.shape} for var in labels], - 'vars': [{"name": var.name, "shape": var.shape} for var in main_program.list_vars() if fluid.io.is_parameter(var)], + 'vars': [{"name": var.name, "shape": var.shape} for var in params], 'loss': loss.name, } -- GitLab