提交 d97a732f 编写于 作者: Z zchen0211

deconv

上级 e59ca752
...@@ -30,7 +30,6 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -30,7 +30,6 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
auto filter_dims = ctx->GetInputDim("Filter"); auto filter_dims = ctx->GetInputDim("Filter");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides"); std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings"); std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int groups = ctx->Attrs().Get<int>("groups");
for (int i = 0; i < paddings.size(); ++i) { for (int i = 0; i < paddings.size(); ++i) {
PADDLE_ENFORCE_EQ(paddings[i], 0, "No Padding allowed in deconv op."); PADDLE_ENFORCE_EQ(paddings[i], 0, "No Padding allowed in deconv op.");
...@@ -41,9 +40,6 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -41,9 +40,6 @@ void Deconv2DOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0], PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[0],
"input and kernel input dimension should be equal."); "input and kernel input dimension should be equal.");
PADDLE_ENFORCE_EQ(groups, 1,
"The number of groups should be 1 in case of deconv op.");
auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2]; auto output_height = (in_dims[2] - 1) * strides[0] + filter_dims[2];
auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3]; auto output_width = (in_dims[3] - 1) * strides[1] + filter_dims[3];
ctx->SetOutputDim("Output", ctx->SetOutputDim("Output",
......
...@@ -83,7 +83,7 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> { ...@@ -83,7 +83,7 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
DDim col_shape = {C, K_H, K_W, H, W}; DDim col_shape = {C, K_H, K_W, H, W};
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
DDim col_matrix_shape = {M * K_H * K_W, H * W}; DDim col_matrix_shape = {C * K_H * K_W, H * W};
Tensor col; Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
...@@ -108,11 +108,11 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> { ...@@ -108,11 +108,11 @@ class GemmDeconv2DKernel : public framework::OpKernel<T> {
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
// batch with size (M, H * W) // batch with size (M, H * W)
Tensor input_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape); Tensor input_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
// filter size: (M, C * K_H * K_W)
// output size: (C, O_H, O_W) // output size: (C, O_H, O_W)
Tensor output_batch = output->Slice<T>(i, i + 1).Resize(output_shape); Tensor output_batch = output->Slice<T>(i, i + 1).Resize(output_shape);
// filter size: (Co, Ci * Hf * Wf)
// col_matrix = filter * input_batch // col_matrix = filter * input_batch
// of shape (C * K_H * K_W, H * W) // of shape (C * K_H * K_W, H * W)
math::matmul<Place, T>(context.device_context(), filter, true, math::matmul<Place, T>(context.device_context(), filter, true,
...@@ -132,8 +132,8 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { ...@@ -132,8 +132,8 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
const Tensor* output_grad = const Tensor* output_grad =
context.Input<Tensor>(framework::GradVarName("Output")); context.Input<Tensor>(framework::GradVarName("Output"));
// For filter, we do not use const pointer // For filter, we do not use const pointer b/c we will do reshape
// but we should avoid // but we should avoid modifying its value
Tensor filter = *context.Input<Tensor>("Filter"); Tensor filter = *context.Input<Tensor>("Filter");
Tensor* input_grad = Tensor* input_grad =
...@@ -157,7 +157,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { ...@@ -157,7 +157,7 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
int O_H = output_grad->dims()[2]; int O_H = output_grad->dims()[2];
int O_W = output_grad->dims()[3]; int O_W = output_grad->dims()[3];
// Two functors required to get to the right shape // Only im2col functor required for bp to get to the right shape
paddle::operators::math::Im2ColFunctor< paddle::operators::math::Im2ColFunctor<
paddle::operators::math::ColFormat::kCFO, Place, T> paddle::operators::math::ColFormat::kCFO, Place, T>
im2col; im2col;
...@@ -166,15 +166,13 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { ...@@ -166,15 +166,13 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
DDim col_shape = {C, K_H, K_W, H, W}; DDim col_shape = {C, K_H, K_W, H, W};
// use col_matrix_shape in the gemm calculation // use col_matrix_shape in the gemm calculation
DDim col_matrix_shape = {C * K_H * K_W, H * W}; DDim col_matrix_shape_f = {C * H * W, K_H * K_W};
Tensor col; Tensor col;
col.mutable_data<T>(col_shape, context.GetPlace()); col.mutable_data<T>(col_shape, context.GetPlace());
// col_matrix shares the same piece of data with col, // col_matrix shares the same piece of data with col,
// but will be reshaped into a two-dimensional matrix shape // but will be reshaped into a two-dimensional matrix shape
// to call the matrix multiplication interface. // to call the matrix multiplication interface.
Tensor col_matrix = col;
col_matrix.Resize(col_matrix_shape);
DDim output_shape = {C, O_H, O_W}; DDim output_shape = {C, O_H, O_W};
DDim input_matrix_shape = {M, H * W}; DDim input_matrix_shape = {M, H * W};
...@@ -186,6 +184,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { ...@@ -186,6 +184,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// im2col + gemm (similar to conv-forward) // im2col + gemm (similar to conv-forward)
// input need to compute gradient // input need to compute gradient
if (input_grad) { if (input_grad) {
Tensor col_matrix = col;
DDim col_matrix_shape = {C * K_H * K_W, H * W};
col_matrix.Resize(col_matrix_shape);
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
auto t = framework::EigenVector<T>::Flatten(*input_grad); auto t = framework::EigenVector<T>::Flatten(*input_grad);
t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0)); t.device(context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(0));
...@@ -194,14 +196,18 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { ...@@ -194,14 +196,18 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// batch with size (C, O_H * O_W) // batch with size (C, O_H * O_W)
Tensor output_grad_batch = Tensor output_grad_batch =
output_grad->Slice<T>(i, i + 1).Resize(output_shape); output_grad->Slice<T>(i, i + 1).Resize(output_shape);
// filter of size (M, C * K_H * K_W)
// batch with size (M, H, W) // batch with size (M, H, W)
Tensor input_grad_batch = Tensor input_grad_batch =
input_grad->Slice<T>(i, i + 1).Resize(input_matrix_shape); input_grad->Slice<T>(i, i + 1).Resize(input_matrix_shape);
// im2col: (C * K_H * K_W, H * W) // im2col: dy from (C, O_H, O_W) -> (C * K_H * K_W, H * W)
im2col(context.device_context(), output_grad_batch, col_matrix, im2col(context.device_context(), output_grad_batch, col_matrix,
strides[0], strides[1], paddings[0], paddings[1]); strides[0], strides[1], paddings[0], paddings[1]);
// gemm: dx = filter * dy // gemm: dx = filter * dy
// (M, C * K_H * K_W) * (C * K_H * K_W, H * W) -> (M, C, H)
math::matmul<Place, T>(context.device_context(), filter, false, math::matmul<Place, T>(context.device_context(), filter, false,
col_matrix, false, T(1.0), &input_grad_batch, col_matrix, false, T(1.0), &input_grad_batch,
T(0.0)); T(0.0));
...@@ -210,6 +216,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { ...@@ -210,6 +216,10 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// filter gradient required // filter gradient required
if (filter_grad) { if (filter_grad) {
Tensor col_matrix_f = col;
DDim col_matrix_shape_f = {C * H * W, K_H * K_W};
col_matrix_f.Resize(col_matrix_shape_f);
filter_grad->mutable_data<T>(context.GetPlace()); filter_grad->mutable_data<T>(context.GetPlace());
Tensor filter_grad_ = *filter_grad; Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape); filter_grad_.Resize(filter_matrix_shape);
...@@ -223,10 +233,12 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> { ...@@ -223,10 +233,12 @@ class GemmDeconvGrad2DKernel : public framework::OpKernel<T> {
// input batch // input batch
Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape); Tensor in_batch = input->Slice<T>(i, i + 1).Resize(input_matrix_shape);
// im2col: (C * K_H * K_W, H * W) // im2col: (C * H * W, K_H * K_W)
im2col(context.device_context(), output_grad_batch, col_matrix, im2col(context.device_context(), output_grad_batch, col_matrix_f,
strides[0], strides[1], paddings[0], paddings[1]); strides[0], strides[1], paddings[0], paddings[1]);
// gemm: d_filter = x * y_grad^T // gemm: d_filter = x * y_grad^T
// (M, C * H * W) * (K_H * K_W, C * H * W) -> (M, C, H)
math::matmul<Place, T>(context.device_context(), in_batch, false, math::matmul<Place, T>(context.device_context(), in_batch, false,
col_matrix, true, T(1.0), &filter_grad_, T(1.0)); col_matrix, true, T(1.0), &filter_grad_, T(1.0));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册