未验证 提交 b9fad5da 编写于 作者: J JZ-LIANG 提交者: GitHub

[Bugfix] recompute dep filter param (#49010)

* recompute dep filter param

* recompute dep for reshard
上级 a8d139a4
...@@ -2497,6 +2497,8 @@ class Resharder: ...@@ -2497,6 +2497,8 @@ class Resharder:
"read", "read",
"write_to_array", "write_to_array",
"read_from_array", "read_from_array",
"nop",
"depend",
] ]
global _g_special_ops global _g_special_ops
skip_ops += _g_special_ops skip_ops += _g_special_ops
......
...@@ -2168,6 +2168,9 @@ def insert_dependencies_for_two_ops( ...@@ -2168,6 +2168,9 @@ def insert_dependencies_for_two_ops(
def _select_best_depend_var(vars): def _select_best_depend_var(vars):
# parameter should not be dep var since it maybe partition in sharding pass
vars = [var for var in vars if not var.is_parameter]
assert len(vars) > 0
vars_with_numels = [(var, get_var_numel(var)) for var in vars] vars_with_numels = [(var, get_var_numel(var)) for var in vars]
vars_with_numels.sort(key=lambda x: x[1]) vars_with_numels.sort(key=lambda x: x[1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册