提交 0f1b30ef 编写于 作者: C chengduoZH

fix doc and unit test

上级 10bd9f68
...@@ -65,16 +65,17 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( ...@@ -65,16 +65,17 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
"Input", "Input",
"(Tensor) The input tensor of convolution transpose operator. " "(Tensor) The input tensor of convolution transpose operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the " "The format of input tensor is NCHW. Where N is batch size, C is the "
"number of input channels, H and W is the height and width of image."); "number of input channels, H is the height of the feature, and "
"W is the width of the feature.");
AddInput("Filter", AddInput("Filter",
"(Tensor) The filter tensor of convolution transpose operator." "(Tensor) The filter tensor of convolution transpose operator. "
"The format of the filter tensor is CMHW, where C is the number of " "The format of the filter tensor is CMHW, where C is the number of "
"output image channels, M is the number of input image channels, " "output image channels, M is the number of input image channels, "
"H and W is height and width of filter. " "H is the height of the filter, and W is the width of the filter. "
"We enforce groups number == 1 and padding == 0 in " "We enforce groups number == 1 and padding == 0 in "
"convolution transpose Scenario."); "the convolution transpose scenario.");
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator." "(Tensor) The output tensor of convolution transpose operator. "
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
...@@ -85,13 +86,15 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker( ...@@ -85,13 +86,15 @@ Conv2DTransposeOpMaker::Conv2DTransposeOpMaker(
"(vector defalut:{0, 0}), paddings of convolution transpose operator.") "(vector defalut:{0, 0}), paddings of convolution transpose operator.")
.SetDefault({0, 0}); .SetDefault({0, 0});
AddComment(R"DOC( AddComment(R"DOC(
Convolution2D Transpose Operator.
The convolution transpose operation calculates the output based on the input, filter The convolution transpose operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the and strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape. parameters is checked in the infer-shape.
Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch
size, C is the number of channels, H and W is the height and size, C is the number of channels, H is the height of the feature, and
width of feature. Parameters(ksize, strides, paddings) are two elements. W is the width of the feature. Parameters(ksize, strides, paddings) are two elements.
These two elements represent height and width, respectively. These two elements represent height and width, respectively.
The input(X) size and output(Out) size may be different. The input(X) size and output(Out) size may be different.
Example: Example:
...@@ -109,25 +112,26 @@ Example: ...@@ -109,25 +112,26 @@ Example:
Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(
framework::OpProto* proto, framework::OpAttrChecker* op_checker) framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput( AddInput("Input",
"Input",
"(Tensor) The input tensor of convolution transpose operator." "(Tensor) The input tensor of convolution transpose operator."
"The format of input tensor is NCDHW. Where N is batch size, C is " "The format of input tensor is NCDHW. Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and width of " "the number of channels, D is the depth of the feature, H is the "
"feature."); "height of the feature, and "
"W is the width of the feature.");
AddInput("Filter", AddInput("Filter",
"(Tensor) The filter tensor of convolution transpose operator." "(Tensor) The filter tensor of convolution transpose operator."
"The format of the filter tensor is CMDHW, where C is the number of " "The format of the filter tensor is CMDHW, where C is the number of "
"output image channels, M is the number of input image channels, " "output image channels, M is the number of input image channels, D "
"D, H and W is depth, height and width of filter. " "is the depth of the filter, H is the height of the filter, and "
"W is the width of the filter."
"We enforce groups number == 1 and padding == 0 in " "We enforce groups number == 1 and padding == 0 in "
"convolution transpose Scenario."); "the convolution3d transpose scenario.");
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator." "(Tensor) The output tensor of convolution transpose operator."
"The format of output tensor is also NCDHW." "The format of output tensor is also NCDHW."
"Where N is batch size, C is " "Where N is batch size, C is "
"the number of channels, D, H and W is the depth, height and " "the number of channels, D is the depth of the feature, H is the "
"width of feature."); "height of the feature, and W is the width of the feature.");
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
"(vector defalut:{1, 1, 1}), strides of convolution transpose operator.") "(vector defalut:{1, 1, 1}), strides of convolution transpose operator.")
...@@ -137,13 +141,16 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker( ...@@ -137,13 +141,16 @@ Conv3DTransposeOpMaker::Conv3DTransposeOpMaker(
"(vector defalut:{0, 0, 0}), paddings of convolution transpose operator.") "(vector defalut:{0, 0, 0}), paddings of convolution transpose operator.")
.SetDefault({0, 0, 0}); .SetDefault({0, 0, 0});
AddComment(R"DOC( AddComment(R"DOC(
Convolution3D Transpose Operator.
The convolution transpose operation calculates the output based on the input, filter The convolution transpose operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the and strides, paddings, groups parameters. The size of each dimension of the
parameters is checked in the infer-shape. parameters is checked in the infer-shape.
Input(Input, Filter) and output(Output) are in NCDHW format. Where N is batch Input(Input, Filter) and output(Output) are in NCDHW format. Where N is batch
size, C is the number of channels, d, H and W is the depth, height and size, C is the number of channels, D is the depth of the feature,
width of feature. Parameters(ksize, strides, paddings) are three elements. H is the height of the feature, and W is the width of the feature.
Parameters(ksize, strides, paddings) are three elements.
These three elements represent depth, height and width, respectively. These three elements represent depth, height and width, respectively.
The input(X) size and output(Out) size may be different. The input(X) size and output(Out) size may be different.
Example: Example:
......
...@@ -175,6 +175,10 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -175,6 +175,10 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
DDim filter_matrix_shape = {m, c * k_h * k_w}; DDim filter_matrix_shape = {m, c * k_h * k_w};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
if ((!input_grad) && (!filter_grad)) {
return;
}
// convolution transpose grad on input: // convolution transpose grad on input:
// im2col + gemm (similar to conv-forward) // im2col + gemm (similar to conv-forward)
// input need to compute gradient // input need to compute gradient
...@@ -265,7 +269,7 @@ class GemmConv3DTransposeKernel : public framework::OpKernel<T> { ...@@ -265,7 +269,7 @@ class GemmConv3DTransposeKernel : public framework::OpKernel<T> {
const int64_t o_h = output->dims()[3]; const int64_t o_h = output->dims()[3];
const int64_t o_w = output->dims()[4]; const int64_t o_w = output->dims()[4];
paddle::operators::math::Col2VolFunctor<Place, T> col2vol; math::Col2VolFunctor<Place, T> col2vol;
// use col_shape in the vol2col and col2vol calculation // use col_shape in the vol2col and col2vol calculation
DDim col_shape = {c, k_d, k_h, k_w, d, h, w}; DDim col_shape = {c, k_d, k_h, k_w, d, h, w};
...@@ -349,7 +353,7 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -349,7 +353,7 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> {
const int64_t o_w = output_grad->dims()[4]; const int64_t o_w = output_grad->dims()[4];
// Only vol2col functor required for bp to get to the right shape // Only vol2col functor required for bp to get to the right shape
paddle::operators::math::Vol2ColFunctor<Place, T> vol2col; math::Vol2ColFunctor<Place, T> vol2col;
// use col_shape in the vol2col and col2vol calculation // use col_shape in the vol2col and col2vol calculation
DDim col_shape = {c, k_d, k_h, k_w, d, h, w}; DDim col_shape = {c, k_d, k_h, k_w, d, h, w};
...@@ -363,6 +367,10 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -363,6 +367,10 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> {
DDim filter_matrix_shape = {m, c * k_d * k_h * k_w}; DDim filter_matrix_shape = {m, c * k_d * k_h * k_w};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
if ((!input_grad) && (!filter_grad)) {
return;
}
// convolution transpose grad on input: // convolution transpose grad on input:
// vol2col + gemm (similar to conv-forward) // vol2col + gemm (similar to conv-forward)
// input need to compute gradient // input need to compute gradient
......
...@@ -58,36 +58,37 @@ class TestConv2dTransposeOp(OpTest): ...@@ -58,36 +58,37 @@ class TestConv2dTransposeOp(OpTest):
print 'check output here for', self.op_type print 'check output here for', self.op_type
self.check_output() self.check_output()
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.op_type = "conv2d_transpose"
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
self.check_grad( self.check_grad(
['Filter'], ['Filter'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.02,
no_grad_set=set(['Input'])) no_grad_set=set(['Input']))
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
self.check_grad( self.check_grad(
['Input'], ['Input'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.02,
no_grad_set=set(['Filter'])) no_grad_set=set(['Filter']))
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05) set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
def init_op_type(self):
self.op_type = "conv2d_transpose"
# ------------ test_cudnn ------------
class TestCudnn(TestConv2dTransposeOp): class TestCudnn(TestConv2dTransposeOp):
def init_op_type(self): def init_op_type(self):
self.op_type = "conv2d_transpose_cudnn" self.op_type = "conv2d_transpose_cudnn"
......
...@@ -65,20 +65,20 @@ class TestConv3dTransposeOp(OpTest): ...@@ -65,20 +65,20 @@ class TestConv3dTransposeOp(OpTest):
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05) set(['Input', 'Filter']), 'Output', max_relative_error=0.02)
def test_check_grad_no_filter(self): def test_check_grad_no_filter(self):
self.check_grad( self.check_grad(
['Input'], ['Input'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.02,
no_grad_set=set(['Filter'])) no_grad_set=set(['Filter']))
def test_check_grad_no_input(self): def test_check_grad_no_input(self):
self.check_grad( self.check_grad(
['Filter'], ['Filter'],
'Output', 'Output',
max_relative_error=0.05, max_relative_error=0.02,
no_grad_set=set(['Input'])) no_grad_set=set(['Input']))
def init_test_case(self): def init_test_case(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册