未验证 提交 0052ed40 编写于 作者: X xiongkun 提交者: GitHub

[Yaml] Modify api and add unittests for full api final state. (#41437) (#41618)

* full api fix

* when out is None, go old dygraph mode

* fix

* add name for buffer

* fix by code review

* fix

* by static check
上级 99d32eb5
...@@ -1169,6 +1169,8 @@ class Layer(object): ...@@ -1169,6 +1169,8 @@ class Layer(object):
# add a persistable buffer. # add a persistable buffer.
if name not in self._buffers: if name not in self._buffers:
self._non_persistable_buffer_names_set.add(name) self._non_persistable_buffer_names_set.add(name)
if not value.name:
value.name = unique_name.generate('_buffers_' + name)
_buffers[name] = value _buffers[name] = value
elif _buffers is not None and name in _buffers: elif _buffers is not None and name in _buffers:
# Note(Aurelius84): In Dy2stat, the value of the Buffer may be modified in # Note(Aurelius84): In Dy2stat, the value of the Buffer may be modified in
......
...@@ -21,7 +21,7 @@ import warnings ...@@ -21,7 +21,7 @@ import warnings
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..param_attr import ParamAttr from ..param_attr import ParamAttr
from ..initializer import Initializer from ..initializer import Initializer
from ..framework import _current_expected_place, convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph, in_dygraph_mode from ..framework import _current_expected_place, convert_np_dtype_to_dtype_, _non_static_mode, _varbase_creator, device_guard, _in_legacy_dygraph, in_dygraph_mode, _get_paddle_place
from ..framework import Variable from ..framework import Variable
from ..initializer import Constant from ..initializer import Constant
from ..core import VarDesc from ..core import VarDesc
...@@ -751,6 +751,20 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None): ...@@ -751,6 +751,20 @@ def fill_constant(shape, dtype, value, force_cpu=False, out=None, name=None):
attrs['value'] = float(value) attrs['value'] = float(value)
if _non_static_mode(): if _non_static_mode():
if out is None and in_dygraph_mode():
#Currently, final state mode don't support out is None.
place = _current_expected_place()
if force_cpu:
place = core.CPUPlace()
shape = utils.convert_shape_to_list(shape)
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
out = _C_ops.final_state_full(shape, float(value), dtype, place)
out.stop_gradient = True
return out
else:
shape = utils.convert_shape_to_list(shape) shape = utils.convert_shape_to_list(shape)
if out is None: if out is None:
out = _varbase_creator(dtype=dtype) out = _varbase_creator(dtype=dtype)
......
...@@ -23,6 +23,7 @@ from paddle.fluid.op import Operator ...@@ -23,6 +23,7 @@ from paddle.fluid.op import Operator
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.framework import _test_eager_guard
# Test python API # Test python API
...@@ -75,6 +76,61 @@ class TestFullAPI(unittest.TestCase): ...@@ -75,6 +76,61 @@ class TestFullAPI(unittest.TestCase):
assert np.array_equal(res_6, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_6, np.full([1, 2], 1.1, dtype="float32"))
assert np.array_equal(res_7, np.full([1, 2], 1.1, dtype="float32")) assert np.array_equal(res_7, np.full([1, 2], 1.1, dtype="float32"))
def test_api_eager(self):
with fluid.dygraph.base.guard():
with _test_eager_guard():
positive_2_int32 = fluid.layers.fill_constant([1], "int32", 2)
positive_2_int64 = fluid.layers.fill_constant([1], "int64", 2)
out_1 = paddle.full(
shape=[1, 2], dtype="float32", fill_value=1.1)
out_2 = paddle.full(
shape=[1, positive_2_int32.item()],
dtype="float32",
fill_value=1.1)
out_3 = paddle.full(
shape=[1, positive_2_int64.item()],
dtype="float32",
fill_value=1.1)
out_4 = paddle.full(
shape=[1, 2], dtype="float32", fill_value=1.2)
out_5 = paddle.full(
shape=[1, 2], dtype="float32", fill_value=1.1)
out_6 = paddle.full(
shape=[1, 2], dtype=np.float32, fill_value=1.1)
val = fluid.layers.fill_constant(
shape=[1], dtype=np.float32, value=1.1)
out_7 = paddle.full(
shape=[1, 2], dtype=np.float32, fill_value=val)
assert np.array_equal(
out_1, np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
out_2, np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
out_3, np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
out_4, np.full(
[1, 2], 1.2, dtype="float32"))
assert np.array_equal(
out_5, np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
out_6, np.full(
[1, 2], 1.1, dtype="float32"))
assert np.array_equal(
out_7, np.full(
[1, 2], 1.1, dtype="float32"))
class TestFullOpError(unittest.TestCase): class TestFullOpError(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册