diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index ce45dd58419ab20cccf00544288b79d869515578..154ff2bb209bb8f932c06caa319223ccf3314767 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -28,6 +28,46 @@ using mkldnn::stream; using platform::to_void_cast; using platform::GetMKLDNNFormat; +inline void GetWeightsTz(std::vector& weights_tz, int groups, // NOLINT + bool is_conv3d) { + if (groups > 1) { + if (is_conv3d) { + int output = weights_tz[0]; + int input = weights_tz[1]; + int dimension = weights_tz[2]; + int height = weights_tz[3]; + int width = weights_tz[4]; + weights_tz.resize(6); + weights_tz[0] = groups; + weights_tz[1] = output / groups; + weights_tz[2] = input; + weights_tz[3] = dimension; + weights_tz[4] = height; + weights_tz[5] = width; + } else { + int output = weights_tz[0]; + int input = weights_tz[1]; + int height = weights_tz[2]; + int width = weights_tz[3]; + weights_tz.resize(5); + weights_tz[0] = groups; + weights_tz[1] = output / groups; + weights_tz[2] = input; + weights_tz[3] = height; + weights_tz[4] = width; + } + } +} + +inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format, + int groups, bool is_conv3d) { + if (is_conv3d) { + return (groups == 1) ? format : mkldnn::memory::format::goidhw; + } else { + return (groups == 1) ? format : mkldnn::memory::format::goihw; + } +} + template class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -52,10 +92,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN && filter->format() != memory::format::format_undef, "Wrong layout/format set for Filter tensor"); - PADDLE_ENFORCE(input->dims().size() == 4, - "Input must be with 4 dimensions, i.e. NCHW"); - PADDLE_ENFORCE(filter->dims().size() == 4, - "Filter must be with 4 dimensions, i.e. OIHW"); + PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5, + "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW"); + PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5, + "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW"); if (bias) { PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN && bias->format() != memory::format::format_undef, @@ -71,9 +111,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { bool fuse_residual_conn = ctx.Attr("fuse_residual_connection"); int groups = ctx.Attr("groups"); + bool is_conv3d = strides.size() == 3U; // TODO(tpatejko): add support for dilation PADDLE_ENFORCE( - dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, + is_conv3d + ? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 && + dilations[2] == 1 + : dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, "dilation in convolution is not implemented yet"); const T* input_data = input->data(); @@ -83,18 +127,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector weights_tz = paddle::framework::vectorize2int(filter->dims()); int g = std::max(groups, 1); - if (g > 1) { - int o = weights_tz[0]; - int i = weights_tz[1]; - int h = weights_tz[2]; - int w = weights_tz[3]; - weights_tz.resize(5); - weights_tz[0] = g; - weights_tz[1] = o / g; - weights_tz[2] = i; - weights_tz[3] = h; - weights_tz[4] = w; - } + GetWeightsTz(weights_tz, g, is_conv3d); std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); // Get unique name for storing MKLDNN primitives @@ -105,11 +138,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector pipeline; + auto src_format = input->format(); + mkldnn::memory::format weights_format = + GetWeightsFormat(filter->format(), g, is_conv3d); + auto user_src_md = platform::MKLDNNMemDesc( - {src_tz}, platform::MKLDNNGetDataType(), input->format()); + {src_tz}, platform::MKLDNNGetDataType(), src_format); auto user_weights_md = platform::MKLDNNMemDesc( - {weights_tz}, platform::MKLDNNGetDataType(), - (g == 1) ? filter->format() : mkldnn::memory::format::goihw); + {weights_tz}, platform::MKLDNNGetDataType(), weights_format); /* create memory descriptor for convolution without specified format * ('any') which lets a primitive (convolution in this case) choose @@ -119,10 +155,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto chosen_memory_format = platform::data_format_to_memory_format(data_format); + if (is_conv3d) { + chosen_memory_format = + platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); + } + weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d); + auto src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + weights_tz, platform::MKLDNNGetDataType(), weights_format); std::vector bias_tz; // TODO(mgallus): avoid empty vector creation. // Currently used whenever bias is != nullptr. auto dst_md = platform::MKLDNNMemDesc( @@ -263,8 +305,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const mkldnn::engine& engine, const bool fuse_relu, const bool fuse_residual_conn, mkldnn::prop_kind fwd_prop_kind) const { - memory::dims stride_dims = {strides[0], strides[1]}; - memory::dims padding_dims = {paddings[0], paddings[1]}; + memory::dims stride_dims = strides; + memory::dims padding_dims = paddings; auto conv_desc = mkldnn::convolution_forward::desc( fwd_prop_kind, mkldnn::convolution_direct, src, weights, dst, @@ -288,8 +330,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { const mkldnn::engine& engine, const bool fuse_relu, const bool fuse_residual_conn, mkldnn::prop_kind fwd_prop_kind) const { - memory::dims stride_dims = {strides[0], strides[1]}; - memory::dims padding_dims = {paddings[0], paddings[1]}; + memory::dims stride_dims = strides; + memory::dims padding_dims = paddings; auto conv_desc = mkldnn::convolution_forward::desc( fwd_prop_kind, mkldnn::convolution_direct, src, weights, bias, dst, @@ -349,6 +391,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { std::vector dilations = ctx.Attr>("dilations"); int groups = ctx.Attr("groups"); + bool is_conv3d = strides.size() == 3U; const T* input_data = input->data(); const T* filter_data = filter->data(); const T* output_grad_data = output_grad->data(); @@ -358,8 +401,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { std::vector src_tz = paddle::framework::vectorize2int(input->dims()); std::vector weights_tz = paddle::framework::vectorize2int(filter->dims()); + int g = std::max(groups, 1); + GetWeightsTz(weights_tz, g, is_conv3d); std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); + auto src_format = input->format(); + mkldnn::memory::format weights_format = + GetWeightsFormat(filter->format(), g, is_conv3d); + // Get an unique name from "argument" name of "Output" variable // as well as attributes of primitive to be created // This name will be used as key when saving info into device context @@ -372,9 +421,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { // Create user memory descriptors auto user_src_md = platform::MKLDNNMemDesc( - {src_tz}, platform::MKLDNNGetDataType(), input->format()); + {src_tz}, platform::MKLDNNGetDataType(), src_format); auto user_weights_md = platform::MKLDNNMemDesc( - {weights_tz}, platform::MKLDNNGetDataType(), filter->format()); + {weights_tz}, platform::MKLDNNGetDataType(), weights_format); auto user_diff_dst_md = platform::MKLDNNMemDesc( {dst_tz}, platform::MKLDNNGetDataType(), output_grad->format()); @@ -386,14 +435,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto chosen_memory_format = platform::data_format_to_memory_format(data_format); + if (is_conv3d) { + chosen_memory_format = + platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); + } + weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d); + auto src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto diff_src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); auto weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + weights_tz, platform::MKLDNNGetDataType(), weights_format); auto diff_weights_md = platform::MKLDNNMemDesc( - weights_tz, platform::MKLDNNGetDataType(), chosen_memory_format); + weights_tz, platform::MKLDNNGetDataType(), weights_format); auto diff_dst_md = platform::MKLDNNMemDesc( dst_tz, platform::MKLDNNGetDataType(), chosen_memory_format); @@ -500,3 +555,13 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, FP32, ops::kConvMKLDNNFP32, ops::ConvMKLDNNGradOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN, + ::paddle::platform::CPUPlace, FP32, + ops::kConvMKLDNNFP32, + ops::ConvMKLDNNOpKernel); + +REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN, + ::paddle::platform::CPUPlace, FP32, + ops::kConvMKLDNNFP32, + ops::ConvMKLDNNGradOpKernel); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 7455b9492f054b32ee7fb1fc90b1a344367ceb81..d7b876628855b8b76b340cd1e6115896ead4aa6c 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -134,14 +134,14 @@ void Conv2DOpMaker::Make() { "The format of output tensor is X (one-dimensional) of size equal" "to the number of output channels. Only used with MKL-DNN.") .AsDispensable(); - AddOutput("Output", - "(Tensor) The output tensor of convolution operator. " - "The format of output tensor is also NCHW."); AddInput("ResidualData", "(Tensor) Tensor with residual data " "to which convolution output will be added." "Used with fuse_residual_connection fusion.") .AsDispensable(); + AddOutput("Output", + "(Tensor) The output tensor of convolution operator. " + "The format of output tensor is also NCHW."); AddAttr>("strides", "(vector default:{1, 1}), the " "strides(h_stride, w_stride) of " @@ -232,6 +232,10 @@ $$ } void Conv3DOpMaker::Make() { + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false); AddInput( "Input", "(Tensor) The input tensor of convolution operator. " @@ -247,6 +251,11 @@ void Conv3DOpMaker::Make() { "is the width of the filter." "If the groups attribute is greater than 1, C equals the number of " "input image channels divided by the groups."); + AddInput("ResidualData", + "(Tensor) Tensor with residual data " + "to which convolution output will be added." + "Used with fuse_residual_connection fusion.") + .AsDispensable(); AddOutput("Output", "(Tensor) The output tensor of convolution operator." "The format of output tensor is also NCDHW."); @@ -280,6 +289,13 @@ void Conv3DOpMaker::Make() { AddAttr("use_mkldnn", "(bool, default false) Only used in mkldnn kernel") .SetDefault(false); + AddAttr("fuse_relu", "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false); + AddAttr("fuse_residual_connection", + "(bool, default false) Only used in mkldnn kernel. Used " + "whenever convolution output is as an input to residual " + "connection.") + .SetDefault(false); AddAttr( "data_format", "(string, default NCHW) Only used in " diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 167bd4e81d0ddbbba260417b460d083dbeb932b6..e53064893ee89f663a76483b92de32b318b6c61f 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -113,6 +113,18 @@ inline mkldnn::memory::format MKLDNNFormatForSize( return mkldnn::memory::format::x; } else if (dims_size == 2) { return mkldnn::memory::format::nc; + } else if (dims_size == 3) { + if (data_format == mkldnn::memory::format::nchw) { + return mkldnn::memory::format::ncw; + } else if (data_format == mkldnn::memory::format::nhwc) { + return mkldnn::memory::format::nwc; + } + } else if (dims_size == 5) { + if (data_format == mkldnn::memory::format::nchw) { + return mkldnn::memory::format::ncdhw; + } else if (data_format == mkldnn::memory::format::nhwc) { + return mkldnn::memory::format::ndhwc; + } } return data_format; } diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_mkldnn_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f0e1265e142b800587599783367eca2203033bf1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_conv3d_mkldnn_op.py @@ -0,0 +1,59 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest + +from test_conv3d_op import TestConv3dOp, TestCase1, TestWithGroup1, TestWithGroup2, TestWith1x1, TestWithInput1x1Filter1x1 + + +class TestMKLDNN(TestConv3dOp): + def init_kernel_type(self): + self.use_mkldnn = True + self.data_format = "NCHW" + + +class TestMKLDNNCase1(TestCase1): + def init_kernel_type(self): + self.use_mkldnn = True + self.data_format = "NCHW" + + +class TestMKLDNNGroup1(TestWithGroup1): + def init_kernel_type(self): + self.use_mkldnn = True + self.data_format = "NCHW" + + +class TestMKLDNNGroup2(TestWithGroup2): + def init_kernel_type(self): + self.use_mkldnn = True + self.data_format = "NCHW" + + +class TestMKLDNNWith1x1(TestWith1x1): + def init_kernel_type(self): + self.use_mkldnn = True + self.data_format = "NCHW" + + +class TestMKLDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1): + def init_kernel_type(self): + self.use_mkldnn = True + self.data_format = "NCHW" + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_conv3d_op.py b/python/paddle/fluid/tests/unittests/test_conv3d_op.py index 69c5ab7a4a4cbd552d27dcb07052d46752eeb54a..c6b749fe09b18b1d704f45a5a5b3adbd5c6a6d0b 100644 --- a/python/paddle/fluid/tests/unittests/test_conv3d_op.py +++ b/python/paddle/fluid/tests/unittests/test_conv3d_op.py @@ -74,6 +74,8 @@ class TestConv3dOp(OpTest): def setUp(self): self.op_type = "conv3d" self.use_cudnn = False + self.use_mkldnn = False + self.data_format = "AnyLayout" self.dtype = np.float32 self.init_kernel_type() self.init_group() @@ -83,8 +85,7 @@ class TestConv3dOp(OpTest): conv3d_param = { 'stride': self.stride, 'pad': self.pad, - 'dilations': self.dilations, - 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter + 'dilations': self.dilations } input = np.random.random(self.input_size).astype(self.dtype) @@ -101,7 +102,9 @@ class TestConv3dOp(OpTest): 'paddings': self.pad, 'groups': self.groups, 'dilations': self.dilations, - 'use_cudnn': self.use_cudnn + 'use_cudnn': self.use_cudnn, + 'use_mkldnn': self.use_mkldnn, + 'data_format': self.data_format } self.outputs = {'Output': output} @@ -109,59 +112,35 @@ class TestConv3dOp(OpTest): return core.is_compiled_with_cuda() and self.use_cudnn def test_check_output(self): - if self.testcudnn(): - place = core.CUDAPlace(0) - self.check_output_with_place(place, atol=1e-5) - else: - self.check_output() + place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace() + self.check_output_with_place(place, atol=1e-5) def test_check_grad(self): if self.dtype == np.float16: return - if self.testcudnn(): - place = core.CUDAPlace(0) - self.check_grad_with_place( - place, - set(['Input', 'Filter']), - 'Output', - max_relative_error=0.03) - else: - self.check_grad( - set(['Input', 'Filter']), 'Output', max_relative_error=0.03) + place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace() + self.check_grad_with_place( + place, {'Input', 'Filter'}, 'Output', max_relative_error=0.03) def test_check_grad_no_filter(self): if self.dtype == np.float16: return - if self.testcudnn(): - place = core.CUDAPlace(0) - self.check_grad_with_place( - place, ['Input'], - 'Output', - max_relative_error=0.03, - no_grad_set=set(['Filter'])) - else: - self.check_grad( - ['Input'], - 'Output', - max_relative_error=0.03, - no_grad_set=set(['Filter'])) + place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace() + self.check_grad_with_place( + place, ['Input'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Filter'])) def test_check_grad_no_input(self): if self.dtype == np.float16: return - if self.testcudnn(): - place = core.CUDAPlace(0) - self.check_grad_with_place( - place, ['Filter'], - 'Output', - max_relative_error=0.03, - no_grad_set=set(['Input'])) - else: - self.check_grad( - ['Filter'], - 'Output', - max_relative_error=0.03, - no_grad_set=set(['Input'])) + place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace() + self.check_grad_with_place( + place, ['Input'], + 'Output', + max_relative_error=0.03, + no_grad_set=set(['Input'])) def init_test_case(self): self.pad = [0, 0, 0]