diff --git a/paddle/fluid/operators/set_value_op.cc b/paddle/fluid/operators/set_value_op.cc index a18238adcae192c44d14e68c0bb53a3f534534a4..94d34c648d17482627bec08d1a9038a8600106a8 100644 --- a/paddle/fluid/operators/set_value_op.cc +++ b/paddle/fluid/operators/set_value_op.cc @@ -57,8 +57,7 @@ class SetValue : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::proto::VarType::Type(ctx.Attr("dtype")), - ctx.GetPlace()); + OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace()); } framework::OpKernelType GetKernelTypeForVar( diff --git a/paddle/fluid/operators/set_value_op.h b/paddle/fluid/operators/set_value_op.h index 6347bcd24791aa77cf40f9081a4a3518c242d005..325a2b0b865e9d5d6e433da1cd0c767402603239 100644 --- a/paddle/fluid/operators/set_value_op.h +++ b/paddle/fluid/operators/set_value_op.h @@ -174,14 +174,13 @@ class SetValueKernel : public framework::OpKernel { auto steps_tensor_list = ctx.MultiInput("StepsTensorList"); - auto dtype = - static_cast(ctx.Attr("dtype")); auto axes = ctx.Attr>("axes"); auto starts = ctx.Attr>("starts"); auto ends = ctx.Attr>("ends"); auto steps = ctx.Attr>("steps"); auto shape = ctx.Attr>("shape"); + auto dtype = in->type(); if (!starts_tensor_list.empty()) { starts = GetDataFromTensorList(starts_tensor_list); } diff --git a/python/paddle/fluid/tests/unittests/test_var_base.py b/python/paddle/fluid/tests/unittests/test_var_base.py index 6c5458c1a2cb9fff356bf76941cdf27147f1722c..b0c9dda7a30987e71a648a691441f238572e2873 100644 --- a/python/paddle/fluid/tests/unittests/test_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_var_base.py @@ -631,10 +631,14 @@ class TestVarBase(unittest.TestCase): class TestVarBaseSetitem(unittest.TestCase): def setUp(self): paddle.disable_static() - self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32)) - self.np_value = np.random.random((2, 3)).astype(np.float32) + self.set_dtype() + self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(self.dtype)) + self.np_value = np.random.random((2, 3)).astype(self.dtype) self.tensor_value = paddle.to_tensor(self.np_value) + def set_dtype(self): + self.dtype = "int32" + def _test(self, value): paddle.disable_static() self.assertEqual(self.tensor_x.inplace_version, 0) @@ -644,7 +648,7 @@ class TestVarBaseSetitem(unittest.TestCase): self.assertEqual(self.tensor_x.inplace_version, 1) if isinstance(value, (six.integer_types, float)): - result = np.zeros((2, 3)).astype(np.float32) + value + result = np.zeros((2, 3)).astype(self.dtype) + value else: result = self.np_value @@ -674,11 +678,26 @@ class TestVarBaseSetitem(unittest.TestCase): paddle.disable_static() self._test(10) + +class TestVarBaseSetitemInt64(TestVarBaseSetitem): + def set_dtype(self): + self.dtype = "int64" + + +class TestVarBaseSetitemFp32(TestVarBaseSetitem): + def set_dtype(self): + self.dtype = "float32" + def test_value_float(self): paddle.disable_static() self._test(3.3) +class TestVarBaseSetitemFp64(TestVarBaseSetitem): + def set_dtype(self): + self.dtype = "float64" + + class TestVarBaseInplaceVersion(unittest.TestCase): def test_setitem(self): paddle.disable_static()