diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 06231d9b59a267514a75aa0285aab6b9b6e52bff..62e383c72e9f30e2d81f4d881783f9a7b7434a25 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -2497,6 +2497,8 @@ class Resharder: "read", "write_to_array", "read_from_array", + "nop", + "depend", ] global _g_special_ops skip_ops += _g_special_ops diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 62004269e9b5681326608950e328dc247a640b7e..0883417fc9e82ca2fce373b831311747773ab811 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -2168,6 +2168,9 @@ def insert_dependencies_for_two_ops( def _select_best_depend_var(vars): + # parameter should not be dep var since it maybe partition in sharding pass + vars = [var for var in vars if not var.is_parameter] + assert len(vars) > 0 vars_with_numels = [(var, get_var_numel(var)) for var in vars] vars_with_numels.sort(key=lambda x: x[1])