提交 91db457f 编写于 作者: C chengduoZH

follow comments

上级 24a796fb
...@@ -87,11 +87,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto, ...@@ -87,11 +87,11 @@ Conv3DOpMaker::Conv3DOpMaker(framework::OpProto* proto,
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCDHW.");
AddAttr<std::vector<int>>("strides", "strides of convolution operator.") AddAttr<std::vector<int>>("strides", "strides of convolution operator.")
.SetDefault({1, 1, 1}); .SetDefault({1, 1, 1});
AddAttr<std::vector<int>>("paddings", "paddings of convolution operator.") AddAttr<std::vector<int>>("paddings", "The paddings of convolution operator.")
.SetDefault({0, 0, 0}); .SetDefault({0, 0, 0});
AddAttr<int>( AddAttr<int>(
"groups", "groups",
"group size of convolution operator. " "The group size of convolution operator. "
"Refer to grouped convolution in Alex Krizhevsky's paper: " "Refer to grouped convolution in Alex Krizhevsky's paper: "
"when group=2, the first half of the filters are only connected to the " "when group=2, the first half of the filters are only connected to the "
"first half of the input channels, and the second half only connected " "first half of the input channels, and the second half only connected "
......
...@@ -93,10 +93,13 @@ class GemmConv3DKernel : public framework::OpKernel<T> { ...@@ -93,10 +93,13 @@ class GemmConv3DKernel : public framework::OpKernel<T> {
Tensor col_matrix = col; Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
framework::DDim input_shape = {input->dims()[1], input->dims()[2], framework::DDim input_shape = {
input->dims()[3], input->dims()[4]}; input->dims()[1], input->dims()[2], input->dims()[3],
framework::DDim filter_matrix_shape = {filter.dims()[0], input->dims()[4]}; // channel, depth, height, width
filter.numel() / filter.dims()[0]}; framework::DDim filter_matrix_shape = {
filter.dims()[0],
filter.numel() / filter.dims()[0]}; // filter_out_channel,
// filter_in_channel*filter_depth*filter_height*filter_width
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
framework::DDim output_matrix_shape = { framework::DDim output_matrix_shape = {
...@@ -177,15 +180,18 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> { ...@@ -177,15 +180,18 @@ class GemmConvGrad3DKernel : public framework::OpKernel<T> {
Tensor col_matrix = col; Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
framework::DDim input_shape = {input->dims()[1], input->dims()[2], framework::DDim input_shape = {
input->dims()[3], input->dims()[4]}; input->dims()[1], input->dims()[2], input->dims()[3],
input->dims()[4]}; // channel, depth, height, width
framework::DDim output_matrix_shape = {output_grad->dims()[1], framework::DDim output_matrix_shape = {output_grad->dims()[1],
output_grad->dims()[2] * output_grad->dims()[2] *
output_grad->dims()[3] * output_grad->dims()[3] *
output_grad->dims()[4]}; output_grad->dims()[4]};
framework::DDim filter_matrix_shape = {filter.dims()[0], framework::DDim filter_matrix_shape = {
filter.numel() / filter.dims()[0]}; filter.dims()[0],
filter.numel() / filter.dims()[0]}; // filter_out_channel,
// filter_in_channel*filter_depth*filter_height*filter_width
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
// convolution backward input operator: gemm + col2vol // convolution backward input operator: gemm + col2vol
......
...@@ -34,7 +34,7 @@ def conv3d_forward_naive(input, filter, group, conv_param): ...@@ -34,7 +34,7 @@ def conv3d_forward_naive(input, filter, group, conv_param):
for k in range(sub_out_c): for k in range(sub_out_c):
out[:, g * sub_out_c + k, d, i, j] = \ out[:, g * sub_out_c + k, d, i, j] = \
np.sum(input_pad_masked * f_sub[k, :, :, :, :], np.sum(input_pad_masked * f_sub[k, :, :, :, :],
axis=(1, 2, 3,4)) axis=(1, 2, 3, 4))
return out return out
...@@ -65,7 +65,6 @@ class TestConv3dOp(OpTest): ...@@ -65,7 +65,6 @@ class TestConv3dOp(OpTest):
self.check_grad( self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05) set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
def test_check_grad_no_filter(self):
self.check_grad( self.check_grad(
['Input'], ['Input'],
'Output', 'Output',
...@@ -80,8 +79,6 @@ class TestConv3dOp(OpTest): ...@@ -80,8 +79,6 @@ class TestConv3dOp(OpTest):
no_grad_set=set(['Input'])) no_grad_set=set(['Input']))
def init_test_case(self): def init_test_case(self):
# self.groups = 1
# self.op_type = "conv3d"
self.pad = [0, 0, 0] self.pad = [0, 0, 0]
self.stride = [1, 1, 1] self.stride = [1, 1, 1]
self.input_size = [2, 3, 5, 5, 5] # NCDHW self.input_size = [2, 3, 5, 5, 5] # NCDHW
...@@ -98,8 +95,6 @@ class TestConv3dOp(OpTest): ...@@ -98,8 +95,6 @@ class TestConv3dOp(OpTest):
class TestCase1(TestConv3dOp): class TestCase1(TestConv3dOp):
def init_test_case(self): def init_test_case(self):
# self.groups = 1
# self.op_type = "conv3d"
self.pad = [1, 1, 1] self.pad = [1, 1, 1]
self.stride = [1, 1, 1] self.stride = [1, 1, 1]
self.input_size = [2, 3, 5, 5, 5] # NCDHW self.input_size = [2, 3, 5, 5, 5] # NCDHW
...@@ -114,7 +109,6 @@ class TestCase1(TestConv3dOp): ...@@ -114,7 +109,6 @@ class TestCase1(TestConv3dOp):
self.op_type = "conv3d" self.op_type = "conv3d"
'''
class TestWithGroup1(TestConv3dOp): class TestWithGroup1(TestConv3dOp):
def init_group(self): def init_group(self):
self.groups = 3 self.groups = 3
...@@ -129,7 +123,7 @@ class TestWithGroup2(TestCase1): ...@@ -129,7 +123,7 @@ class TestWithGroup2(TestCase1):
def init_op_type(self): def init_op_type(self):
self.op_type = "conv3d" self.op_type = "conv3d"
'''
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册