提交 206f32c1 编写于 作者: C chengduoZH

deconv2d kernel and deconv3d kernel write together

上级 0f1b30ef
...@@ -44,7 +44,7 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp, ...@@ -44,7 +44,7 @@ REGISTER_OP(conv2d_transpose_cudnn, ops::ConvTransposeOp,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn, conv2d_transpose_cudnn,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_cudnn_grad, conv2d_transpose_cudnn_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
...@@ -187,17 +187,17 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker, ...@@ -187,17 +187,17 @@ REGISTER_OP(conv2d_transpose, ops::ConvTransposeOp, ops::Conv2DTransposeOpMaker,
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose, conv2d_transpose,
ops::GemmConv2DTransposeKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv2d_transpose_grad, conv2d_transpose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker, REGISTER_OP(conv3d_transpose, ops::ConvTransposeOp, ops::Conv3DTransposeOpMaker,
conv3d_transpose_grad, ops::ConvTransposeOpGrad); conv3d_transpose_grad, ops::ConvTransposeOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose, conv3d_transpose,
ops::GemmConv3DTransposeKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
conv3d_transpose_grad, conv3d_transpose_grad,
ops::GemmConv3DTransposeGradKernel<paddle::platform::CPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::CPUPlace, float>);
...@@ -18,14 +18,14 @@ namespace ops = paddle::operators; ...@@ -18,14 +18,14 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d_transpose, conv2d_transpose,
ops::GemmConv2DTransposeKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv2d_transpose_grad, conv2d_transpose_grad,
ops::GemmConv2DTransposeGradKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv3d_transpose, conv3d_transpose,
ops::GemmConv3DTransposeKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
conv3d_transpose_grad, conv3d_transpose_grad,
ops::GemmConv3DTransposeGradKernel<paddle::platform::GPUPlace, float>); ops::GemmConvTransposeGradKernel<paddle::platform::GPUPlace, float>);
...@@ -57,7 +57,7 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel { ...@@ -57,7 +57,7 @@ class ConvTransposeOpGrad : public framework::OperatorWithKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class GemmConv2DTransposeKernel : public framework::OpKernel<T> { class GemmConvTransposeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
...@@ -70,24 +70,31 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> { ...@@ -70,24 +70,31 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
// groups will alway be disabled in conv2dtranspose. // groups will alway be disabled in conv2dtranspose.
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int64_t m = input->dims()[1];
const int64_t h = input->dims()[2];
const int64_t w = input->dims()[3];
const int64_t k_h = filter.dims()[2]; // input_shape_vec: {h, w} or {d, h, w}
const int64_t k_w = filter.dims()[3]; std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2);
const int64_t c = output->dims()[1]; // output channels
const int64_t o_h = output->dims()[2]; // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
const int64_t o_w = output->dims()[3]; std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
filter_shape_vec.erase(filter_shape_vec.begin(),
math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im; filter_shape_vec.begin() + 2);
// use col_shape in the im2col and col2im calculation // use col_shape in the im2col and col2im (or vol2col and col2vol)
DDim col_shape = {c, k_h, k_w, h, w}; // calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
std::vector<int64_t> col_shape_vec;
col_shape_vec.push_back(output->dims()[1]);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
filter_shape_vec.end());
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(),
input_shape_vec.end());
DDim col_shape(framework::make_ddim(col_shape_vec));
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
DDim col_matrix_shape = {c * k_h * k_w, h * w}; // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
Tensor col; Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
...@@ -98,47 +105,61 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> { ...@@ -98,47 +105,61 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
DDim output_shape = {c, o_h, o_w}; // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
DDim input_matrix_shape = {m, h * w}; DDim output_shape =
framework::slice_ddim(output->dims(), 1, output->dims().size());
// filter size: (m, c * k_h * k_w) // input matrix size: (m, h * w) or (m, d * h * w)
DDim filter_matrix_shape = {m, c * k_h * k_w}; DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]};
// filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w)
DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
math::SetConstant<Place, T> set_zero; math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), output, static_cast<T>(0)); set_zero(context.device_context(), output, static_cast<T>(0));
// convolution transpose: gemm + col2im (similar to conv-backward on input) // convolution transpose: gemm + col2im or col2vol (similar to conv-backward
// on input)
for (int i = 0; i < batch_size; i++) { for (int i = 0; i < batch_size; i++) {
// batch with size (m, h * w) // batch with size (m, h * w) or (m, d * h * w)
Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// output size: (c, o_h, o_w) // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape); Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);
// col_matrix = filter * input_batch // col_matrix = filter * input_batch
// of shape (c * k_h * k_w, h * w) // of shape (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
math::matmul<Place, T>(context.device_context(), filter, true, math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, static_cast<T>(1.0), input_batch, false, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0)); &col_matrix, static_cast<T>(0.0));
// col2im: col_matrix -> dy if (filter_shape_vec.size() == 2) {
// from (c * k_h * k_w, h * w) to (c, o_h, o_w) // col2im: col_matrix -> dy
col2im(context.device_context(), output_batch, col, strides[0], // from (c * k_h * k_w, h * w) to (c, o_h, o_w)
strides[1], 0, 0, 0, 0); math::Col2ImFunctor<math::ColFormat::kCFO, Place, T> col2im;
col2im(context.device_context(), output_batch, col, strides[0],
strides[1], 0, 0, 0, 0);
} else if (filter_shape_vec.size() == 3) {
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
math::Col2VolFunctor<Place, T> col2vol;
col2vol(context.device_context(), output_batch, col, strides[0],
strides[1], strides[2], 0, 0, 0);
}
} }
} }
}; };
template <typename Place, typename T> template <typename Place, typename T>
class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad = const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output")); context.Input<Tensor>(framework::GradVarName("Output"));
// For filter, we do not use const pointer b/c we will do reshape, // For filter, we do not use const pointer b/c we will do reshape,
// but we should avoid modifying its value. // but we should avoid modifying its value.
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
...@@ -147,38 +168,50 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -147,38 +168,50 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
Tensor* filter_grad = Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter")); context.Output<Tensor>(framework::GradVarName("Filter"));
if ((!input_grad) && (!filter_grad)) return;
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// Actually, no paddings and groups allowed in conv transpose. // Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
const int64_t m = input->dims()[1];
const int64_t h = input->dims()[2];
const int64_t w = input->dims()[3];
const int64_t k_h = filter.dims()[2]; // input_shape_vec: {h, w} or {d, h, w}
const int64_t k_w = filter.dims()[3]; std::vector<int64_t> input_shape_vec = framework::vectorize(input->dims());
input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2);
// filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w}
std::vector<int64_t> filter_shape_vec = framework::vectorize(filter.dims());
filter_shape_vec.erase(filter_shape_vec.begin(),
filter_shape_vec.begin() + 2);
// use col_shape in the im2col and col2im (or vol2col and col2vol)
// calculation
// col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w}
std::vector<int64_t> col_shape_vec;
col_shape_vec.push_back(output_grad->dims()[1]);
col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(),
filter_shape_vec.end());
col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(),
input_shape_vec.end());
DDim col_shape(framework::make_ddim(col_shape_vec));
const int64_t c = output_grad->dims()[1]; // output channels // use col_matrix_shape in the gemm calculation
const int64_t o_h = output_grad->dims()[2]; // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w)
const int64_t o_w = output_grad->dims()[3]; DDim col_matrix_shape =
framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1);
// Only im2col functor required for bp to get to the right shape
math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
// use col_shape in the im2col and col2im calculation // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w)
DDim col_shape = {c, k_h, k_w, h, w}; DDim output_shape = framework::slice_ddim(output_grad->dims(), 1,
output_grad->dims().size());
DDim output_shape = {c, o_h, o_w}; // input matrix size: (m, h * w) or (m, d * h * w)
DDim input_matrix_shape = {m, h * w}; DDim input_matrix_shape = {input->dims()[1], col_matrix_shape[1]};
DDim filter_matrix_shape = {m, c * k_h * k_w}; // filter size: (m, c * k_h * k_w) or (m, c * k_d * k_h * k_w)
DDim filter_matrix_shape = {input->dims()[1], col_matrix_shape[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
if ((!input_grad) && (!filter_grad)) {
return;
}
// convolution transpose grad on input: // convolution transpose grad on input:
// im2col + gemm (similar to conv-forward) // im2col + gemm (similar to conv-forward)
// input need to compute gradient // input need to compute gradient
...@@ -190,7 +223,6 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -190,7 +223,6 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
// to call the matrix multiplication interface. // to call the matrix multiplication interface.
Tensor col_matrix; Tensor col_matrix;
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
DDim col_matrix_shape = {c * k_h * k_w, h * w};
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
Tensor filter_grad_; Tensor filter_grad_;
...@@ -212,10 +244,21 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -212,10 +244,21 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
Tensor output_grad_batch = Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape); output_grad->Slice(i, i + 1).Resize(output_shape);
// im2col: dy -> col matrix if (filter_shape_vec.size() == 2) {
// from (c, o_h, o_w) to (c * k_h * k_w, h * w) // im2col: dy -> col matrix
im2col(context.device_context(), output_grad_batch, col, strides[0], // from (c, o_h, o_w) to (c * k_h * k_w, h * w)
strides[1], paddings[0], paddings[0], paddings[1], paddings[1]); math::Im2ColFunctor<math::ColFormat::kCFO, Place, T> im2col;
im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[0], paddings[1],
paddings[1]);
} else if (filter_shape_vec.size() == 3) {
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
math::Vol2ColFunctor<Place, T> vol2col;
vol2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], strides[2], paddings[0], paddings[1],
paddings[2]);
}
if (input_grad) { if (input_grad) {
// batch with size (m, h, w) // batch with size (m, h, w)
...@@ -223,197 +266,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -223,197 +266,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
input_grad->Slice(i, i + 1).Resize(input_matrix_shape); input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
// gemm: dx = filter * dy // gemm: dx = filter * dy
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, h * w) // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, h * w)
math::matmul<Place, T>(context.device_context(), filter, false, // or
col_matrix, false, static_cast<T>(1.0),
&input_grad_batch, static_cast<T>(0.0));
}
if (filter_grad) {
// input batch
Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// gemm: d_filter = x * dy^T
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, k_h * k_w)
math::matmul<Place, T>(context.device_context(), in_batch, false,
col_matrix, true, static_cast<T>(1.0),
&filter_grad_, static_cast<T>(1.0));
}
}
}
}
};
template <typename Place, typename T>
class GemmConv3DTransposeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
// The filter will be reshaped, so it should not be constant pointer
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* output = context.Output<Tensor>("Output");
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// TODO(chengduo): Paddings can be added in future.
// groups will alway be disabled in conv3dtranspose.
const int batch_size = static_cast<int>(input->dims()[0]);
const int64_t m = input->dims()[1];
const int64_t d = input->dims()[2];
const int64_t h = input->dims()[3];
const int64_t w = input->dims()[4];
const int64_t k_d = filter.dims()[2];
const int64_t k_h = filter.dims()[3];
const int64_t k_w = filter.dims()[4];
const int64_t c = output->dims()[1]; // output channels
const int64_t o_d = output->dims()[2];
const int64_t o_h = output->dims()[3];
const int64_t o_w = output->dims()[4];
math::Col2VolFunctor<Place, T> col2vol;
// use col_shape in the vol2col and col2vol calculation
DDim col_shape = {c, k_d, k_h, k_w, d, h, w};
// use col_matrix_shape in the gemm calculation
DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix;
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
DDim output_shape = {c, o_d, o_h, o_w};
DDim input_matrix_shape = {m, d * h * w};
// filter size: (m, c * k_d * k_h * k_w)
DDim filter_matrix_shape = {m, c * k_d * k_h * k_w};
filter.Resize(filter_matrix_shape);
output->mutable_data<T>(context.GetPlace());
math::SetConstant<Place, T> set_zero;
set_zero(context.device_context(), output, static_cast<T>(0));
// convolution transpose: gemm + col2vol (similar to conv-backward on input)
for (int i = 0; i < batch_size; i++) {
// batch with size (m, d * h * w)
Tensor input_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// output size: (c, o_d, o_h, o_w)
Tensor output_batch = output->Slice(i, i + 1).Resize(output_shape);
// col_matrix = filter * input_batch
// of shape (c * k_d * k_h * k_w, d * h * w)
math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, static_cast<T>(1.0),
&col_matrix, static_cast<T>(0.0));
// col2vol: col_matrix -> dy
// from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w)
col2vol(context.device_context(), output_batch, col, strides[0],
strides[1], strides[2], 0, 0, 0);
}
}
};
template <typename Place, typename T>
class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output"));
// For filter, we do not use const pointer b/c we will do reshape,
// but we should avoid modifying its value.
Tensor filter = *context.Input<Tensor>("Filter");
Tensor* input_grad =
context.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad =
context.Output<Tensor>(framework::GradVarName("Filter"));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// Actually, no paddings and groups allowed in conv transpose.
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
const int batch_size = static_cast<int>(input->dims()[0]);
const int64_t m = input->dims()[1];
const int64_t d = input->dims()[2];
const int64_t h = input->dims()[3];
const int64_t w = input->dims()[4];
const int64_t k_d = filter.dims()[2];
const int64_t k_h = filter.dims()[3];
const int64_t k_w = filter.dims()[4];
const int64_t c = output_grad->dims()[1]; // output channels
const int64_t o_d = output_grad->dims()[2];
const int64_t o_h = output_grad->dims()[3];
const int64_t o_w = output_grad->dims()[4];
// Only vol2col functor required for bp to get to the right shape
math::Vol2ColFunctor<Place, T> vol2col;
// use col_shape in the vol2col and col2vol calculation
DDim col_shape = {c, k_d, k_h, k_w, d, h, w};
// use col_matrix_shape in the gemm calculation
DDim col_matrix_shape_f = {c * d * h * w, k_d * k_h * k_w};
DDim output_shape = {c, o_d, o_h, o_w};
DDim input_matrix_shape = {m, d * h * w};
DDim filter_matrix_shape = {m, c * k_d * k_h * k_w};
filter.Resize(filter_matrix_shape);
if ((!input_grad) && (!filter_grad)) {
return;
}
// convolution transpose grad on input:
// vol2col + gemm (similar to conv-forward)
// input need to compute gradient
if (input_grad || filter_grad) {
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface.
Tensor col_matrix;
col_matrix.ShareDataWith(col);
DDim col_matrix_shape = {c * k_d * k_h * k_w, d * h * w};
col_matrix.Resize(col_matrix_shape);
Tensor filter_grad_;
math::SetConstant<Place, T> set_zero;
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), input_grad, static_cast<T>(0));
}
if (filter_grad) { // filter size (m, c * k_d * k_h * k_w)
filter_grad->mutable_data<T>(context.GetPlace());
set_zero(context.device_context(), filter_grad, static_cast<T>(0));
filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
}
for (int i = 0; i < batch_size; i++) {
// batch with size (c, o_d * o_h * o_w)
Tensor output_grad_batch =
output_grad->Slice(i, i + 1).Resize(output_shape);
// vol2col: dy -> col_matrix
// from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w)
vol2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], strides[2], paddings[0], paddings[1], paddings[2]);
if (input_grad) {
// batch with size (m, d, h, w)
Tensor input_grad_batch =
input_grad->Slice(i, i + 1).Resize(input_matrix_shape);
// gemm: dx = filter * dy
// (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m, // (m, c * k_d * k_h * k_w) * (c * k_d * k_h * k_w, d * h * w) -> (m,
// d, h, w) // d, h, w)
math::matmul<Place, T>(context.device_context(), filter, false, math::matmul<Place, T>(context.device_context(), filter, false,
...@@ -424,6 +277,8 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -424,6 +277,8 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> {
// input batch // input batch
Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape); Tensor in_batch = input->Slice(i, i + 1).Resize(input_matrix_shape);
// gemm: d_filter = x * dy^T // gemm: d_filter = x * dy^T
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, k_h * k_w)
// or
// (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d * // (m, d * h * w) * (d * h * w, c * k_d * k_h * k_w) -> (m, c * k_d *
// k_h * k_w) // k_h * k_w)
math::matmul<Place, T>(context.device_context(), in_batch, false, math::matmul<Place, T>(context.device_context(), in_batch, false,
...@@ -434,6 +289,5 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -434,6 +289,5 @@ class GemmConv3DTransposeGradKernel : public framework::OpKernel<T> {
} }
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册