From 0fff9306676ca2256de8cdd60eb0d30878521b95 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Thu, 4 Mar 2021 14:51:30 +0800 Subject: [PATCH] Fix bug for set_value op when input dtype is not float32 (#31411) --- paddle/fluid/operators/set_value_op.cc | 3 +-- paddle/fluid/operators/set_value_op.h | 3 +-- .../fluid/tests/unittests/test_var_base.py | 25 ++++++++++++++++--- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index a18238adca..94d34c648d 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -57,8 +57,7 @@ class SetValue : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::proto::VarType::Type(ctx.Attr("dtype")), - ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/set_value_op.h b/paddle/fluid/operators/set_value_op.h index 6347bcd247..325a2b0b86 100644 --- a/paddle/fluid/operators/set_value_op.h +++ b/paddle/fluid/operators/set_value_op.h @@ -174,14 +174,13 @@ class SetValueKernel : public framework::OpKernel { auto steps_tensor_list = ctx.MultiInput("StepsTensorList"); - auto dtype = - static_cast(ctx.Attr("dtype")); auto axes = ctx.Attr>("axes"); auto starts = ctx.Attr>("starts"); auto ends = ctx.Attr>("ends"); auto steps = ctx.Attr>("steps"); auto shape = ctx.Attr>("shape"); + auto dtype = in->type(); if (!starts_tensor_list.empty()) { starts = GetDataFromTensorList(starts_tensor_list); } diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 6c5458c1a2..b0c9dda7a3 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -631,10 +631,14 @@ class TestVarBase(unittest.TestCase): class TestVarBaseSetitem(unittest.TestCase): def setUp(self): paddle.disable_static() - self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32)) - self.np_value = np.random.random((2, 3)).astype(np.float32) + self.set_dtype() + self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(self.dtype)) + self.np_value = np.random.random((2, 3)).astype(self.dtype) self.tensor_value = paddle.to_tensor(self.np_value) + def set_dtype(self): + self.dtype = "int32" + def _test(self, value): paddle.disable_static() self.assertEqual(self.tensor_x.inplace_version, 0) @@ -644,7 +648,7 @@ class TestVarBaseSetitem(unittest.TestCase): self.assertEqual(self.tensor_x.inplace_version, 1) if isinstance(value, (six.integer_types, float)): - result = np.zeros((2, 3)).astype(np.float32) + value + result = np.zeros((2, 3)).astype(self.dtype) + value else: result = self.np_value @@ -674,11 +678,26 @@ class TestVarBaseSetitem(unittest.TestCase): paddle.disable_static() self._test(10) + +class TestVarBaseSetitemInt64(TestVarBaseSetitem): + def set_dtype(self): + self.dtype = "int64" + + +class TestVarBaseSetitemFp32(TestVarBaseSetitem): + def set_dtype(self): + self.dtype = "float32" + def test_value_float(self): paddle.disable_static() self._test(3.3) +class TestVarBaseSetitemFp64(TestVarBaseSetitem): + def set_dtype(self): + self.dtype = "float64" + + class TestVarBaseInplaceVersion(unittest.TestCase): def test_setitem(self): paddle.disable_static() -- GitLab