diff --git a/paddle/fluid/operators/conv_mkldnn_op.cc b/paddle/fluid/operators/conv_mkldnn_op.cc index 786bdb10a28977c448f5ca9b015262973d363ac1..154ff2bb209bb8f932c06caa319223ccf3314767 100644 --- a/paddle/fluid/operators/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/conv_mkldnn_op.cc @@ -28,6 +28,46 @@ using mkldnn::stream; using platform::to_void_cast; using platform::GetMKLDNNFormat; +inline void GetWeightsTz(std::vector& weights_tz, int groups, // NOLINT + bool is_conv3d) { + if (groups > 1) { + if (is_conv3d) { + int output = weights_tz[0]; + int input = weights_tz[1]; + int dimension = weights_tz[2]; + int height = weights_tz[3]; + int width = weights_tz[4]; + weights_tz.resize(6); + weights_tz[0] = groups; + weights_tz[1] = output / groups; + weights_tz[2] = input; + weights_tz[3] = dimension; + weights_tz[4] = height; + weights_tz[5] = width; + } else { + int output = weights_tz[0]; + int input = weights_tz[1]; + int height = weights_tz[2]; + int width = weights_tz[3]; + weights_tz.resize(5); + weights_tz[0] = groups; + weights_tz[1] = output / groups; + weights_tz[2] = input; + weights_tz[3] = height; + weights_tz[4] = width; + } + } +} + +inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format, + int groups, bool is_conv3d) { + if (is_conv3d) { + return (groups == 1) ? format : mkldnn::memory::format::goidhw; + } else { + return (groups == 1) ? format : mkldnn::memory::format::goihw; + } +} + template class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { public: @@ -53,7 +93,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { filter->format() != memory::format::format_undef, "Wrong layout/format set for Filter tensor"); PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5, - "Input must be with 4 or 5dimensions, i.e. NCHW or NCDHW"); + "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW"); PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5, "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW"); if (bias) { @@ -87,33 +127,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { std::vector weights_tz = paddle::framework::vectorize2int(filter->dims()); int g = std::max(groups, 1); - if (g > 1) { - if (is_conv3d) { - int o = weights_tz[0]; - int i = weights_tz[1]; - int d = weights_tz[2]; - int h = weights_tz[3]; - int w = weights_tz[4]; - weights_tz.resize(6); - weights_tz[0] = g; - weights_tz[1] = o / g; - weights_tz[2] = i; - weights_tz[3] = d; - weights_tz[4] = h; - weights_tz[5] = w; - } else { - int o = weights_tz[0]; - int i = weights_tz[1]; - int h = weights_tz[2]; - int w = weights_tz[3]; - weights_tz.resize(5); - weights_tz[0] = g; - weights_tz[1] = o / g; - weights_tz[2] = i; - weights_tz[3] = h; - weights_tz[4] = w; - } - } + GetWeightsTz(weights_tz, g, is_conv3d); std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); // Get unique name for storing MKLDNN primitives @@ -126,12 +140,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto src_format = input->format(); mkldnn::memory::format weights_format = - (g == 1) ? filter->format() : mkldnn::memory::format::goihw; - - if (is_conv3d) { - weights_format = - (g == 1) ? filter->format() : mkldnn::memory::format::goidhw; - } + GetWeightsFormat(filter->format(), g, is_conv3d); auto user_src_md = platform::MKLDNNMemDesc( {src_tz}, platform::MKLDNNGetDataType(), src_format); @@ -146,15 +155,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel { auto chosen_memory_format = platform::data_format_to_memory_format(data_format); - weights_format = - (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw; - if (is_conv3d) { chosen_memory_format = platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); - weights_format = - (g == 1) ? chosen_memory_format : mkldnn::memory::format::goidhw; } + weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d); auto src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); @@ -397,43 +402,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { std::vector weights_tz = paddle::framework::vectorize2int(filter->dims()); int g = std::max(groups, 1); - if (g > 1) { - if (is_conv3d) { - int o = weights_tz[0]; - int i = weights_tz[1]; - int d = weights_tz[2]; - int h = weights_tz[3]; - int w = weights_tz[4]; - weights_tz.resize(6); - weights_tz[0] = g; - weights_tz[1] = o / g; - weights_tz[2] = i; - weights_tz[3] = d; - weights_tz[4] = h; - weights_tz[5] = w; - } else { - int o = weights_tz[0]; - int i = weights_tz[1]; - int h = weights_tz[2]; - int w = weights_tz[3]; - weights_tz.resize(5); - weights_tz[0] = g; - weights_tz[1] = o / g; - weights_tz[2] = i; - weights_tz[3] = h; - weights_tz[4] = w; - } - } + GetWeightsTz(weights_tz, g, is_conv3d); std::vector dst_tz = paddle::framework::vectorize2int(output->dims()); auto src_format = input->format(); mkldnn::memory::format weights_format = - (g == 1) ? filter->format() : mkldnn::memory::format::goihw; - - if (is_conv3d) { - weights_format = - (g == 1) ? filter->format() : mkldnn::memory::format::goidhw; - } + GetWeightsFormat(filter->format(), g, is_conv3d); // Get an unique name from "argument" name of "Output" variable // as well as attributes of primitive to be created @@ -461,15 +435,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel { auto chosen_memory_format = platform::data_format_to_memory_format(data_format); - weights_format = - (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw; - if (is_conv3d) { chosen_memory_format = platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); - weights_format = - (g == 1) ? chosen_memory_format : mkldnn::memory::format::goidhw; } + weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d); auto src_md = platform::MKLDNNMemDesc( src_tz, platform::MKLDNNGetDataType(), chosen_memory_format); diff --git a/paddle/fluid/operators/conv_op.cc b/paddle/fluid/operators/conv_op.cc index 9abba5f512a3fff8a1c16885ff76b7c8b477adb1..d7b876628855b8b76b340cd1e6115896ead4aa6c 100644 --- a/paddle/fluid/operators/conv_op.cc +++ b/paddle/fluid/operators/conv_op.cc @@ -134,14 +134,14 @@ void Conv2DOpMaker::Make() { "The format of output tensor is X (one-dimensional) of size equal" "to the number of output channels. Only used with MKL-DNN.") .AsDispensable(); - AddOutput("Output", - "(Tensor) The output tensor of convolution operator. " - "The format of output tensor is also NCHW."); AddInput("ResidualData", "(Tensor) Tensor with residual data " "to which convolution output will be added." "Used with fuse_residual_connection fusion.") .AsDispensable(); + AddOutput("Output", + "(Tensor) The output tensor of convolution operator. " + "The format of output tensor is also NCHW."); AddAttr>("strides", "(vector default:{1, 1}), the " "strides(h_stride, w_stride) of " @@ -251,14 +251,14 @@ void Conv3DOpMaker::Make() { "is the width of the filter." "If the groups attribute is greater than 1, C equals the number of " "input image channels divided by the groups."); - AddOutput("Output", - "(Tensor) The output tensor of convolution operator." - "The format of output tensor is also NCDHW."); AddInput("ResidualData", "(Tensor) Tensor with residual data " "to which convolution output will be added." "Used with fuse_residual_connection fusion.") .AsDispensable(); + AddOutput("Output", + "(Tensor) The output tensor of convolution operator." + "The format of output tensor is also NCDHW."); AddAttr>("strides", "(vector, default:{1, 1, 1}), the " "strides(d_stride, h_stride, w_stride) of "