未验证 提交 0fff9306 编写于 作者: L liym27 提交者: GitHub

Fix bug for set_value op when input dtype is not float32 (#31411)

上级 c40b98e0
...@@ -57,8 +57,7 @@ class SetValue : public framework::OperatorWithKernel { ...@@ -57,8 +57,7 @@ class SetValue : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")), OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace());
ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
...@@ -174,14 +174,13 @@ class SetValueKernel : public framework::OpKernel<T> { ...@@ -174,14 +174,13 @@ class SetValueKernel : public framework::OpKernel<T> {
auto steps_tensor_list = auto steps_tensor_list =
ctx.MultiInput<framework::Tensor>("StepsTensorList"); ctx.MultiInput<framework::Tensor>("StepsTensorList");
auto dtype =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto axes = ctx.Attr<std::vector<int64_t>>("axes"); auto axes = ctx.Attr<std::vector<int64_t>>("axes");
auto starts = ctx.Attr<std::vector<int64_t>>("starts"); auto starts = ctx.Attr<std::vector<int64_t>>("starts");
auto ends = ctx.Attr<std::vector<int64_t>>("ends"); auto ends = ctx.Attr<std::vector<int64_t>>("ends");
auto steps = ctx.Attr<std::vector<int64_t>>("steps"); auto steps = ctx.Attr<std::vector<int64_t>>("steps");
auto shape = ctx.Attr<std::vector<int64_t>>("shape"); auto shape = ctx.Attr<std::vector<int64_t>>("shape");
auto dtype = in->type();
if (!starts_tensor_list.empty()) { if (!starts_tensor_list.empty()) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list); starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
} }
......
...@@ -631,10 +631,14 @@ class TestVarBase(unittest.TestCase): ...@@ -631,10 +631,14 @@ class TestVarBase(unittest.TestCase):
class TestVarBaseSetitem(unittest.TestCase): class TestVarBaseSetitem(unittest.TestCase):
def setUp(self): def setUp(self):
paddle.disable_static() paddle.disable_static()
self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32)) self.set_dtype()
self.np_value = np.random.random((2, 3)).astype(np.float32) self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(self.dtype))
self.np_value = np.random.random((2, 3)).astype(self.dtype)
self.tensor_value = paddle.to_tensor(self.np_value) self.tensor_value = paddle.to_tensor(self.np_value)
def set_dtype(self):
self.dtype = "int32"
def _test(self, value): def _test(self, value):
paddle.disable_static() paddle.disable_static()
self.assertEqual(self.tensor_x.inplace_version, 0) self.assertEqual(self.tensor_x.inplace_version, 0)
...@@ -644,7 +648,7 @@ class TestVarBaseSetitem(unittest.TestCase): ...@@ -644,7 +648,7 @@ class TestVarBaseSetitem(unittest.TestCase):
self.assertEqual(self.tensor_x.inplace_version, 1) self.assertEqual(self.tensor_x.inplace_version, 1)
if isinstance(value, (six.integer_types, float)): if isinstance(value, (six.integer_types, float)):
result = np.zeros((2, 3)).astype(np.float32) + value result = np.zeros((2, 3)).astype(self.dtype) + value
else: else:
result = self.np_value result = self.np_value
...@@ -674,11 +678,26 @@ class TestVarBaseSetitem(unittest.TestCase): ...@@ -674,11 +678,26 @@ class TestVarBaseSetitem(unittest.TestCase):
paddle.disable_static() paddle.disable_static()
self._test(10) self._test(10)
class TestVarBaseSetitemInt64(TestVarBaseSetitem):
def set_dtype(self):
self.dtype = "int64"
class TestVarBaseSetitemFp32(TestVarBaseSetitem):
def set_dtype(self):
self.dtype = "float32"
def test_value_float(self): def test_value_float(self):
paddle.disable_static() paddle.disable_static()
self._test(3.3) self._test(3.3)
class TestVarBaseSetitemFp64(TestVarBaseSetitem):
def set_dtype(self):
self.dtype = "float64"
class TestVarBaseInplaceVersion(unittest.TestCase): class TestVarBaseInplaceVersion(unittest.TestCase):
def test_setitem(self): def test_setitem(self):
paddle.disable_static() paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册