From 9a9953d9b0b32456fdb35e2bdb94679375b694dd Mon Sep 17 00:00:00 2001 From: zhangbo9674 <82555433+zhangbo9674@users.noreply.github.com> Date: Wed, 13 Oct 2021 16:01:20 +0800 Subject: [PATCH] [AMP] add attr is_distributed for layer.to (#36221) * add attr is_distributed * refine code * refine black/white list for pure fp16 --- python/paddle/fluid/dygraph/amp/auto_cast.py | 4 ++-- python/paddle/fluid/dygraph/layers.py | 5 +++++ python/paddle/fluid/framework.py | 1 - 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index d218e6b749..c807303621 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -70,8 +70,8 @@ AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, } -PURE_FP16_BLACK_LIST = {' '} -PURE_FP16_WHITE_LIST = {'lookup_table', 'lookup_table_v2'} +PURE_FP16_WHITE_LIST = {' '} +PURE_FP16_BLACK_LIST = {'lookup_table', 'lookup_table_v2'} #NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 30d5ee4417..e4b6bc0103 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -1466,6 +1466,8 @@ class Layer(core.Layer): param_applied = func(param, device, dtype, blocking) assert param.is_leaf param_applied.stop_gradient = param.stop_gradient + if hasattr(param_applied, 'is_distributed'): + param_applied.is_distributed = param.is_distributed self._parameters[key] = param_applied if param.grad is not None: @@ -1475,6 +1477,9 @@ class Layer(core.Layer): grad_applied.stop_gradient = param._grad_ivar( ).stop_gradient + if hasattr(param._grad_ivar(), 'is_distributed'): + grad_applied.is_distributed = param._grad_ivar( + ).is_distributed self._parameters[key]._set_grad_ivar(grad_applied) self._parameters_transform_map[id(param)] = [param_applied, key] diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index 4d90b91594..c6367911b8 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -6153,7 +6153,6 @@ class ParamBase(core.VarBase): return new_param def _copy_to(self, device, blocking): - print("in ParamBase copy_to func") state = copy.deepcopy(self.__dict__) new_param = ParamBase(self.shape, self.dtype, **state) core.varbase_copy(self, new_param, device, blocking) -- GitLab