diff --git a/paddle/fluid/operators/expand_as_v2_op.cc b/paddle/fluid/operators/expand_as_v2_op.cc index 6fcf301897f29520d4e6e506ce19bb0c41098416..7263192e13924747d13a3d91327a653b0308fd10 100644 --- a/paddle/fluid/operators/expand_as_v2_op.cc +++ b/paddle/fluid/operators/expand_as_v2_op.cc @@ -26,6 +26,13 @@ using framework::Tensor; class ExpandAsV2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } }; class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py index f107fec1c4e4eb038fb5b33ac07076ef521ad333..f2791e55d5188f3aab06e049508d1727e6cee09e 100755 --- a/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py +++ b/python/paddle/fluid/tests/unittests/test_expand_as_v2_op.py @@ -83,6 +83,24 @@ class TestExpandAsOpRank4(TestExpandAsBasic): self.outputs = {'Out': output} +class TestExpandAsOpRank5(TestExpandAsBasic): + no_need_check_grad = True + + def setUp(self): + self.op_type = "expand_as_v2" + self.python_api = paddle.expand_as + x = np.random.rand(1, 1, 7, 16).astype("int64") + target_tensor = np.random.rand(4, 6, 7, 16).astype("float64") + self.inputs = {'X': x, "Y": target_tensor} + self.attrs = {'target_shape': target_tensor.shape} + bcast_dims = [4, 6, 1, 1] + output = np.tile(self.inputs['X'], bcast_dims) + self.outputs = {'Out': output} + + def test_check_grad(self): + pass + + class TestExpandAsV2Error(unittest.TestCase): def test_errors(self):