提交 b6f9ba48 编写于 作者: C chengduoZH

fix conv2d doc

上级 97e9dd72
......@@ -54,6 +54,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
for (size_t i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE(in_dims[i + 2] + 2 * paddings[i] -
(dilations[i] * (filter_dims[i + 2] - 1) + 1) >
0,
"Due to the settings of paddings, filter_dims and "
"dilations, the output size is less than 0, please check "
"again.");
output_shape.push_back(OutputSize(in_dims[i + 2], filter_dims[i + 2],
dilations[i], paddings[i], paddings[i],
strides[i]));
......@@ -100,11 +106,11 @@ Conv2DOpMaker::Conv2DOpMaker(framework::OpProto* proto,
Convolution Operator.
The convolution operation calculates the output based on the input, filter
and strides, paddings, groups parameters. The size of each dimension of the
and strides, paddings, groups, dilations parameters. The size of each dimension of the
parameters is checked in the infer-shape.
Input(Input, Filter) and output(Output) are in NCHW format. Where N is batch
size, C is the number of channels, H is the height of the feature, and W is
the width of the feature. Parameters(ksize, strides, paddings) are two elements.
the width of the feature. Parameters(ksize, strides, paddings, dilations) are two elements.
These two elements represent height and width, respectively.
The input(X) size and output(Out) size may be different.
......@@ -115,8 +121,8 @@ Example:
Output:
Output shape: (N, C_out, H_out, W_out)
where
H_out = (H_in - filter_size[0] + 2 * paddings[0]) / strides[0] + 1;
W_out = (W_in - filter_size[1] + 2 * paddings[1]) / strides[1] + 1;
H_out = (H_in + 2 * paddings[0] - (dilations[0]*(filter_size[0] - 1) + 1)) / strides[0] + 1;
W_out = (W_in + 2 * paddings[1] - (dilations[1]*(filter_size[1] - 1) + 1)) / strides[1] + 1;
)DOC");
}
......
......@@ -39,6 +39,7 @@ class TestConv2dOp(OpTest):
def setUp(self):
self.init_op_type()
self.init_group()
self.init_dilation()
self.init_test_case()
conv2d_param = {'stride': self.stride, 'pad': self.pad}
......@@ -80,12 +81,14 @@ class TestConv2dOp(OpTest):
def init_test_case(self):
self.pad = [0, 0]
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 3, 3]
def init_dilation(self):
self.dilations = [1, 1]
def init_group(self):
self.groups = 1
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册