未验证 提交 c0a604e7 编写于 作者: Y YangQun 提交者: GitHub

[Zero-Dim] support 0d tensor for shape and squeeze onednn kernel (#52832)

* support 0d tensor for shape and squeeze onednn kernel

* set python api for shape op ut
上级 a0aff194
...@@ -89,8 +89,9 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> { ...@@ -89,8 +89,9 @@ class ReshapeMKLDNNKernel : public framework::OpKernel<T> {
astream.wait(); astream.wait();
out->Resize(out_dims); out->Resize(out_dims);
out->set_mem_desc( auto reshape_dims = out_dims.size() != 0 ? phi::vectorize(out_dims)
reorder_dst_memory_p->get_desc().reshape(phi::vectorize(out_dims))); : std::vector<int64_t>{1};
out->set_mem_desc(reorder_dst_memory_p->get_desc().reshape(reshape_dims));
} }
void InferInOutShape(const framework::ExecutionContext& ctx, void InferInOutShape(const framework::ExecutionContext& ctx,
......
...@@ -25,7 +25,9 @@ void SqueezeGradKernel(const Context& dev_ctx, ...@@ -25,7 +25,9 @@ void SqueezeGradKernel(const Context& dev_ctx,
const DenseTensor& dout, const DenseTensor& dout,
const IntArray& axes, const IntArray& axes,
DenseTensor* dx) { DenseTensor* dx) {
auto dout_vec_dims = vectorize(dout.dims()); auto dout_vec_dims = dout.dims().size() != 0 ? vectorize(dout.dims())
: std::vector<int64_t>{1};
auto dout_type = funcs::ToOneDNNDataType(dout.dtype()); auto dout_type = funcs::ToOneDNNDataType(dout.dtype());
funcs::ReorderOneDNNHandler reorder_handler( funcs::ReorderOneDNNHandler reorder_handler(
......
...@@ -47,8 +47,10 @@ void ExecuteSqueeze(const Context& dev_ctx, ...@@ -47,8 +47,10 @@ void ExecuteSqueeze(const Context& dev_ctx,
astream.wait(); astream.wait();
out->Resize(out_dims); out->Resize(out_dims);
out->set_mem_desc(
reorder_dst_memory_p->get_desc().reshape(vectorize(out_dims))); auto reshape_dims =
out_dims.size() != 0 ? vectorize(out_dims) : std::vector<int64_t>{1};
out->set_mem_desc(reorder_dst_memory_p->get_desc().reshape(reshape_dims));
} }
template <typename T, typename Context> template <typename T, typename Context>
......
...@@ -31,4 +31,6 @@ NEED_TO_FIX_OP_LIST = [ ...@@ -31,4 +31,6 @@ NEED_TO_FIX_OP_LIST = [
'multi_dot', 'multi_dot',
'index_add', 'index_add',
'reshape2', 'reshape2',
'squeeze',
'squeeze2',
] ]
...@@ -21,10 +21,10 @@ from paddle.fluid import core ...@@ -21,10 +21,10 @@ from paddle.fluid import core
from paddle.fluid.tests.unittests.eager_op_test import OpTest, OpTestTool from paddle.fluid.tests.unittests.eager_op_test import OpTest, OpTestTool
@OpTestTool.skip_if_not_cpu_bf16()
class TestShape3DFP32OneDNNOp(OpTest): class TestShape3DFP32OneDNNOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "shape" self.op_type = "shape"
self.python_api = paddle.tensor.shape
self.config() self.config()
self.attrs = {'use_mkldnn': True} self.attrs = {'use_mkldnn': True}
self.inputs = {'Input': np.zeros(self.shape).astype(self.dtype)} self.inputs = {'Input': np.zeros(self.shape).astype(self.dtype)}
...@@ -38,6 +38,13 @@ class TestShape3DFP32OneDNNOp(OpTest): ...@@ -38,6 +38,13 @@ class TestShape3DFP32OneDNNOp(OpTest):
self.check_output_with_place(core.CPUPlace()) self.check_output_with_place(core.CPUPlace())
class TestShape0DFP32OneDNNOp(TestShape3DFP32OneDNNOp):
def config(self):
self.shape = []
self.dtype = np.float32
@OpTestTool.skip_if_not_cpu_bf16()
class TestShape6DBF16OneDNNOp(TestShape3DFP32OneDNNOp): class TestShape6DBF16OneDNNOp(TestShape3DFP32OneDNNOp):
def config(self): def config(self):
self.shape = [10, 2, 3, 4, 5, 2] self.shape = [10, 2, 3, 4, 5, 2]
......
...@@ -76,6 +76,20 @@ class TestSqueezeOneDNNOp(TestSqueeze2OneDNNOp): ...@@ -76,6 +76,20 @@ class TestSqueezeOneDNNOp(TestSqueeze2OneDNNOp):
self.check_output_with_place(core.CPUPlace()) self.check_output_with_place(core.CPUPlace())
class TestSqueeze2OneDNNOp_ZeroDim(TestSqueeze2OneDNNOp):
def init_test_case(self):
self.ori_shape = [1]
self.axes = ()
self.new_shape = ()
class TestSqueezeOneDNNOp_ZeroDim(TestSqueezeOneDNNOp):
def init_test_case(self):
self.ori_shape = [1]
self.axes = ()
self.new_shape = ()
class TestSqueeze2OneDNNOp1(TestSqueeze2OneDNNOp): class TestSqueeze2OneDNNOp1(TestSqueeze2OneDNNOp):
def init_test_case(self): def init_test_case(self):
self.ori_shape = (1, 20, 1, 5) self.ori_shape = (1, 20, 1, 5)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册