未验证 提交 9a9953d9 编写于 作者: Z zhangbo9674 提交者: GitHub

[AMP] add attr is_distributed for layer.to (#36221)

* add attr is_distributed

* refine code

* refine black/white list for pure fp16
上级 3a869cc5
...@@ -70,8 +70,8 @@ AMP_RELATED_FLAGS_SETTING = { ...@@ -70,8 +70,8 @@ AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
} }
PURE_FP16_BLACK_LIST = {' '} PURE_FP16_WHITE_LIST = {' '}
PURE_FP16_WHITE_LIST = {'lookup_table', 'lookup_table_v2'} PURE_FP16_BLACK_LIST = {'lookup_table', 'lookup_table_v2'}
#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list #NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
......
...@@ -1466,6 +1466,8 @@ class Layer(core.Layer): ...@@ -1466,6 +1466,8 @@ class Layer(core.Layer):
param_applied = func(param, device, dtype, blocking) param_applied = func(param, device, dtype, blocking)
assert param.is_leaf assert param.is_leaf
param_applied.stop_gradient = param.stop_gradient param_applied.stop_gradient = param.stop_gradient
if hasattr(param_applied, 'is_distributed'):
param_applied.is_distributed = param.is_distributed
self._parameters[key] = param_applied self._parameters[key] = param_applied
if param.grad is not None: if param.grad is not None:
...@@ -1475,6 +1477,9 @@ class Layer(core.Layer): ...@@ -1475,6 +1477,9 @@ class Layer(core.Layer):
grad_applied.stop_gradient = param._grad_ivar( grad_applied.stop_gradient = param._grad_ivar(
).stop_gradient ).stop_gradient
if hasattr(param._grad_ivar(), 'is_distributed'):
grad_applied.is_distributed = param._grad_ivar(
).is_distributed
self._parameters[key]._set_grad_ivar(grad_applied) self._parameters[key]._set_grad_ivar(grad_applied)
self._parameters_transform_map[id(param)] = [param_applied, key] self._parameters_transform_map[id(param)] = [param_applied, key]
......
...@@ -6153,7 +6153,6 @@ class ParamBase(core.VarBase): ...@@ -6153,7 +6153,6 @@ class ParamBase(core.VarBase):
return new_param return new_param
def _copy_to(self, device, blocking): def _copy_to(self, device, blocking):
print("in ParamBase copy_to func")
state = copy.deepcopy(self.__dict__) state = copy.deepcopy(self.__dict__)
new_param = ParamBase(self.shape, self.dtype, **state) new_param = ParamBase(self.shape, self.dtype, **state)
core.varbase_copy(self, new_param, device, blocking) core.varbase_copy(self, new_param, device, blocking)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册