未验证 提交 97a68ad2 编写于 作者: A Aurelius84 提交者: GitHub

[BUG]Fix expand_as_v2 bug while X and Y with different dtype (#46950)

* [BUG]Fix expand_as_v2 bug while X and Y with different dtype

* fix commit
上级 a95b6f33
......@@ -24,6 +24,13 @@ namespace operators {
class ExpandAsV2Op : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class ExpandAsV2OpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -81,6 +81,24 @@ class TestExpandAsOpRank4(TestExpandAsBasic):
self.outputs = {'Out': output}
class TestExpandAsOpRank5(TestExpandAsBasic):
no_need_check_grad = True
def setUp(self):
self.op_type = "expand_as_v2"
self.python_api = paddle.expand_as
x = np.random.rand(1, 1, 7, 16).astype("int64")
target_tensor = np.random.rand(4, 6, 7, 16).astype("float64")
self.inputs = {'X': x, "Y": target_tensor}
self.attrs = {'target_shape': target_tensor.shape}
bcast_dims = [4, 6, 1, 1]
output = np.tile(self.inputs['X'], bcast_dims)
self.outputs = {'Out': output}
def test_check_grad(self):
pass
class TestExpandAsV2Error(unittest.TestCase):
def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册