From 4b472656010ca933f9f5e2bf4c77a9da8cd8558e Mon Sep 17 00:00:00 2001 From: Aurelius84 Date: Fri, 14 Oct 2022 15:52:10 +0800 Subject: [PATCH] [BUG]Fix expand_as_v2 bug while X and Y with different dtype (#46950) (#46999) * [BUG]Fix expand_as_v2 bug while X and Y with different dtype * fix commit --- paddle/fluid/operators/expand_as_v2_op.cc | 7 +++++++ .../tests/unittests/test_expand_as_v2_op.py | 18 ++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc index 6fcf301897f..7263192e139 100644 --- a/paddle/fluid/operators/expand_as_v2_op.cc +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -26,6 +26,13 @@ using framework::Tensor; class ExpandAsV2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } }; class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py index f107fec1c4e..f2791e55d51 100755 --- a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py @@ -83,6 +83,24 @@ class TestExpandAsOpRank4(TestExpandAsBasic): self.outputs = {'Out': output} +class TestExpandAsOpRank5(TestExpandAsBasic): + no_need_check_grad = True + + def setUp(self): + self.op_type = "expand_as_v2" + self.python_api = paddle.expand_as + x = np.random.rand(1, 1, 7, 16).astype("int64") + target_tensor = np.random.rand(4, 6, 7, 16).astype("float64") + self.inputs = {'X': x, "Y": target_tensor} + self.attrs = {'target_shape': target_tensor.shape} + bcast_dims = [4, 6, 1, 1] + output = np.tile(self.inputs['X'], bcast_dims) + self.outputs = {'Out': output} + + def test_check_grad(self): + pass + + class TestExpandAsV2Error(unittest.TestCase): def test_errors(self): -- GitLab