未验证 提交 5a9ae411 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #12618 from sfraczek/sfraczek/fix-new-mkldnn-conv-tests

fix UT for mkldnn 0.15
......@@ -280,12 +280,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
......@@ -423,16 +427,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
* ('any') which lets a primitive (conv backward in this case) choose
* the memory format preferred for best performance
*/
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto diff_src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto diff_weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto diff_dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), memory::format::any);
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// Retrieve conv_pd from device context
auto conv_pd =
......
......@@ -223,7 +223,7 @@ class MKLDNNHandler {
static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT
const std::string& suffix) {
return dims2str(operand_dims) + suffix;
};
}
protected:
static std::string dims2str(const mkldnn::memory::dims& operand_dims) {
......@@ -251,5 +251,17 @@ inline mkldnn::memory::format MKLDNNFormatForSize(
return data_format;
}
inline mkldnn::memory::format data_format_to_memory_format(
const std::string& data_format) {
switch (framework::StringToDataLayout(data_format)) {
case framework::DataLayout::kNHWC:
return mkldnn::memory::format::nhwc;
case framework::DataLayout::kNCHW:
return mkldnn::memory::format::nchw;
default:
return mkldnn::memory::format::any;
}
}
} // namespace platform
} // namespace paddle
......@@ -20,16 +20,19 @@ from test_conv2d_op import TestConv2dOp, TestWithPad, TestWithStride
class TestMKLDNN(TestConv2dOp):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWithPad(TestWithPad):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWithStride(TestWithStride):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
if __name__ == '__main__':
......
......@@ -66,6 +66,7 @@ class TestConv2dOp(OpTest):
self.op_type = "conv2d"
self.use_cudnn = False
self.use_mkldnn = False
self.data_format = "AnyLayout"
self.dtype = np.float32
self.init_kernel_type()
self.init_group()
......@@ -93,7 +94,8 @@ class TestConv2dOp(OpTest):
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
}
self.outputs = {'Output': output}
......@@ -101,59 +103,35 @@ class TestConv2dOp(OpTest):
return core.is_compiled_with_cuda() and self.use_cudnn
def test_check_output(self):
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_output_with_place(place, atol=1e-5)
def test_check_grad(self):
if self.dtype == np.float16:
return
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.02)
else:
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_grad_with_place(
place, set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
def test_check_grad_no_filter(self):
if self.dtype == np.float16:
return
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
else:
self.check_grad(
['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_grad_with_place(
place, ['Input'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
if self.dtype == np.float16:
return
if self.testcudnn():
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
else:
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
self.check_grad_with_place(
place, ['Filter'],
'Output',
max_relative_error=0.02,
no_grad_set=set(['Input']))
def init_test_case(self):
self.pad = [0, 0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册