提交 d74bb6ab 编写于 作者: S Sylwester Fraczek

fix ut for mkldnn 0.15 - added forcing layout NCHW in mkldnn conv tests

上级 c1446342
......@@ -280,12 +280,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
......@@ -423,16 +427,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
* ('any') which lets a primitive (conv backward in this case) choose
* the memory format preferred for best performance
*/
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto diff_src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto diff_weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto diff_dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// Retrieve conv_pd from device context
auto conv_pd =
......
......@@ -223,7 +223,7 @@ class MKLDNNHandler {
static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT
const std::string& suffix) {
return dims2str(operand_dims) + suffix;
};
}
protected:
static std::string dims2str(const mkldnn::memory::dims& operand_dims) {
......@@ -251,5 +251,17 @@ inline mkldnn::memory::format MKLDNNFormatForSize(
return data_format;
}
inline mkldnn::memory::format data_format_to_memory_format(
const std::string& data_format) {
switch (framework::StringToDataLayout(data_format)) {
case framework::DataLayout::kNHWC:
return mkldnn::memory::format::nhwc;
case framework::DataLayout::kNCHW:
return mkldnn::memory::format::nchw;
default:
return mkldnn::memory::format::any;
}
}
} // namespace platform
} // namespace paddle
......@@ -20,16 +20,19 @@ from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride
class TestMKLDNN(TestConv2dOp):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWithPad(TestWithPad):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWithStride(TestWithStride):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
if __name__ == '__main__':
......
......@@ -66,6 +66,7 @@ class TestConv2dOp(OpTest):
self.op_type = "conv2d"
self.use_cudnn = False
self.use_mkldnn = False
self.data_format = "AnyLayout"
self.dtype = np.float32
self.init_kernel_type()
self.init_group()
......@@ -93,7 +94,8 @@ class TestConv2dOp(OpTest):
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
}
self.outputs = {'Output': output}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册