From 97a68ad20effe65cd7f9a3323913dfdda2f7396b Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Thu, 13 Oct 2022 13:54:15 +0800 Subject: [PATCH] [BUG]Fix expand_as_v2 bug while X and Y with different dtype (#46950) * [BUG]Fix expand_as_v2 bug while X and Y with different dtype * fix commit --- paddle/fluid/operators/expand_as_v2_op.cc | 7 +++++++ .../tests/unittests/test_expand_as_v2_op.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc index 772ef09219..09dc0f68cc 100644 --- a/paddle/fluid/operators/expand_as_v2_op.cc +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -24,6 +24,13 @@ namespace operators { class ExpandAsV2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } }; class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py index 3c06f0ca47..e0506f8eb5 100755 --- a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py @@ -81,6 +81,24 @@ class TestExpandAsOpRank4(TestExpandAsBasic): self.outputs = {'Out': output} +class TestExpandAsOpRank5(TestExpandAsBasic): + no_need_check_grad = True + + def setUp(self): + self.op_type = "expand_as_v2" + self.python_api = paddle.expand_as + x = np.random.rand(1, 1, 7, 16).astype("int64") + target_tensor = np.random.rand(4, 6, 7, 16).astype("float64") + self.inputs = {'X': x, "Y": target_tensor} + self.attrs = {'target_shape': target_tensor.shape} + bcast_dims = [4, 6, 1, 1] + output = np.tile(self.inputs['X'], bcast_dims) + self.outputs = {'Out': output} + + def test_check_grad(self): + pass + + class TestExpandAsV2Error(unittest.TestCase): def test_errors(self): -- GitLab