未验证 提交 4b472656 编写于 作者: A Aurelius84 提交者: GitHub

[BUG]Fix expand_as_v2 bug while X and Y with different dtype (#46950) (#46999)

* [BUG]Fix expand_as_v2 bug while X and Y with different dtype

* fix commit
上级 535d7574
...@@ -26,6 +26,13 @@ using framework::Tensor; ...@@ -26,6 +26,13 @@ using framework::Tensor;
class ExpandAsV2Op : public framework::OperatorWithKernel { class ExpandAsV2Op : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
}; };
class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker { class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -83,6 +83,24 @@ class TestExpandAsOpRank4(TestExpandAsBasic): ...@@ -83,6 +83,24 @@ class TestExpandAsOpRank4(TestExpandAsBasic):
self.outputs = {'Out': output} self.outputs = {'Out': output}
class TestExpandAsOpRank5(TestExpandAsBasic):
no_need_check_grad = True
def setUp(self):
self.op_type = "expand_as_v2"
self.python_api = paddle.expand_as
x = np.random.rand(1, 1, 7, 16).astype("int64")
target_tensor = np.random.rand(4, 6, 7, 16).astype("float64")
self.inputs = {'X': x, "Y": target_tensor}
self.attrs = {'target_shape': target_tensor.shape}
bcast_dims = [4, 6, 1, 1]
output = np.tile(self.inputs['X'], bcast_dims)
self.outputs = {'Out': output}
def test_check_grad(self):
pass
class TestExpandAsV2Error(unittest.TestCase): class TestExpandAsV2Error(unittest.TestCase):
def test_errors(self): def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册