未验证 提交 d93ee063 编写于 作者: S Sławomir Siwek 提交者: GitHub

Add new unittests for gIOHW format in conv_transpose_mkldnn_op (#37344)

* Add new unittests

* Replace I with O channel for filter groups

* Undo changes affecting other operators

* Fix oneDNN namespace typo

* Fix code format error
上级 c8ffdecb
...@@ -26,14 +26,12 @@ using Tensor = framework::Tensor; ...@@ -26,14 +26,12 @@ using Tensor = framework::Tensor;
using framework::DataLayout; using framework::DataLayout;
inline dnnl::memory::dims GetWeightsTz(const Tensor* filter, const int groups) { inline dnnl::memory::dims GetWeightsTz(const Tensor* filter, const int groups) {
auto iohw_weights_tz = framework::vectorize(filter->dims()); auto weights_tz = framework::vectorize(filter->dims());
auto weights_tz = iohw_weights_tz;
// IOHW -> OIHW
weights_tz[0] = iohw_weights_tz[1];
weights_tz[1] = iohw_weights_tz[0];
int g = std::max(groups, 1); int g = std::max(groups, 1);
int g_dim = (g > 1) ? 1 : 0;
platform::GetGroupConvWeightsTz(weights_tz, g); platform::GetGroupConvWeightsTz(weights_tz, g);
// gIOHW -> gOIHW || IOHW -> OIHW
std::swap(weights_tz[g_dim + 0], weights_tz[g_dim + 1]);
return weights_tz; return weights_tz;
} }
......
...@@ -154,6 +154,27 @@ class TestConv2DTransposeMKLDNNWithDilationsExplicitPad( ...@@ -154,6 +154,27 @@ class TestConv2DTransposeMKLDNNWithDilationsExplicitPad(
self.padding_algorithm = "EXPLICIT" self.padding_algorithm = "EXPLICIT"
class TestMKLDNNWithGroups(TestConv2DTransposeMKLDNNOp):
def init_test_case(self):
TestConv2DTransposeMKLDNNOp.init_test_case(self)
self.pad = [1, 1]
self.groups = 2
self.input_size = [2, 4, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 3, 3, 3]
class TestMKLDNNWithGroups_NHWC(TestConv2DTransposeMKLDNNOp):
def init_test_case(self):
TestConv2DTransposeMKLDNNOp.init_test_case(self)
self.pad = [1, 1]
self.groups = 2
self.input_size = [2, 5, 5, 4] # NHWC
f_c = self.input_size[-1]
self.filter_size = [f_c, 3, 3, 3]
self.data_format = 'NHWC'
if __name__ == '__main__': if __name__ == '__main__':
enable_static() enable_static()
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册