提交 91db457f 编写于 作者: C chengduoZH

follow comments

上级 24a796fb
......@@ -87,11 +87,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"The format of output tensor is also NCDHW.");
AddAttr<std::vector<int>>("strides", "strides of convolution operator.")
.SetDefault({1, 1, 1});
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.")
AddAttr<std::vector<int>>("paddings", "The paddings of convolution operator.")
.SetDefault({0, 0, 0});
AddAttr<int>(
"groups",
"group size of convolution operator. "
"The group size of convolution operator. "
"Refer to grouped convolution in Alex Krizhevsky's paper: "
"when group=2, the first half of the filters are only connected to the "
"first half of the input channels, and the second half only connected "
......
......@@ -93,10 +93,13 @@ class GemmConv3DKernel : public framework::OpKernel<T> {
Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape);
framework::DDim input_shape = {input->dims()[1], input->dims()[2],
input->dims()[3], input->dims()[4]};
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
framework::DDim input_shape = {
input->dims()[1], input->dims()[2], input->dims()[3],
input->dims()[4]}; // channel, depth, height, width
framework::DDim filter_matrix_shape = {
filter.dims()[0],
filter.numel() / filter.dims()[0]}; // filter_out_channel,
// filter_in_channel*filter_depth*filter_height*filter_width
filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = {
......@@ -177,15 +180,18 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> {
Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape);
framework::DDim input_shape = {input->dims()[1], input->dims()[2],
input->dims()[3], input->dims()[4]};
framework::DDim input_shape = {
input->dims()[1], input->dims()[2], input->dims()[3],
input->dims()[4]}; // channel, depth, height, width
framework::DDim output_matrix_shape = {output_grad->dims()[1],
output_grad->dims()[2] *
output_grad->dims()[3] *
output_grad->dims()[4]};
framework::DDim filter_matrix_shape = {filter.dims()[0],
filter.numel() / filter.dims()[0]};
framework::DDim filter_matrix_shape = {
filter.dims()[0],
filter.numel() / filter.dims()[0]}; // filter_out_channel,
// filter_in_channel*filter_depth*filter_height*filter_width
filter.Resize(filter_matrix_shape);
// convolution backward input operator: gemm + col2vol
......
......@@ -34,7 +34,7 @@ def conv3d_forward_naive(input, filter, group, conv_param):
for k in range(sub_out_c):
out[:, g * sub_out_c + k, d, i, j] = \
np.sum(input_pad_masked * f_sub[k, :, :, :, :],
axis=(1, 2, 3,4))
axis=(1, 2, 3, 4))
return out
......@@ -65,7 +65,6 @@ class TestConv3dOp(OpTest):
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
def test_check_grad_no_filter(self):
self.check_grad(
['Input'],
'Output',
......@@ -80,8 +79,6 @@ class TestConv3dOp(OpTest):
no_grad_set=set(['Input']))
def init_test_case(self):
# self.groups = 1
# self.op_type = "conv3d"
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 5, 5, 5] # NCDHW
......@@ -98,8 +95,6 @@ class TestConv3dOp(OpTest):
class TestCase1(TestConv3dOp):
def init_test_case(self):
# self.groups = 1
# self.op_type = "conv3d"
self.pad = [1, 1, 1]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 5, 5, 5] # NCDHW
......@@ -114,7 +109,6 @@ class TestCase1(TestConv3dOp):
self.op_type = "conv3d"
'''
class TestWithGroup1(TestConv3dOp):
def init_group(self):
self.groups = 3
......@@ -129,7 +123,7 @@ class TestWithGroup2(TestCase1):
def init_op_type(self):
self.op_type = "conv3d"
'''
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册