未验证 提交 c56d9026 编写于 作者: F flame 提交者: GitHub

return saved targets' name list (#16240)

上级 996a7473
...@@ -895,7 +895,7 @@ def save_inference_model(dirname, ...@@ -895,7 +895,7 @@ def save_inference_model(dirname,
True is supported. True is supported.
Returns: Returns:
None target_var_name_list(list): The fetch variables' name list
Raises: Raises:
ValueError: If `feed_var_names` is not a list of basestring. ValueError: If `feed_var_names` is not a list of basestring.
...@@ -954,6 +954,7 @@ def save_inference_model(dirname, ...@@ -954,6 +954,7 @@ def save_inference_model(dirname,
var, 1., name="save_infer_model/scale_{}".format(i)) var, 1., name="save_infer_model/scale_{}".format(i))
uniq_target_vars.append(var) uniq_target_vars.append(var)
target_vars = uniq_target_vars target_vars = uniq_target_vars
target_var_name_list = [var.name for var in target_vars]
# when a pserver and a trainer running on the same machine, mkdir may conflict # when a pserver and a trainer running on the same machine, mkdir may conflict
try: try:
...@@ -1010,6 +1011,7 @@ def save_inference_model(dirname, ...@@ -1010,6 +1011,7 @@ def save_inference_model(dirname,
params_filename = os.path.basename(params_filename) params_filename = os.path.basename(params_filename)
save_persistables(executor, dirname, main_program, params_filename) save_persistables(executor, dirname, main_program, params_filename)
return target_var_name_list
def load_inference_model(dirname, def load_inference_model(dirname,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册