未验证 提交 54fcafb5 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #14707 from yihuaxu/develop_4f71a6ee_conv3d_mkldnn_opt

Implement conv3d with mkldnn library
...@@ -28,6 +28,46 @@ using mkldnn::stream; ...@@ -28,6 +28,46 @@ using mkldnn::stream;
using platform::to_void_cast; using platform::to_void_cast;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
inline void GetWeightsTz(std::vector<int>& weights_tz, int groups, // NOLINT
bool is_conv3d) {
if (groups > 1) {
if (is_conv3d) {
int output = weights_tz[0];
int input = weights_tz[1];
int dimension = weights_tz[2];
int height = weights_tz[3];
int width = weights_tz[4];
weights_tz.resize(6);
weights_tz[0] = groups;
weights_tz[1] = output / groups;
weights_tz[2] = input;
weights_tz[3] = dimension;
weights_tz[4] = height;
weights_tz[5] = width;
} else {
int output = weights_tz[0];
int input = weights_tz[1];
int height = weights_tz[2];
int width = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = groups;
weights_tz[1] = output / groups;
weights_tz[2] = input;
weights_tz[3] = height;
weights_tz[4] = width;
}
}
}
inline mkldnn::memory::format GetWeightsFormat(mkldnn::memory::format format,
int groups, bool is_conv3d) {
if (is_conv3d) {
return (groups == 1) ? format : mkldnn::memory::format::goidhw;
} else {
return (groups == 1) ? format : mkldnn::memory::format::goihw;
}
}
template <typename T> template <typename T>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -52,10 +92,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -52,10 +92,10 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE(filter->layout() == DataLayout::kMKLDNN &&
filter->format() != memory::format::format_undef, filter->format() != memory::format::format_undef,
"Wrong layout/format set for Filter tensor"); "Wrong layout/format set for Filter tensor");
PADDLE_ENFORCE(input->dims().size() == 4, PADDLE_ENFORCE(input->dims().size() == 4 || input->dims().size() == 5,
"Input must be with 4 dimensions, i.e. NCHW"); "Input must be with 4 or 5 dimensions, i.e. NCHW or NCDHW");
PADDLE_ENFORCE(filter->dims().size() == 4, PADDLE_ENFORCE(filter->dims().size() == 4 || filter->dims().size() == 5,
"Filter must be with 4 dimensions, i.e. OIHW"); "Filter must be with 4 or 5 dimensions, i.e. OIHW or OIDHW");
if (bias) { if (bias) {
PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE(bias->layout() == DataLayout::kMKLDNN &&
bias->format() != memory::format::format_undef, bias->format() != memory::format::format_undef,
...@@ -71,9 +111,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -71,9 +111,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
bool is_conv3d = strides.size() == 3U;
// TODO(tpatejko): add support for dilation // TODO(tpatejko): add support for dilation
PADDLE_ENFORCE( PADDLE_ENFORCE(
dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, is_conv3d
? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 &&
dilations[2] == 1
: dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet"); "dilation in convolution is not implemented yet");
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
...@@ -83,18 +127,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -83,18 +127,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> weights_tz = std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims()); paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
if (g > 1) { GetWeightsTz(weights_tz, g, is_conv3d);
int o = weights_tz[0];
int i = weights_tz[1];
int h = weights_tz[2];
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
}
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
...@@ -105,11 +138,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -105,11 +138,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
auto src_format = input->format();
mkldnn::memory::format weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format()); {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
auto user_weights_md = platform::MKLDNNMemDesc( auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), {weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
(g == 1) ? filter->format() : mkldnn::memory::format::goihw);
/* create memory descriptor for convolution without specified format /* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose * ('any') which lets a primitive (convolution in this case) choose
...@@ -119,10 +155,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -119,10 +155,16 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc( auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation. std::vector<int> bias_tz; // TODO(mgallus): avoid empty vector creation.
// Currently used whenever bias is != nullptr. // Currently used whenever bias is != nullptr.
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
...@@ -263,8 +305,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -263,8 +305,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, const bool fuse_residual_conn,
mkldnn::prop_kind fwd_prop_kind) const { mkldnn::prop_kind fwd_prop_kind) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = strides;
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = paddings;
auto conv_desc = mkldnn::convolution_forward::desc( auto conv_desc = mkldnn::convolution_forward::desc(
fwd_prop_kind, mkldnn::convolution_direct, src, weights, dst, fwd_prop_kind, mkldnn::convolution_direct, src, weights, dst,
...@@ -288,8 +330,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -288,8 +330,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, const bool fuse_residual_conn,
mkldnn::prop_kind fwd_prop_kind) const { mkldnn::prop_kind fwd_prop_kind) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = strides;
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = paddings;
auto conv_desc = mkldnn::convolution_forward::desc( auto conv_desc = mkldnn::convolution_forward::desc(
fwd_prop_kind, mkldnn::convolution_direct, src, weights, bias, dst, fwd_prop_kind, mkldnn::convolution_direct, src, weights, bias, dst,
...@@ -349,6 +391,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -349,6 +391,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
bool is_conv3d = strides.size() == 3U;
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>(); const T* filter_data = filter->data<T>();
const T* output_grad_data = output_grad->data<T>(); const T* output_grad_data = output_grad->data<T>();
...@@ -358,8 +401,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -358,8 +401,14 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz = std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims()); paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
auto src_format = input->format();
mkldnn::memory::format weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
// Get an unique name from "argument" name of "Output" variable // Get an unique name from "argument" name of "Output" variable
// as well as attributes of primitive to be created // as well as attributes of primitive to be created
// This name will be used as key when saving info into device context // This name will be used as key when saving info into device context
...@@ -372,9 +421,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -372,9 +421,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Create user memory descriptors // Create user memory descriptors
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format()); {src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
auto user_weights_md = platform::MKLDNNMemDesc( auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), filter->format()); {weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
auto user_diff_dst_md = platform::MKLDNNMemDesc( auto user_diff_dst_md = platform::MKLDNNMemDesc(
{dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format()); {dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format());
...@@ -386,14 +435,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -386,14 +435,20 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); platform::data_format_to_memory_format(data_format);
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
weights_format = GetWeightsFormat(chosen_memory_format, g, is_conv3d);
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto diff_src_md = platform::MKLDNNMemDesc( auto diff_src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc( auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
auto diff_weights_md = platform::MKLDNNMemDesc( auto diff_weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
auto diff_dst_md = platform::MKLDNNMemDesc( auto diff_dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
...@@ -500,3 +555,13 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN, ...@@ -500,3 +555,13 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
::paddle::platform::CPUPlace, FP32, ::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32, ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float>); ops::ConvMKLDNNGradOpKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float>);
...@@ -134,14 +134,14 @@ void Conv2DOpMaker::Make() { ...@@ -134,14 +134,14 @@ void Conv2DOpMaker::Make() {
"The format of output tensor is X (one-dimensional) of size equal" "The format of output tensor is X (one-dimensional) of size equal"
"to the number of output channels. Only used with MKL-DNN.") "to the number of output channels. Only used with MKL-DNN.")
.AsDispensable(); .AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.");
AddInput("ResidualData", AddInput("ResidualData",
"(Tensor) Tensor with residual data " "(Tensor) Tensor with residual data "
"to which convolution output will be added." "to which convolution output will be added."
"Used with fuse_residual_connection fusion.") "Used with fuse_residual_connection fusion.")
.AsDispensable(); .AsDispensable();
AddOutput("Output",
"(Tensor) The output tensor of convolution operator. "
"The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"(vector<int> default:{1, 1}), the " "(vector<int> default:{1, 1}), the "
"strides(h_stride, w_stride) of " "strides(h_stride, w_stride) of "
...@@ -232,6 +232,10 @@ $$ ...@@ -232,6 +232,10 @@ $$
} }
void Conv3DOpMaker::Make() { void Conv3DOpMaker::Make() {
AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false "
"for training. Some layers may run faster when this is true.")
.SetDefault(false);
AddInput( AddInput(
"Input", "Input",
"(Tensor) The input tensor of convolution operator. " "(Tensor) The input tensor of convolution operator. "
...@@ -247,6 +251,11 @@ void Conv3DOpMaker::Make() { ...@@ -247,6 +251,11 @@ void Conv3DOpMaker::Make() {
"is the width of the filter." "is the width of the filter."
"If the groups attribute is greater than 1, C equals the number of " "If the groups attribute is greater than 1, C equals the number of "
"input image channels divided by the groups."); "input image channels divided by the groups.");
AddInput("ResidualData",
"(Tensor) Tensor with residual data "
"to which convolution output will be added."
"Used with fuse_residual_connection fusion.")
.AsDispensable();
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution operator." "(Tensor) The output tensor of convolution operator."
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCDHW.");
...@@ -280,6 +289,13 @@ void Conv3DOpMaker::Make() { ...@@ -280,6 +289,13 @@ void Conv3DOpMaker::Make() {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false);
AddAttr<bool>("fuse_residual_connection",
"(bool, default false) Only used in mkldnn kernel. Used "
"whenever convolution output is as an input to residual "
"connection.")
.SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"data_format", "data_format",
"(string, default NCHW) Only used in " "(string, default NCHW) Only used in "
......
...@@ -113,6 +113,18 @@ inline mkldnn::memory::format MKLDNNFormatForSize( ...@@ -113,6 +113,18 @@ inline mkldnn::memory::format MKLDNNFormatForSize(
return mkldnn::memory::format::x; return mkldnn::memory::format::x;
} else if (dims_size == 2) { } else if (dims_size == 2) {
return mkldnn::memory::format::nc; return mkldnn::memory::format::nc;
} else if (dims_size == 3) {
if (data_format == mkldnn::memory::format::nchw) {
return mkldnn::memory::format::ncw;
} else if (data_format == mkldnn::memory::format::nhwc) {
return mkldnn::memory::format::nwc;
}
} else if (dims_size == 5) {
if (data_format == mkldnn::memory::format::nchw) {
return mkldnn::memory::format::ncdhw;
} else if (data_format == mkldnn::memory::format::nhwc) {
return mkldnn::memory::format::ndhwc;
}
} }
return data_format; return data_format;
} }
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from test_conv3d_op import TestConv3dOp, TestCase1, TestWithGroup1, TestWithGroup2, TestWith1x1, TestWithInput1x1Filter1x1
class TestMKLDNN(TestConv3dOp):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNCase1(TestCase1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNGroup1(TestWithGroup1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNGroup2(TestWithGroup2):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWith1x1(TestWith1x1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
class TestMKLDNNWithInput1x1Filter1x1(TestWithInput1x1Filter1x1):
def init_kernel_type(self):
self.use_mkldnn = True
self.data_format = "NCHW"
if __name__ == '__main__':
unittest.main()
...@@ -74,6 +74,8 @@ class TestConv3dOp(OpTest): ...@@ -74,6 +74,8 @@ class TestConv3dOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "conv3d" self.op_type = "conv3d"
self.use_cudnn = False self.use_cudnn = False
self.use_mkldnn = False
self.data_format = "AnyLayout"
self.dtype = np.float32 self.dtype = np.float32
self.init_kernel_type() self.init_kernel_type()
self.init_group() self.init_group()
...@@ -83,8 +85,7 @@ class TestConv3dOp(OpTest): ...@@ -83,8 +85,7 @@ class TestConv3dOp(OpTest):
conv3d_param = { conv3d_param = {
'stride': self.stride, 'stride': self.stride,
'pad': self.pad, 'pad': self.pad,
'dilations': self.dilations, 'dilations': self.dilations
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
} }
input = np.random.random(self.input_size).astype(self.dtype) input = np.random.random(self.input_size).astype(self.dtype)
...@@ -101,7 +102,9 @@ class TestConv3dOp(OpTest): ...@@ -101,7 +102,9 @@ class TestConv3dOp(OpTest):
'paddings': self.pad, 'paddings': self.pad,
'groups': self.groups, 'groups': self.groups,
'dilations': self.dilations, 'dilations': self.dilations,
'use_cudnn': self.use_cudnn 'use_cudnn': self.use_cudnn,
'use_mkldnn': self.use_mkldnn,
'data_format': self.data_format
} }
self.outputs = {'Output': output} self.outputs = {'Output': output}
...@@ -109,59 +112,35 @@ class TestConv3dOp(OpTest): ...@@ -109,59 +112,35 @@ class TestConv3dOp(OpTest):
return core.is_compiled_with_cuda() and self.use_cudnn return core.is_compiled_with_cuda() and self.use_cudnn
def test_check_output(self): def test_check_output(self):
if self.testcudnn(): place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) self.check_output_with_place(place, atol=1e-5)
self.check_output_with_place(place, atol=1e-5)
else:
self.check_output()
def test_check_grad(self): def test_check_grad(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
if self.testcudnn(): place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) self.check_grad_with_place(
self.check_grad_with_place( place, {'Input', 'Filter'}, 'Output', max_relative_error=0.03)
place,
set(['Input', 'Filter']),
'Output',
max_relative_error=0.03)
else:
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.03)
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
if self.testcudnn(): place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) self.check_grad_with_place(
self.check_grad_with_place( place, ['Input'],
place, ['Input'], 'Output',
'Output', max_relative_error=0.03,
max_relative_error=0.03, no_grad_set=set(['Filter']))
no_grad_set=set(['Filter']))
else:
self.check_grad(
['Input'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
if self.dtype == np.float16: if self.dtype == np.float16:
return return
if self.testcudnn(): place = core.CUDAPlace(0) if self.testcudnn() else core.CPUPlace()
place = core.CUDAPlace(0) self.check_grad_with_place(
self.check_grad_with_place( place, ['Input'],
place, ['Filter'], 'Output',
'Output', max_relative_error=0.03,
max_relative_error=0.03, no_grad_set=set(['Input']))
no_grad_set=set(['Input']))
else:
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.03,
no_grad_set=set(['Input']))
def init_test_case(self): def init_test_case(self):
self.pad = [0, 0, 0] self.pad = [0, 0, 0]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册