diff --git a/paddle/operators/block_expand_op.cc b/paddle/operators/block_expand_op.cc index b3fad3c81f444e0654125d2afd96a014715654e7..49c7011fe1fe04226f761ca3b98a585ae193eb30 100644 --- a/paddle/operators/block_expand_op.cc +++ b/paddle/operators/block_expand_op.cc @@ -109,7 +109,18 @@ class BlockExpandGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override { + using namespace framework; + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output of BlockExpandOp op should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + + auto in_dim = ctx->GetInputDim("X"); + + ctx->SetOutputDim(GradVarName("Out"), in_dim); + } }; } // namespace operators @@ -117,7 +128,7 @@ class BlockExpandGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(block_expand, ops::BlockExpandOp, ops::BlockExpandOpMaker, - block_expand_grad, ops::BlockExpandOpGrad); + block_expand_grad, ops::BlockExpandGradOp); REGISTER_OP_CPU_KERNEL( block_expand, ops::BlockExpandKernel); REGISTER_OP_CPU_KERNEL( diff --git a/paddle/operators/block_expand_op.cu b/paddle/operators/block_expand_op.cu index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..492ac0c9b2ed2e1092e3eda4504576d9083f5dcb 100644 --- a/paddle/operators/block_expand_op.cu +++ b/paddle/operators/block_expand_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/block_expand_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL( + block_expand, ops::BlockExpandKernel); +REGISTER_OP_GPU_KERNEL( + block_expand_grad, + ops::BlockExpandGradKernel); diff --git a/paddle/operators/block_expand_op.h b/paddle/operators/block_expand_op.h index bd6b307852031a463c46f8c2c87eb615ef37616f..b272582883b5c891a47367d578c79c0d45fc334b 100644 --- a/paddle/operators/block_expand_op.h +++ b/paddle/operators/block_expand_op.h @@ -69,12 +69,12 @@ class BlockExpandKernel : public framework::OpKernel { stride_width, padding_height, padding_width, outputHeight, outputWidth); for (int i = 0; i < N; i++) { - Tensor src = in->Slice(i, i + 1).Resize(C, img_height, img_width); - Tensor dst = out->Slice(i, i + 1).Resize(outputHeight, outputWidth, C, - block_height, block_width); - math::Im2ColFunctor( - ctx, src, dst, stride_height, stride_width, padding_height, - padding_width); + Tensor src = in->Slice(i, i + 1).Resize({C, img_height, img_width}); + Tensor dst = out->Slice(i, i + 1).Resize( + {outputHeight, outputWidth, C, block_height, block_width}); + math::Im2ColFunctor f; + f(ctx.device_context(), src, dst, stride_height, stride_width, + padding_height, padding_width); } } }; @@ -84,6 +84,40 @@ class BlockExpandGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { using namespace framework; + auto* in = ctx.Input("X"); + auto* out = ctx.Input("Out"); + auto* out_grad = ctx.Output(GradVarName("Out")); + out_grad->mutable_data(ctx.GetPlace()); + + auto in_dim = in->dims(); + int N = in_dim[0]; + int C = in_dim[1]; + int img_height = in_dim[2]; + int img_width = in_dim[3]; + + int block_height = ctx.Attr("blockHeight"); + int block_width = ctx.Attr("blockWidth"); + int stride_height = ctx.Attr("strideHeight"); + int stride_width = ctx.Attr("strideWidth"); + int padding_height = ctx.Attr("paddingHeight"); + int padding_width = ctx.Attr("paddingWidth"); + + int outputHeight = 0; + int outputWidth = 0; + + get_blockexpand_output_shape( + img_height, img_width, block_height, block_width, stride_height, + stride_width, padding_height, padding_width, outputHeight, outputWidth); + + for (int i = 0; i < N; i++) { + Tensor dst = + out_grad->Slice(i, i + 1).Resize({C, img_height, img_width}); + Tensor src = out->Slice(i, i + 1).Resize( + {outputHeight, outputWidth, C, block_height, block_width}); + math::Im2ColFunctor f; + f(ctx.device_context(), src, dst, stride_height, stride_width, + padding_height, padding_width); + } } };