提交 43aad989 编写于 作者: Z zchen0211

deconv

上级 5ec55e79
......@@ -100,4 +100,5 @@ REGISTER_OP(deconv2d, ops::Deconv2DOp, ops::Deconv2DOpMaker, deconv2d_grad,
REGISTER_OP_CPU_KERNEL(
deconv2d, ops::GemmDeconv2DKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
deconv2d_grad, ops::GemmConv2DKernel<paddle::platform::CPUPlace, float>);
deconv2d_grad,
ops::GemmDeconvGrad2DKernel<paddle::platform::CPUPlace, float>);
......@@ -18,6 +18,7 @@
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
deconv2d, ops::GemmConvGrad2DKernel<paddle::platform::GPUPlace, float>);
deconv2d, ops::GemmDeconv2DKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
deconv2d_grad, ops::GemmConv2DKernel<paddle::platform::GPUPlace, float>);
deconv2d_grad,
ops::GemmDeconvGrad2DKernel<paddle::platform::GPUPlace, float>);
......@@ -80,10 +80,10 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
col2im;
// use col_shape in the im2col and col2im calculation
framework::DDim col_shape = {C, K_H, K_W, H, W};
DDim col_shape = {C, K_H, K_W, H, W};
// use col_matrix_shape in the gemm calculation
framework::DDim col_matrix_shape = {M * K_H * K_W, H * W};
DDim col_matrix_shape = {M * K_H * K_W, H * W};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
......@@ -124,7 +124,6 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
}
};
/*
template <typename Place, typename T>
class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
public:
......@@ -143,8 +142,8 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
context.Output<Tensor>(framework::GradVarName("Filter"));
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
// no paddings and groups allowed in deconv
// Actually, no paddings and groups allowed in deconv
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
int N = input->dims()[0];
int M = input->dims()[1];
......@@ -154,19 +153,23 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
int K_H = filter.dims()[2];
int K_W = filter.dims()[3];
int C = output->dims()[1]; // output channels
int O_H = output->dims()[2];
int O_W = output->dims()[3];
int C = output_grad->dims()[1]; // output channels
int O_H = output_grad->dims()[2];
int O_W = output_grad->dims()[3];
// Two functors required to get to the right shape
paddle::operators::math::Col2ImFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
col2im;
paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T>
im2col;
// use col_shape in the im2col and col2im calculation
framework::DDim col_shape = {C, K_H, K_W, H, W};
DDim col_shape = {C, K_H, K_W, H, W};
// use col_matrix_shape in the gemm calculation
framework::DDim col_matrix_shape = {M * K_H * K_W, H * W};
DDim col_matrix_shape = {C * K_H * K_W, H * W};
Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace());
......@@ -179,37 +182,60 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
DDim output_shape = {C, O_H, O_W};
DDim input_matrix_shape = {M, H * W};
DDim filter_matrix_shape = {M, C* K_H * K_W};
DDim filter_matrix_shape = {M, C * K_H * K_W};
filter.Resize(filter_matrix_shape);
// deconvolution: gemm + col2im (similar to conv-backward on input)
output->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*output);
// deconvolution grad on input:
// im2col + gemm (similar to conv-forward)
// input need to compute gradient
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < N; i++) {
// batch with size (M, H * W)
Tensor input_batch =
input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
// output size: (C, O_H, O_W)
Tensor output_batch =
output->Slice<T>(i, i + 1).Resize(output_shape);
// filter size: (Co, Ci * Hf * Wf)
// col_matrix = filter * input_batch
// of shape (C * K_H * K_W, H * W)
math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, T(1.0), &col_matrix,
// batch with size (C, O_H * O_W)
Tensor output_grad_batch =
output_grad->Slice<T>(i, i + 1).Resize(output_shape);
// batch with size (M, H, W)
Tensor input_grad_batch =
input_grad->Slice<T>(i, i + 1).Resize(input_matrix_shape);
// im2col: (C * K_H * K_W, H * W)
im2col(context.device_context(), output_grad_batch, col_matrix,
strides[0], strides[1], paddings[0], paddings[1]);
// gemm: dx = filter * dy
math::matmul<Place, T>(context.device_context(), filter, false,
col_matrix, false, T(1.0), &input_grad_batch,
T(0.0));
}
}
col2im(context.device_context(), output_batch, col_matrix, strides[0],
strides[1], 0, 0);
// filter gradient required
if (filter_grad) {
filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape);
auto t = framework::EigenVector<T>::Flatten(filter_grad_);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
for (int i = 0; i < N; ++i) {
// batch with size (C, O_H, O_W)
Tensor output_grad_batch =
output_grad->Slice<T>(i, i + 1).Resize(output_shape);
// input batch
Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
// im2col: (C * K_H * K_W, H * W)
im2col(context.device_context(), output_grad_batch, col_matrix,
strides[0], strides[1], paddings[0], paddings[1]);
// gemm: d_filter = x * y_grad^T
math::matmul<Place, T>(context.device_context(), in_batch, false,
col_matrix, true, T(1.0), &filter_grad, T(1.0));
}
}
}
};
*/
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册