From 48556ba3bbf228cbe418c3a0634df9f7c147b211 Mon Sep 17 00:00:00 2001 From: gongweibao Date: Wed, 11 Oct 2017 12:39:53 +0000 Subject: [PATCH] add block_expand_op --- paddle/operators/block_expand_op.cc | 80 ++++++++++++++++++++++++++ paddle/operators/block_expand_op.cu | 0 paddle/operators/block_expand_op.h | 89 +++++++++++++++++++++++++++++ 3 files changed, 169 insertions(+) create mode 100644 paddle/operators/block_expand_op.cc create mode 100644 paddle/operators/block_expand_op.cu create mode 100644 paddle/operators/block_expand_op.h diff --git a/paddle/operators/block_expand_op.cc b/paddle/operators/block_expand_op.cc new file mode 100644 index 00000000000..0b36dc1ae54 --- /dev/null +++ b/paddle/operators/block_expand_op.cc @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/block_expand_op.h" + +namespace paddle { +namespace operators { + +class BlockExpandOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("block"), + "Input(block) of BlockExpandOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("padding"), + "Input(padding) of BlockExpandOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("stride"), + "Input(stride) of BlockExpandOp should not be null."); + // ctx->SetOutputDim("Out", {1}); + } +}; + +class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker { + public: + BlockExpandOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("block", "The input of block_expand op"); + AddOutput("stride", "The output of block_expand op"); + AddComment(R"DOC( +Expand feature map to minibatch matrix. +- matrix width is: blockH_ * blockW_ * channels_ +- matirx height is: outputH_ * outputW_ + +outputH\_ = 1 + (2paddingH\_ + imgSizeH\_ - blockH\_ + strideH\_ - 1) / + strideH\_ \\ +outputW\_ = 1 + (2paddingW\_ + imgSizeW\_ - blockW\_ + strideW\_ - 1) / + strideW\_ + +The expand method is the same with ExpandConvLayer, but saved the transposed +value. After expanding, output_.sequenceStartPositions will store timeline. +The number of time steps are outputH_outputW_ and the dimension of each +time step is blockH_ * blockW_ * channels_. This layer can be used after +convolution neural network, and before recurrent neural network. +)DOC"); + } +}; + +class BlockExpandGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(block_expand, ops::BlockExpandOp, ops::BlockExpandOpMaker, + block_expand_grad, ops::BlockExpandOpGrad); +REGISTER_OP_CPU_KERNEL( + block_expand, ops::BlockExpanddKernel); +REGISTER_OP_CPU_KERNEL( + block_expand_grad, + ops::BlockExpandGradKernel); diff --git a/paddle/operators/block_expand_op.cu b/paddle/operators/block_expand_op.cu new file mode 100644 index 00000000000..e69de29bb2d diff --git a/paddle/operators/block_expand_op.h b/paddle/operators/block_expand_op.h new file mode 100644 index 00000000000..54a9c5354f1 --- /dev/null +++ b/paddle/operators/block_expand_op.h @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/operators/math/math_function.h" + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class BlockExpandKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using namespace framework; + const Tensor* input = context.Input("input"); + const Tensor* filter = context.Input("filter"); + const Tensor* stride = context.Input("stride"); + const Tensor* padding = context.Input("padding"); + Tensor* out = context.Output("Out"); + + auto input_dim = input->dims(); + size_t N = input_dim[0]; + size_t C = input_dim[1]; + PADDLE_ENFORCE_GE(N, 1, "Input batchsize must >= 1."); + PADDLE_ENFORCE_EQ(input_dim.size(), 4, "Input format must be NCHW."); + + size_t input_height = input_dim[2]; + size_t input_height = input_dim[3]; + + size_t filter_height = filter[0]; + size_t filter_width = filter[1]; + + size_t output_height = 1 + + (input_height + 2 * padding_height - block_height() + + stride_height - 1) / + stride_height; + + size_t output_width = + 1 + + (input_width + 2 * padding_width - block_width() + stride_width - 1) / + stride_width; + + Tensor col; + if (clo_format = KCFO) { + col.Resize( + {N, C, filter_height, filter_width, output_height, output_width}); + } else { + col.Resize( + {N, output_height, output_width, C, filter_height, filter_width}); + } + + for (size_t i = 0; i < N; i++) { + Im2ColFunctor(ctx, one_img, col, stride[0], + stride[1], padding[0], padding[1]); + } + } +}; + +template +class BlockExpandGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + using Tensor = framework::Tensor; + /* + int x_num_col_dims = ctx.template Attr("x_num_col_dims"); + int y_num_col_dims = ctx.template Attr("y_num_col_dims"); + const Tensor* x = ctx.Input("X"); + const Tensor* y = ctx.Input("Y"); + */ + } +}; + +} // namespace operators +} // namespace paddle -- GitLab