提交 4aae1fff 编写于 作者: C chengduoZH

fix conv3d_gemm, unit test and follow comments

上级 c2fbf8c5
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv3d_op.h"
......@@ -52,7 +52,7 @@ void Conv3DOp::InferShape(framework::InferShapeContext* ctx) const {
output_shape.push_back(OutputSizeConv3d(in_dims[i + 2], filter_dims[i],
paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
}
void Conv3DOpGrad::InferShape(framework::InferShapeContext* ctx) const {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/conv3d_op.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
......
......@@ -3,85 +3,59 @@ import numpy as np
from op_test import OpTest
def conv3d_forward_naive(input, filter, group, conv_param):
in_n, in_c, in_d, in_h, in_w = input.shape
out_c, f_c, f_d, f_h, f_w = filter.shape
assert f_c * group == in_c
assert np.mod(out_c, group) == 0
sub_out_c = out_c / group
stride, pad = conv_param['stride'], conv_param['pad']
out_d = 1 + (in_d + 2 * pad[0] - f_h) / stride[0]
out_h = 1 + (in_h + 2 * pad[1] - f_h) / stride[1]
out_w = 1 + (in_w + 2 * pad[2] - f_w) / stride[2]
out = np.zeros((in_n, out_c, out_d, out_h, out_w))
input_pad = np.pad(input, ((0, ), (0, ), (pad[0], ), (pad[1], ),
(pad[2], )),
mode='constant',
constant_values=0)
for d in range(out_d):
for i in range(out_h):
for j in range(out_w):
for g in range(group):
input_pad_masked = \
input_pad[:, g * f_c:(g + 1) * f_c,
d * stride[0]:d * stride[0] + f_d,
i * stride[1]:i * stride[1] + f_h,
j * stride[2]:j * stride[2] + f_w]
f_sub = filter[g * sub_out_c:(g + 1) *
sub_out_c, :, :, :, :]
for k in range(sub_out_c):
out[:, g * sub_out_c + k, d, i, j] = \
np.sum(input_pad_masked * f_sub[k, :, :, :, :],
axis=(1, 2, 3,4))
return out
class TestConv3dOp(OpTest):
def setUp(self):
self.init_groups()
self.op_type = "conv3d"
batch_size = 2
input_channels = 3
input_depth = 5
input_height = 5
input_width = 5
output_channels = 6
filter_depth = 3
filter_height = 3
filter_width = 3
stride = 1
padding = 0
output_depth = (input_depth - filter_depth + 2 * padding) / stride + 1
output_height = (input_height - filter_height + 2 * padding
) / stride + 1
output_width = (input_width - filter_width + 2 * padding) / stride + 1
input = np.random.random((batch_size, input_channels, input_depth,
input_height, input_width)).astype("float32")
filter = np.random.random(
(output_channels, input_channels / self.groups, filter_depth,
filter_height, filter_width)).astype("float32")
output = np.ndarray((batch_size, output_channels, output_depth,
output_height, output_width))
self.init_group()
self.init_op_type()
self.init_test_case()
conv3d_param = {'stride': self.stride, 'pad': self.pad}
input = np.random.random(self.input_size).astype("float32")
filter = np.random.random(self.filter_size).astype("float32")
output = conv3d_forward_naive(input, filter, self.groups, conv3d_param)
self.inputs = {'Input': input, 'Filter': filter}
self.attrs = {
'strides': [1, 1, 1],
'paddings': [0, 0, 0],
'strides': self.stride,
'paddings': self.pad,
'groups': self.groups
}
output_group_channels = output_channels / self.groups
input_group_channels = input_channels / self.groups
for batchid in xrange(batch_size):
for group in xrange(self.groups):
for outchannelid in range(group * output_group_channels,
(group + 1) * output_group_channels):
for deepid in xrange(output_depth):
for rowid in xrange(output_height):
for colid in xrange(output_width):
start_d = (deepid * stride) - padding
start_h = (rowid * stride) - padding
start_w = (colid * stride) - padding
output_value = 0.0
for inchannelid in range(
group * input_group_channels,
(group + 1) * input_group_channels):
for fdeepid in xrange(filter_depth):
for frowid in xrange(filter_height):
for fcolid in xrange(filter_width):
input_value = 0.0
indeepid = start_d + fdeepid
inrowid = start_h + frowid
incolid = start_w + fcolid
if ((indeepid >= 0 and
indeepid < input_depth) and
(inrowid >= 0 and
inrowid < input_height) and
(incolid >= 0 and
incolid < input_width)):
input_value = input[
batchid][inchannelid][
indeepid][inrowid][
incolid]
filter_value = filter[
outchannelid][
inchannelid %
input_group_channels][
fdeepid][frowid][
fcolid]
output_value += input_value * filter_value
output[batchid][outchannelid][deepid][rowid][
colid] = output_value
self.outputs = {'Output': output}
def test_check_output(self):
......@@ -105,14 +79,30 @@ class TestConv3dOp(OpTest):
max_relative_error=0.05,
no_grad_set=set(['Input']))
def init_groups(self):
def init_test_case(self):
# self.groups = 1
# self.op_type = "conv3d"
self.pad = [0, 0, 0]
self.stride = [1, 1, 1]
self.input_size = [2, 3, 5, 5, 5] # NCDHW
assert np.mod(self.input_size[1], self.groups) == 0
f_c = self.input_size[1] / self.groups
self.filter_size = [6, f_c, 3, 3, 3]
def init_group(self):
self.groups = 1
def init_op_type(self):
self.op_type = "conv3d"
class TestWithGroup(TestConv3dOp):
def init_groups(self):
def init_group(self):
self.groups = 3
def init_op_type(self):
self.op_type = "conv3d"
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册