From b9fad5dad3d547f2516e61bfa9fdc9aa5c408f9c Mon Sep 17 00:00:00 2001 From: JZ-LIANG Date: Wed, 14 Dec 2022 14:41:11 +0800 Subject: [PATCH] [Bugfix] recompute dep filter param (#49010) * recompute dep filter param * recompute dep for reshard --- python/paddle/distributed/auto_parallel/reshard.py | 2 ++ python/paddle/distributed/auto_parallel/utils.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 06231d9b59..62e383c72e 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -2497,6 +2497,8 @@ class Resharder: "read", "write_to_array", "read_from_array", + "nop", + "depend", ] global _g_special_ops skip_ops += _g_special_ops diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 62004269e9..0883417fc9 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -2168,6 +2168,9 @@ def insert_dependencies_for_two_ops( def _select_best_depend_var(vars): + # parameter should not be dep var since it maybe partition in sharding pass + vars = [var for var in vars if not var.is_parameter] + assert len(vars) > 0 vars_with_numels = [(var, get_var_numel(var)) for var in vars] vars_with_numels.sort(key=lambda x: x[1]) -- GitLab