未验证 提交 9a9953d9 编写于 作者: Z zhangbo9674 提交者: GitHub

[AMP] add attr is_distributed for layer.to (#36221)

* add attr is_distributed

* refine code

* refine black/white list for pure fp16
上级 3a869cc5
......@@ -70,8 +70,8 @@ AMP_RELATED_FLAGS_SETTING = {
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
}
PURE_FP16_BLACK_LIST = {' '}
PURE_FP16_WHITE_LIST = {'lookup_table', 'lookup_table_v2'}
PURE_FP16_WHITE_LIST = {' '}
PURE_FP16_BLACK_LIST = {'lookup_table', 'lookup_table_v2'}
#NOTE(zhiqiu): similar as paddle.fluid.contrib.mixed_precision.fp16_lists.AutoMixedPrecisionLists._update_list
......
......@@ -1466,6 +1466,8 @@ class Layer(core.Layer):
param_applied = func(param, device, dtype, blocking)
assert param.is_leaf
param_applied.stop_gradient = param.stop_gradient
if hasattr(param_applied, 'is_distributed'):
param_applied.is_distributed = param.is_distributed
self._parameters[key] = param_applied
if param.grad is not None:
......@@ -1475,6 +1477,9 @@ class Layer(core.Layer):
grad_applied.stop_gradient = param._grad_ivar(
).stop_gradient
if hasattr(param._grad_ivar(), 'is_distributed'):
grad_applied.is_distributed = param._grad_ivar(
).is_distributed
self._parameters[key]._set_grad_ivar(grad_applied)
self._parameters_transform_map[id(param)] = [param_applied, key]
......
......@@ -6153,7 +6153,6 @@ class ParamBase(core.VarBase):
return new_param
def _copy_to(self, device, blocking):
print("in ParamBase copy_to func")
state = copy.deepcopy(self.__dict__)
new_param = ParamBase(self.shape, self.dtype, **state)
core.varbase_copy(self, new_param, device, blocking)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册