未验证 提交 e9eb5db3 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] keep input order (#48963)

* [AutoParallel] keep input order

* rm annotation
上级 1acddc34
......@@ -195,17 +195,16 @@ class DistributedSaver:
used_inputs += op.input_arg_names
used_outputs += op.output_arg_names
dist_feed_vars_names = list(set(feed_vars_names) & set(used_inputs))
dist_fetch_vars_names = list(set(fetch_vars_names) & set(used_outputs))
for idx, var_name in enumerate(feed_vars_names):
if var_name not in used_inputs:
feed_vars_names.pop(idx)
for idx, var_name in enumerate(fetch_vars_names):
if var_name not in used_outputs:
fetch_vars_names.pop(idx)
dist_feed_vars = [
global_block.vars[name] for name in dist_feed_vars_names
]
dist_fetch_vars = [
global_block.vars[name] for name in dist_fetch_vars_names
]
dist_feed_vars = [global_block.vars[name] for name in feed_vars_names]
dist_fetch_vars = [global_block.vars[name] for name in fetch_vars_names]
# NOTE: `paddle.static.save_inference_model` does not support subblock.
dist_filename = filename + "_dist" + str(rank_id)
dist_path = os.path.join(dirname, dist_filename)
paddle.static.save_inference_model(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册