未验证 提交 0fff9306 编写于 作者: L liym27 提交者: GitHub

Fix bug for set_value op when input dtype is not float32 (#31411)

上级 c40b98e0
......@@ -57,8 +57,7 @@ class SetValue : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
......
......@@ -174,14 +174,13 @@ class SetValueKernel : public framework::OpKernel<T> {
auto steps_tensor_list =
ctx.MultiInput<framework::Tensor>("StepsTensorList");
auto dtype =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
auto axes = ctx.Attr<std::vector<int64_t>>("axes");
auto starts = ctx.Attr<std::vector<int64_t>>("starts");
auto ends = ctx.Attr<std::vector<int64_t>>("ends");
auto steps = ctx.Attr<std::vector<int64_t>>("steps");
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
auto dtype = in->type();
if (!starts_tensor_list.empty()) {
starts = GetDataFromTensorList<int64_t>(starts_tensor_list);
}
......
......@@ -631,10 +631,14 @@ class TestVarBase(unittest.TestCase):
class TestVarBaseSetitem(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(np.float32))
self.np_value = np.random.random((2, 3)).astype(np.float32)
self.set_dtype()
self.tensor_x = paddle.to_tensor(np.ones((4, 2, 3)).astype(self.dtype))
self.np_value = np.random.random((2, 3)).astype(self.dtype)
self.tensor_value = paddle.to_tensor(self.np_value)
def set_dtype(self):
self.dtype = "int32"
def _test(self, value):
paddle.disable_static()
self.assertEqual(self.tensor_x.inplace_version, 0)
......@@ -644,7 +648,7 @@ class TestVarBaseSetitem(unittest.TestCase):
self.assertEqual(self.tensor_x.inplace_version, 1)
if isinstance(value, (six.integer_types, float)):
result = np.zeros((2, 3)).astype(np.float32) + value
result = np.zeros((2, 3)).astype(self.dtype) + value
else:
result = self.np_value
......@@ -674,11 +678,26 @@ class TestVarBaseSetitem(unittest.TestCase):
paddle.disable_static()
self._test(10)
class TestVarBaseSetitemInt64(TestVarBaseSetitem):
def set_dtype(self):
self.dtype = "int64"
class TestVarBaseSetitemFp32(TestVarBaseSetitem):
def set_dtype(self):
self.dtype = "float32"
def test_value_float(self):
paddle.disable_static()
self._test(3.3)
class TestVarBaseSetitemFp64(TestVarBaseSetitem):
def set_dtype(self):
self.dtype = "float64"
class TestVarBaseInplaceVersion(unittest.TestCase):
def test_setitem(self):
paddle.disable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册