未验证 提交 cc511f24 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

support ONEDNN 0D for full_kernel (#51265)

上级 c7251b96
......@@ -69,6 +69,9 @@ inline OneDNNMemoryFormat OneDNNFormatForSize(size_t dims_size,
inline dnnl::memory::format_tag GetPlainOneDNNFormat(int tensor_rank) {
switch (tensor_rank) {
case 0:
// use 1D to represent 0D
return dnnl::memory::format_tag::a;
case 1:
return dnnl::memory::format_tag::a;
case 2:
......
......@@ -30,7 +30,7 @@ void TransposeGradKernel(const Context& dev_ctx,
const auto& onednn_engine = dev_ctx.GetEngine();
if (axis.size() == 1) {
if (axis.size() == 1 || axis.size() == 0) {
Copy<Context>(dev_ctx, out_grad, out_grad.place(), false, x_grad);
x_grad->set_mem_desc(out_grad.mem_desc());
return;
......
......@@ -112,7 +112,7 @@ void TransposeKernel(const Context& dev_ctx,
SetInMemDescWithLogicalLayoutFusesSupport(
dev_ctx, const_cast<DenseTensor*>(&x), x.mem_desc());
if (axis.size() == 1) {
if (axis.size() == 1 || axis.size() == 0) {
Copy<Context>(dev_ctx, x, x.place(), false, out);
out->set_mem_desc(x.mem_desc());
return;
......
......@@ -87,5 +87,11 @@ class TestCase4(TestTransposeMKLDNN):
self.axis = (4, 2, 3, 1, 0, 5)
class TestCase_ZeroDim(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = ()
self.axis = ()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册