未验证 提交 cc511f24 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

support ONEDNN 0D for full_kernel (#51265)

上级 c7251b96
...@@ -69,6 +69,9 @@ inline OneDNNMemoryFormat OneDNNFormatForSize(size_t dims_size, ...@@ -69,6 +69,9 @@ inline OneDNNMemoryFormat OneDNNFormatForSize(size_t dims_size,
inline dnnl::memory::format_tag GetPlainOneDNNFormat(int tensor_rank) { inline dnnl::memory::format_tag GetPlainOneDNNFormat(int tensor_rank) {
switch (tensor_rank) { switch (tensor_rank) {
case 0:
// use 1D to represent 0D
return dnnl::memory::format_tag::a;
case 1: case 1:
return dnnl::memory::format_tag::a; return dnnl::memory::format_tag::a;
case 2: case 2:
......
...@@ -30,7 +30,7 @@ void TransposeGradKernel(const Context& dev_ctx, ...@@ -30,7 +30,7 @@ void TransposeGradKernel(const Context& dev_ctx,
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
if (axis.size() == 1) { if (axis.size() == 1 || axis.size() == 0) {
Copy<Context>(dev_ctx, out_grad, out_grad.place(), false, x_grad); Copy<Context>(dev_ctx, out_grad, out_grad.place(), false, x_grad);
x_grad->set_mem_desc(out_grad.mem_desc()); x_grad->set_mem_desc(out_grad.mem_desc());
return; return;
......
...@@ -112,7 +112,7 @@ void TransposeKernel(const Context& dev_ctx, ...@@ -112,7 +112,7 @@ void TransposeKernel(const Context& dev_ctx,
SetInMemDescWithLogicalLayoutFusesSupport( SetInMemDescWithLogicalLayoutFusesSupport(
dev_ctx, const_cast<DenseTensor*>(&x), x.mem_desc()); dev_ctx, const_cast<DenseTensor*>(&x), x.mem_desc());
if (axis.size() == 1) { if (axis.size() == 1 || axis.size() == 0) {
Copy<Context>(dev_ctx, x, x.place(), false, out); Copy<Context>(dev_ctx, x, x.place(), false, out);
out->set_mem_desc(x.mem_desc()); out->set_mem_desc(x.mem_desc());
return; return;
......
...@@ -87,5 +87,11 @@ class TestCase4(TestTransposeMKLDNN): ...@@ -87,5 +87,11 @@ class TestCase4(TestTransposeMKLDNN):
self.axis = (4, 2, 3, 1, 0, 5) self.axis = (4, 2, 3, 1, 0, 5)
class TestCase_ZeroDim(TestTransposeMKLDNN):
def initTestCase(self):
self.shape = ()
self.axis = ()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册