提交 b6f9ba48 编写于 作者: C chengduoZH

fix conv2d doc

上级 97e9dd72
...@@ -54,6 +54,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -54,6 +54,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < paddings.size(); ++i) { for (size_t i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] -
(dilations[i] * (filter_dims[i + 2] - 1) + 1) >
0,
"Due to the settings of paddings, filter_dims and "
"dilations, the output size is less than 0, please check "
"again.");
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2], output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i], paddings[i], dilations[i], paddings[i], paddings[i],
strides[i])); strides[i]));
...@@ -100,11 +106,11 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto, ...@@ -100,11 +106,11 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
Convolution Operator. Convolution Operator.
The convolution operation calculates the output based on the input, filter The convolution operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the and strides, paddings, groups, dilations parameters. The size of each dimension of the
parameters is checked in the infer-shape. parameters is checked in the infer-shape.
Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch
size, C is the number of channels, H is the height of the feature, and W is size, C is the number of channels, H is the height of the feature, and W is
the width of the feature. Parameters(ksize, strides, paddings) are two elements. the width of the feature. Parameters(ksize, strides, paddings, dilations) are two elements.
These two elements represent height and width, respectively. These two elements represent height and width, respectively.
The input(X) size and output(Out) size may be different. The input(X) size and output(Out) size may be different.
...@@ -115,8 +121,8 @@ Example: ...@@ -115,8 +121,8 @@ Example:
Output: Output:
Output shape: (N, C_out, H_out, W_out) Output shape: (N, C_out, H_out, W_out)
where where
H_out = (H_in - filter_size[0] + 2 * paddings[0]) / strides[0] + 1; H_out = (H_in + 2 * paddings[0] - (dilations[0]*(filter_size[0] - 1) + 1)) / strides[0] + 1;
W_out = (W_in - filter_size[1] + 2 * paddings[1]) / strides[1] + 1; W_out = (W_in + 2 * paddings[1] - (dilations[1]*(filter_size[1] - 1) + 1)) / strides[1] + 1;
)DOC"); )DOC");
} }
......
...@@ -39,6 +39,7 @@ class TestConv2dOp(OpTest): ...@@ -39,6 +39,7 @@ class TestConv2dOp(OpTest):
def setUp(self): def setUp(self):
self.init_op_type() self.init_op_type()
self.init_group() self.init_group()
self.init_dilation()
self.init_test_case() self.init_test_case()
conv2d_param = {'stride': self.stride, 'pad': self.pad} conv2d_param = {'stride': self.stride, 'pad': self.pad}
...@@ -80,12 +81,14 @@ class TestConv2dOp(OpTest): ...@@ -80,12 +81,14 @@ class TestConv2dOp(OpTest):
def init_test_case(self): def init_test_case(self):
self.pad = [0, 0] self.pad = [0, 0]
self.stride = [1, 1] self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0 assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 3, 3] self.filter_size = [6, f_c, 3, 3]
def init_dilation(self):
self.dilations = [1, 1]
def init_group(self): def init_group(self):
self.groups = 1 self.groups = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册