未验证 提交 cef601bf 编写于 作者: Y yeliang2258 提交者: GitHub

Fix ONEDNN squeeze kernel when Xshape does not exist (#49071)

* fix onednn squeeze bug

* add test

* update kernel
上级 3d5d1173
......@@ -68,9 +68,13 @@ void SqueezeWithXShapeKernel(const Context& dev_ctx,
const IntArray& axes,
DenseTensor* out,
DenseTensor* xshape) {
auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size());
auto out_dims = out->dims();
ExecuteSqueeze<T, Context>(dev_ctx, x, x_dims, out_dims, out);
if (xshape == nullptr) {
SqueezeKernel<T, Context>(dev_ctx, x, axes, out);
} else {
auto x_dims = slice_ddim(xshape->dims(), 1, xshape->dims().size());
auto out_dims = out->dims();
ExecuteSqueeze<T, Context>(dev_ctx, x, x_dims, out_dims, out);
}
}
} // namespace phi
......
......@@ -111,6 +111,16 @@ class TestSqueeze2OneDNNOp3(TestSqueeze2OneDNNOp):
self.new_shape = (25, 1, 4)
class TestSqueeze2OneDNNOp4(TestSqueeze2OneDNNOp):
def set_outputs(self):
self.outputs = {"Out": self.x.reshape(self.new_shape)}
def init_test_case(self):
self.ori_shape = (25, 1, 1, 4, 1)
self.axes = (1, -1)
self.new_shape = (25, 1, 4)
class TestSqueezeOneDNNOp3(TestSqueezeOneDNNOp):
def init_test_case(self):
self.ori_shape = (25, 1, 1, 4, 1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册