未验证 提交 49411a20 编写于 作者: L liym27 提交者: GitHub

In creation.assgin, reuse implamention code of layers.tensor.assign to avoid...

In creation.assgin, reuse implamention code of layers.tensor.assign to avoid maintain two code (#30227)
上级 e03171b7
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
from ..fluid.layers import tensor
from ..fluid.framework import Variable from ..fluid.framework import Variable
from ..fluid.framework import unique_name from ..fluid.framework import unique_name
from ..fluid.framework import _current_expected_place, _get_paddle_place from ..fluid.framework import _current_expected_place, _get_paddle_place
...@@ -1057,46 +1058,5 @@ def assign(x, output=None): ...@@ -1057,46 +1058,5 @@ def assign(x, output=None):
result2 = paddle.assign(data) # result2 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]] result2 = paddle.assign(data) # result2 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]]
result3 = paddle.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]] result3 = paddle.assign(np.array([[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]], dtype='float32')) # result3 = [[2.5, 2.5], [2.5, 2.5], [2.5, 2.5]]
""" """
helper = LayerHelper('assign', **locals())
check_type(x, 'x', (Variable, numpy.ndarray), 'assign') check_type(x, 'x', (Variable, numpy.ndarray), 'assign')
if isinstance(x, Variable): return tensor.assign(x, output)
check_dtype(
x.dtype, 'x',
['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
'assign', '(When the type of input in assign is Variable.)')
if output is None:
output = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='assign', inputs={'X': [x]}, outputs={'Out': [output]})
elif isinstance(x, numpy.ndarray):
dtype = convert_np_dtype_to_dtype_(x.dtype)
if dtype == VarDesc.VarType.BOOL:
value_name = "bool_values"
values = [bool(v) for v in x.flat]
elif dtype == VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in x.flat]
elif dtype == VarDesc.VarType.INT32:
value_name = "int32_values"
values = [int(v) for v in x.flat]
elif dtype == VarDesc.VarType.INT64:
value_name = "int64_values"
values = [int(v) for v in x.flat]
else:
raise TypeError(
"When the type of 'x' in assign is numpy.ndarray, "
"the data type of 'x' must be bool, float32, int32 or int64, but "
"received %s." % convert_dtype(dtype))
if x.size > 1024 * 1024:
raise ValueError("The size of input is too big. Please consider "
"saving it to file and 'load_op' to load it")
if output is None:
output = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='assign_value',
outputs={'Out': [output]},
attrs={'dtype': dtype,
'shape': list(x.shape),
value_name: values})
return output
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册