提交 669c0df6 编写于 作者: Y Yibing Liu

Add groups for conv transpose

上级 8b1b7564
...@@ -32,6 +32,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -32,6 +32,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations"); std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
int groups = ctx->Attrs().Get<int>("groups");
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5, PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
"ConvTransposeOp intput should be 4-D or 5-D tensor."); "ConvTransposeOp intput should be 4-D or 5-D tensor.");
...@@ -48,10 +49,10 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -48,10 +49,10 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"ConvTransposeOp paddings dimension and dilations " "ConvTransposeOp paddings dimension and dilations "
"dimension should be the same."); "dimension should be the same.");
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"In ConvTransposeOp, The input channel should be the same " "In ConvTransposeOp, The number of input channels should "
"as the number of filters."); "be equal to the number of filter' channels.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[1]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[1] * groups});
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1; auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] + output_shape.push_back((in_dims[i + 2] - 1) * strides[i] - 2 * paddings[i] +
...@@ -102,7 +103,10 @@ void Conv2DTransposeOpMaker::Make() { ...@@ -102,7 +103,10 @@ void Conv2DTransposeOpMaker::Make() {
AddOutput("Output", AddOutput("Output",
"(Tensor) The output tensor of convolution transpose operator. " "(Tensor) The output tensor of convolution transpose operator. "
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW.");
AddAttr<int>("groups",
"(int default:1), the groups number of the convolution "
"transpose operator. ")
.SetDefault(1);
AddAttr<std::vector<int>>("dilations", AddAttr<std::vector<int>>("dilations",
"(vector<int> default:{1, 1}), the " "(vector<int> default:{1, 1}), the "
"dilations(h_dilation, w_dilation) of convolution " "dilations(h_dilation, w_dilation) of convolution "
......
...@@ -70,7 +70,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -70,7 +70,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
// groups will alway be disabled in conv2dtranspose. int groups = context.Attr<int>("groups");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -81,10 +81,10 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -81,10 +81,10 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// use col_shape in the im2col and col2im (or vol2col and col2vol) // use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation // calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} // col_shape_vec: {c/g, k_h, k_w, h, w} or {c/g, k_d, k_h, k_w, d, h, w}
size_t data_dim = filter_shape_vec.size() - 2; size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
col_shape_vec[0] = output->dims()[1]; col_shape_vec[0] = output->dims()[1] / groups;
for (size_t j = 0; j < data_dim; ++j) { for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2]; col_shape_vec[j + 1] = filter_shape_vec[j + 2];
col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2]; col_shape_vec[j + 1 + data_dim] = input_shape_vec[j + 2];
...@@ -92,7 +92,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -92,7 +92,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
DDim col_shape(framework::make_ddim(col_shape_vec)); DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
// size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) // size: (c/g * k_h * k_w, h * w) or (c/g * k_d * k_h * k_w, d * h * w)
DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1); DDim col_matrix_shape = framework::flatten_to_2d(col_shape, data_dim + 1);
Tensor col; Tensor col;
...@@ -111,7 +111,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -111,7 +111,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// input matrix size: (m, h * w) or (m, d * h * w) // input matrix size: (m, h * w) or (m, d * h * w)
DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]}; DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]};
// filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w) // filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w)
DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]}; DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
...@@ -121,6 +121,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -121,6 +121,8 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
auto blas = math::GetBlas<DeviceContext, T>(dev_ctx); auto blas = math::GetBlas<DeviceContext, T>(dev_ctx);
set_zero(dev_ctx, output, static_cast<T>(0)); set_zero(dev_ctx, output, static_cast<T>(0));
int in_step = static_cast<int>(input->dims()[1]) / groups;
int out_step = static_cast<int>(output->dims()[1]) / groups;
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im; math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
math::Col2VolFunctor<DeviceContext, T> col2vol; math::Col2VolFunctor<DeviceContext, T> col2vol;
...@@ -133,22 +135,29 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -133,22 +135,29 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);
// col_matrix = filter * input_batch for (int g = 0; g < groups; g++) {
// of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) Tensor in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
blas.MatMul(filter, true, input_batch, false, static_cast<T>(1.0), Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
&col_matrix, static_cast<T>(0.0)); Tensor out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
if (data_dim == 2U) { // col_matrix = filter_slice * input_slice
// col2im: col_matrix -> dy // of shape (c/g * k_h * k_w, h * w)
// from (c * k_h * k_w, h * w) to (c, o_h, o_w) // or (c/g * k_d * k_h * k_w, d * h * w)
col2im(dev_ctx, col, dilations, strides, blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
std::vector<int>{paddings[0], paddings[1], paddings[0], &col_matrix, static_cast<T>(0.0));
paddings[1]},
&output_batch); if (data_dim == 2U) {
} else if (data_dim == 3U) { // col2im: col_matrix -> dy
// col2vol: col_matrix -> dy // from (c/g * k_h * k_w, h * w) to (c/g, o_h, o_w)
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) col2im(dev_ctx, col, dilations, strides,
col2vol(dev_ctx, col, dilations, strides, paddings, &output_batch); std::vector<int>{paddings[0], paddings[1], paddings[0],
paddings[1]},
&out_slice);
} else if (data_dim == 3U) {
// col2vol: col_matrix -> dy
// from (c/g * k_d * k_h * k_w, d * h * w) to (c/g, o_d, o_h, o_w)
col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice);
}
} }
} }
} }
...@@ -174,6 +183,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -174,6 +183,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations"); std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
int groups = context.Attr<int>("groups");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
...@@ -205,9 +215,11 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -205,9 +215,11 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// input matrix size: (m, h * w) or (m, d * h * w) // input matrix size: (m, h * w) or (m, d * h * w)
DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]}; DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]};
// filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w) // filter size: (m, c/g * k_h * k_w) or (m, c/g * k_d * k_h * k_w)
DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]}; DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0] / groups};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
int in_step = static_cast<int>(input->dims()[1]) / groups;
int col_step = static_cast<int>(col_matrix_shape[0]) / groups;
// convolution transpose grad on input: // convolution transpose grad on input:
// im2col + gemm (similar to conv-forward) // im2col + gemm (similar to conv-forward)
...@@ -233,7 +245,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -233,7 +245,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
} }
if (filter_grad) { // filter size (m, c, k_h, k_w) if (filter_grad) { // filter size (m, c/g, k_h, k_w)
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
set_zero(dev_ctx, filter_grad, static_cast<T>(0)); set_zero(dev_ctx, filter_grad, static_cast<T>(0));
filter_grad_ = *filter_grad; filter_grad_ = *filter_grad;
...@@ -268,8 +280,17 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -268,8 +280,17 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or // or
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w) // d, h, w)
blas.MatMul(filter, false, col_matrix, false, static_cast<T>(1.0), for (int g = 0; g < groups; g++) {
&input_grad_batch, static_cast<T>(0.0)); Tensor input_grad_slice =
input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
Tensor col_matrix_slice =
col_matrix.Slice(g * col_step, (g + 1) * col_step);
blas.MatMul(filter_slice, false, col_matrix_slice, false,
static_cast<T>(1.0), &input_grad_slice,
static_cast<T>(0.0));
}
} }
if (filter_grad) { if (filter_grad) {
// input batch // input batch
...@@ -279,8 +300,17 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -279,8 +300,17 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// or // or
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// k_h * k_w) // k_h * k_w)
blas.MatMul(in_batch, false, col_matrix, true, static_cast<T>(1.0), for (int g = 0; g < groups; g++) {
&filter_grad_, static_cast<T>(1.0)); Tensor in_batch_slice =
in_batch.Slice(g * in_step, (g + 1) * in_step);
Tensor filter_grad_slice =
filter_grad_.Slice(g * in_step, (g + 1) * in_step);
Tensor col_matrix_slice =
col_matrix.Slice(g * col_step, (g + 1) * col_step);
blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
static_cast<T>(1.0), &filter_grad_slice,
static_cast<T>(1.0));
}
} }
} }
} }
......
...@@ -21,8 +21,11 @@ from op_test import OpTest ...@@ -21,8 +21,11 @@ from op_test import OpTest
def conv2dtranspose_forward_naive(input_, filter_, attrs): def conv2dtranspose_forward_naive(input_, filter_, attrs):
in_n, in_c, in_h, in_w = input_.shape in_n, in_c, in_h, in_w = input_.shape
f_c, out_c, f_h, f_w = filter_.shape f_c, f_out_c, f_h, f_w = filter_.shape
groups = attrs['groups']
assert in_c == f_c assert in_c == f_c
out_c = f_out_c * groups
sub_in_c = in_c / groups
stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[ stride, pad, dilations = attrs['strides'], attrs['paddings'], attrs[
'dilations'] 'dilations']
...@@ -36,15 +39,21 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs): ...@@ -36,15 +39,21 @@ def conv2dtranspose_forward_naive(input_, filter_, attrs):
for n in range(in_n): for n in range(in_n):
for i in range(in_h): for i in range(in_h):
for j in range(in_w): for j in range(in_w):
input_masked = input_[n, :, i, j] # (c) for g in range(groups):
input_masked = np.reshape(input_masked, (in_c, 1, 1)) input_masked = input_[n, g * sub_in_c:(g + 1) * sub_in_c, i,
input_masked = np.tile(input_masked, (1, f_h, f_w)) j] # (c)
input_masked = np.reshape(input_masked, (sub_in_c, 1, 1))
for k in range(out_c): input_masked = np.tile(input_masked, (1, f_h, f_w))
tmp_out = np.sum(input_masked * filter_[:, k, :, :], axis=0)
i1, i2 = i * stride[0], i * stride[0] + d_bolck_h for k in range(f_out_c):
j1, j2 = j * stride[0], j * stride[0] + d_bolck_h tmp_out = np.sum(
out[n, k, i1:i2:dilations[0], j1:j2:dilations[1]] += tmp_out input_masked *
filter_[g * sub_in_c:(g + 1) * sub_in_c, k, :, :],
axis=0)
i1, i2 = i * stride[0], i * stride[0] + d_bolck_h
j1, j2 = j * stride[0], j * stride[0] + d_bolck_h
out[n, g * f_out_c + k, i1:i2:dilations[0], j1:j2:
dilations[1]] += tmp_out
out = out[:, :, pad[0]:out_h - pad[0], pad[1]:out_w - pad[1]] out = out[:, :, pad[0]:out_h - pad[0], pad[1]:out_w - pad[1]]
return out return out
...@@ -64,6 +73,7 @@ class TestConv2dTransposeOp(OpTest): ...@@ -64,6 +73,7 @@ class TestConv2dTransposeOp(OpTest):
self.attrs = { self.attrs = {
'strides': self.stride, 'strides': self.stride,
'paddings': self.pad, 'paddings': self.pad,
'groups': self.groups,
'dilations': self.dilations, 'dilations': self.dilations,
'use_cudnn': self.use_cudnn, 'use_cudnn': self.use_cudnn,
'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter 'data_format': 'AnyLayout' # TODO(dzhwinter) : should be fix latter
...@@ -127,6 +137,7 @@ class TestConv2dTransposeOp(OpTest): ...@@ -127,6 +137,7 @@ class TestConv2dTransposeOp(OpTest):
self.pad = [0, 0] self.pad = [0, 0]
self.stride = [1, 1] self.stride = [1, 1]
self.dilations = [1, 1] self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1] f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3] self.filter_size = [f_c, 6, 3, 3]
...@@ -140,16 +151,29 @@ class TestWithPad(TestConv2dTransposeOp): ...@@ -140,16 +151,29 @@ class TestWithPad(TestConv2dTransposeOp):
self.pad = [1, 1] self.pad = [1, 1]
self.stride = [1, 1] self.stride = [1, 1]
self.dilations = [1, 1] self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1] f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3] self.filter_size = [f_c, 6, 3, 3]
class TestWithGroups(TestConv2dTransposeOp):
def init_test_case(self):
self.pad = [1, 1]
self.stride = [1, 1]
self.dilations = [1, 1]
self.groups = 2
self.input_size = [2, 4, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 3, 3, 3]
class TestWithStride(TestConv2dTransposeOp): class TestWithStride(TestConv2dTransposeOp):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
self.stride = [2, 2] self.stride = [2, 2]
self.dilations = [1, 1] self.dilations = [1, 1]
self.groups = 1
self.input_size = [2, 3, 5, 5] # NCHW self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1] f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3] self.filter_size = [f_c, 6, 3, 3]
...@@ -159,6 +183,7 @@ class TestWithDilation(TestConv2dTransposeOp): ...@@ -159,6 +183,7 @@ class TestWithDilation(TestConv2dTransposeOp):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
self.stride = [1, 1] self.stride = [1, 1]
self.groups = 1
self.dilations = [2, 2] self.dilations = [2, 2]
self.input_size = [2, 3, 5, 5] # NCHW self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1] f_c = self.input_size[1]
...@@ -176,6 +201,7 @@ class TestCUDNNWithPad(TestWithPad): ...@@ -176,6 +201,7 @@ class TestCUDNNWithPad(TestWithPad):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
self.stride = [1, 1] self.stride = [1, 1]
self.groups = 1
self.dilations = [1, 1] self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1] f_c = self.input_size[1]
...@@ -190,6 +216,7 @@ class TestCUDNNWithStride(TestWithStride): ...@@ -190,6 +216,7 @@ class TestCUDNNWithStride(TestWithStride):
def init_test_case(self): def init_test_case(self):
self.pad = [1, 1] self.pad = [1, 1]
self.stride = [2, 2] self.stride = [2, 2]
self.groups = 1
self.dilations = [1, 1] self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1] f_c = self.input_size[1]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册