提交 669191c9 编写于 作者: Y Yihua Xu

Implement conv3d with mkldnn library (test=develop)

上级 4f71a6ee
......@@ -52,10 +52,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4,
"Input must be with 4 dimensions, i.e. NCHW");
PADDLE_ENFORCE(filter->dims().size() == 4,
"Filter must be with 4 dimensions, i.e. OIHW");
PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5,
"Input must be with 4 or 5dimensions, i.e. NCHW or NCDHW");
PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5,
"Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
bias->format() != memory::format::format_undef,
......@@ -71,9 +71,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups");
bool is_conv3d = strides.size() == 3U;
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
is_conv3d
? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 &&
dilations[2] == 1
: dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
const T* input_data = input->data<T>();
......@@ -84,16 +88,31 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1);
if (g > 1) {
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
if (is_conv3d) {
int o = weights_tz[0];
int i = weights_tz[1];
int d = weights_tz[2];
int h = weights_tz[3];
int w = weights_tz[4];
weights_tz.resize(6);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = d;
weights_tz[4] = h;
weights_tz[5] = w;
} else {
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
}
}
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
......@@ -105,11 +124,19 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<primitive> pipeline;
auto src_format = input->format();
mkldnn::memory::format weights_format =
(g == 1) ? filter->format() : mkldnn::memory::format::goihw;
if (is_conv3d) {
weights_format =
(g == 1) ? filter->format() : mkldnn::memory::format::goidhw;
}
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(),
(g == 1) ? filter->format() : mkldnn::memory::format::goihw);
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
......@@ -119,10 +146,20 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
weights_format =
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw;
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
weights_format =
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goidhw;
}
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc(
......@@ -263,8 +300,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
mkldnn::prop_kind fwd_prop_kind) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
memory::dims stride_dims = strides;
memory::dims padding_dims = paddings;
auto conv_desc = mkldnn::convolution_forward::desc(
fwd_prop_kind, mkldnn::convolution_direct, src, weights, dst,
......@@ -288,8 +325,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
mkldnn::prop_kind fwd_prop_kind) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
memory::dims stride_dims = strides;
memory::dims padding_dims = paddings;
auto conv_desc = mkldnn::convolution_forward::desc(
fwd_prop_kind, mkldnn::convolution_direct, src, weights, bias, dst,
......@@ -349,6 +386,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups");
bool is_conv3d = strides.size() == 3U;
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
const T* output_grad_data = output_grad->data<T>();
......@@ -358,8 +396,45 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1);
if (g > 1) {
if (is_conv3d) {
int o = weights_tz[0];
int i = weights_tz[1];
int d = weights_tz[2];
int h = weights_tz[3];
int w = weights_tz[4];
weights_tz.resize(6);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = d;
weights_tz[4] = h;
weights_tz[5] = w;
} else {
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
}
}
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
auto src_format = input->format();
mkldnn::memory::format weights_format =
(g == 1) ? filter->format() : mkldnn::memory::format::goihw;
if (is_conv3d) {
weights_format =
(g == 1) ? filter->format() : mkldnn::memory::format::goidhw;
}
// Get an unique name from "argument" name of "Output" variable
// as well as attributes of primitive to be created
// This name will be used as key when saving info into device context
......@@ -372,9 +447,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Create user memory descriptors
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), filter->format());
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
auto user_diff_dst_md = platform::MKLDNNMemDesc(
{dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format());
......@@ -386,14 +461,24 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
weights_format =
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw;
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
weights_format =
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goidhw;
}
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto diff_src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
auto diff_weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
auto diff_dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
......@@ -496,3 +581,9 @@ REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace,
REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNGradOpKernel<float>);
REGISTER_OP_KERNEL(conv3d, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(conv3d_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNGradOpKernel<float>);
......@@ -229,6 +229,10 @@ $$
}
void Conv3DOpMaker::Make() {
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddInput(
"Input",
"(Tensor) The input tensor of convolution operator. "
......@@ -247,6 +251,11 @@ void Conv3DOpMaker::Make() {
AddOutput("Output",
"(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW.");
AddInput("ResidualData",
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
"Used with fuse_residual_connection fusion.")
.AsDispensable();
AddAttr<std::vector<int>>("strides",
"(vector<int>, default:{1, 1, 1}), the "
"strides(d_stride, h_stride, w_stride) of "
......@@ -277,6 +286,13 @@ void Conv3DOpMaker::Make() {
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is as an input to residual "
"connection.")
.SetDefault(false);
AddAttr<std::string>(
"data_format",
"(string, default NCHW) Only used in "
......
......@@ -113,6 +113,18 @@ inline mkldnn::memory::format MKLDNNFormatForSize(
return mkldnn::memory::format::x;
} else if (dims_size == 2) {
return mkldnn::memory::format::nc;
} else if (dims_size == 3) {
if (data_format == mkldnn::memory::format::nchw) {
return mkldnn::memory::format::ncw;
} else if (data_format == mkldnn::memory::format::nhwc) {
return mkldnn::memory::format::nwc;
}
} else if (dims_size == 5) {
if (data_format == mkldnn::memory::format::nchw) {
return mkldnn::memory::format::ncdhw;
} else if (data_format == mkldnn::memory::format::nhwc) {
return mkldnn::memory::format::ndhwc;
}
}
return data_format;
}
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_conv3d_op import TestConv3dOp, TestCase1, TestWithGroup1, TestWithGroup2, TestWith1x1, TestWithInput1x1Filter1x1
class TestMKLDNN(TestConv3dOp):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNCase1(TestCase1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNGroup1(TestWithGroup1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNGroup2(TestWithGroup2):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWith1x1(TestWith1x1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
if __name__ == '__main__':
unittest.main()
......@@ -74,6 +74,8 @@ class TestConv3dOp(OpTest):
def setUp(self):
self.op_type = "conv3d"
self.use_cudnn = False
self.use_mkldnn = False
self.data_format = "AnyLayout"
self.dtype = np.float32
self.init_kernel_type()
self.init_group()
......@@ -83,8 +85,7 @@ class TestConv3dOp(OpTest):
conv3d_param = {
'stride': self.stride,
'pad': self.pad,
'dilations': self.dilations,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
'dilations': self.dilations
}
input = np.random.random(self.input_size).astype(self.dtype)
......@@ -101,7 +102,9 @@ class TestConv3dOp(OpTest):
'paddings': self.pad,
'groups': self.groups,
'dilations': self.dilations,
'use_cudnn': self.use_cudnn
'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
}
self.outputs = {'Output': output}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册