未验证 提交 0052ed40 编写于 作者: X xiongkun 提交者: GitHub

[Yaml] Modify api and add unittests for full api final state. (#41437) (#41618)

* full api fix

* when out is None, go old dygraph mode

* fix

* add name for buffer

* fix by code review

* fix

* by static check
上级 99d32eb5
......@@ -1169,6 +1169,8 @@ class Layer(object):
# add a persistable buffer.
if name not in self._buffers:
self._non_persistable_buffer_names_set.add(name)
if not value.name:
value.name = unique_name.generate('_buffers_' + name)
_buffers[name] = value
elif _buffers is not None and name in _buffers:
# Note(Aurelius84): In Dy2stat, the value of the Buffer may be modified in
......
......@@ -21,7 +21,7 @@ import warnings
from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr
from ..initializer import Initializer
from ..framework import _current_expected_place, convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph, in_dygraph_mode
from ..framework import _current_expected_place, convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph, in_dygraph_mode, _get_paddle_place
from ..framework import Variable
from ..initializer import Constant
from ..core import VarDesc
......@@ -751,22 +751,36 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs['value'] = float(value)
if _non_static_mode():
shape = utils.convert_shape_to_list(shape)
if out is None:
out = _varbase_creator(dtype=dtype)
if isinstance(value, Variable):
if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value.numpy().item(0)))
else:
attrs['str_value'] = str(float(value.numpy().item(0)))
_C_ops.fill_constant(out, 'value',
float(value), 'force_cpu', force_cpu, 'dtype',
out.dtype, 'str_value', attrs['str_value'],
'shape', shape)
out.stop_gradient = True
return out
if out is None and in_dygraph_mode():
#Currently, final state mode don't support out is None.
place = _current_expected_place()
if force_cpu:
place = core.CPUPlace()
shape = utils.convert_shape_to_list(shape)
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
out = _C_ops.final_state_full(shape, float(value), dtype, place)
out.stop_gradient = True
return out
else:
shape = utils.convert_shape_to_list(shape)
if out is None:
out = _varbase_creator(dtype=dtype)
if isinstance(value, Variable):
if dtype in ['uint8', 'int16', 'int32', 'int64']:
attrs['str_value'] = str(int(value.numpy().item(0)))
else:
attrs['str_value'] = str(float(value.numpy().item(0)))
_C_ops.fill_constant(out, 'value',
float(value), 'force_cpu', force_cpu, 'dtype',
out.dtype, 'str_value', attrs['str_value'],
'shape', shape)
out.stop_gradient = True
return out
helper = LayerHelper("fill_constant", **locals())
inputs = {}
......
......@@ -23,6 +23,7 @@ from paddle.fluid.op import Operator
import paddle.fluid as fluid
import paddle
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.framework import _test_eager_guard
# Test python API
......@@ -75,6 +76,61 @@ class TestFullAPI(unittest.TestCase):
assert np.array_equal(res_6, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_7, np.full([1, 2], 1.1, dtype="float32"))
def test_api_eager(self):
with fluid.dygraph.base.guard():
with _test_eager_guard():
positive_2_int32 = fluid.layers.fill_constant([1], "int32", 2)
positive_2_int64 = fluid.layers.fill_constant([1], "int64", 2)
out_1 = paddle.full(
shape=[1, 2], dtype="float32", fill_value=1.1)
out_2 = paddle.full(
shape=[1, positive_2_int32.item()],
dtype="float32",
fill_value=1.1)
out_3 = paddle.full(
shape=[1, positive_2_int64.item()],
dtype="float32",
fill_value=1.1)
out_4 = paddle.full(
shape=[1, 2], dtype="float32", fill_value=1.2)
out_5 = paddle.full(
shape=[1, 2], dtype="float32", fill_value=1.1)
out_6 = paddle.full(
shape=[1, 2], dtype=np.float32, fill_value=1.1)
val = fluid.layers.fill_constant(
shape=[1], dtype=np.float32, value=1.1)
out_7 = paddle.full(
shape=[1, 2], dtype=np.float32, fill_value=val)
assert np.array_equal(
out_1, np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
out_2, np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
out_3, np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
out_4, np.full(
[1, 2], 1.2, dtype="float32"))
assert np.array_equal(
out_5, np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
out_6, np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
out_7, np.full(
[1, 2], 1.1, dtype="float32"))
class TestFullOpError(unittest.TestCase):
def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册