提交 fe45f211 编写于 作者: W wanghaoshuang

1. Rename 'block_expand' to im2sequence

2. Refine code and doc
上级 09adb769
...@@ -12,21 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,21 +12,21 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/block_expand_op.h" #include "paddle/operators/im2sequence_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class BlockExpandOp : public framework::OperatorWithKernel { class Im2SequenceOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input of BlockExpandOp should not be null."); "Input(X) of Im2SequenceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output of BlockExpandOp op should not be null."); "Output(Out) of Im2SequenceOp op should not be null.");
auto in_dim = ctx->GetInputDim("X"); auto in_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(in_dim.size(), 4, PADDLE_ENFORCE_EQ(in_dim.size(), 4,
...@@ -55,9 +55,9 @@ class BlockExpandOp : public framework::OperatorWithKernel { ...@@ -55,9 +55,9 @@ class BlockExpandOp : public framework::OperatorWithKernel {
} }
}; };
class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker { class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BlockExpandOpMaker(OpProto* proto, OpAttrChecker* op_checker) Im2SequenceOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", AddInput("X",
"(Tensor)The input tensor has NCHW format." "(Tensor)The input tensor has NCHW format."
...@@ -65,7 +65,7 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -65,7 +65,7 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
"C: channels" "C: channels"
"H: height" "H: height"
"W: width"); "W: width");
AddOutput("Out", "(LodTensor)The output data of block_expand op,"); AddOutput("Out", "(LodTensor)The output data of im2sequence op,");
AddAttr<int>("block_height", "(int)height of block."); AddAttr<int>("block_height", "(int)height of block.");
AddAttr<int>("block_width", "(int)width of block."); AddAttr<int>("block_width", "(int)width of block.");
AddAttr<int>("stride_height", "(int)height of stride."); AddAttr<int>("stride_height", "(int)height of stride.");
...@@ -73,7 +73,7 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -73,7 +73,7 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("padding_height", "(int)height of padding."); AddAttr<int>("padding_height", "(int)height of padding.");
AddAttr<int>("padding_width", "(int)width of padding."); AddAttr<int>("padding_width", "(int)width of padding.");
AddComment(R"DOC( AddComment(R"DOC(
Expand feature map to minibatch matrix. Convert feature map to minibatch matrix.
- matirx height is: output_height * output_width - matirx height is: output_height * output_width
- matrix width is: block_height * block_width * channels - matrix width is: block_height * block_width * channels
...@@ -133,7 +133,7 @@ output.lod = [[0, 4, 8]] ...@@ -133,7 +133,7 @@ output.lod = [[0, 4, 8]]
} }
}; };
class BlockExpandGradOp : public framework::OperatorWithKernel { class Im2SequenceGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -150,11 +150,11 @@ class BlockExpandGradOp : public framework::OperatorWithKernel { ...@@ -150,11 +150,11 @@ class BlockExpandGradOp : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(block_expand, ops::BlockExpandOp, ops::BlockExpandOpMaker, REGISTER_OP(im2sequence, ops::Im2SequenceOp, ops::Im2SequenceOpMaker,
block_expand_grad, ops::BlockExpandGradOp); im2sequence_grad, ops::Im2SequenceGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
block_expand, im2sequence,
ops::BlockExpandKernel<paddle::platform::CPUDeviceContext, float>); ops::Im2SequenceKernel<paddle::platform::CPUDeviceContext, float>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
block_expand_grad, im2sequence_grad,
ops::BlockExpandGradKernel<paddle::platform::CPUDeviceContext, float>); ops::Im2SequenceGradKernel<paddle::platform::CPUDeviceContext, float>);
...@@ -13,13 +13,13 @@ ...@@ -13,13 +13,13 @@
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/block_expand_op.h" #include "paddle/operators/im2sequence_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
block_expand, im2sequence,
ops::BlockExpandKernel<paddle::platform::CUDADeviceContext, float>); ops::Im2SequenceKernel<paddle::platform::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
block_expand_grad, im2sequence_grad,
ops::BlockExpandGradKernel<paddle::platform::CUDADeviceContext, float>); ops::Im2SequenceGradKernel<paddle::platform::CUDADeviceContext, float>);
...@@ -14,11 +14,11 @@ ...@@ -14,11 +14,11 @@
#pragma once #pragma once
#include "paddle/operators/math/math_function.h" #include "paddle/framework/data_layout.h"
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/im2col.h" #include "paddle/operators/math/im2col.h"
#include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -32,13 +32,16 @@ inline int get_output_size(int img_size, int block_size, int stride, ...@@ -32,13 +32,16 @@ inline int get_output_size(int img_size, int block_size, int stride,
} }
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class BlockExpandKernel : public framework::OpKernel<T> { class Im2SequenceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* in = ctx.Input<Tensor>("X"); const Tensor* in = ctx.Input<Tensor>("X");
LoDTensor* out = ctx.Output<LoDTensor>("Out"); LoDTensor* out = ctx.Output<LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
// TODO(wanghaoshuang): Add layout checker after 'set_layout'
// being available for python API
// PADDLE_ENFORCE_EQ(in->layout(), framework::DataLayout::kNCHW,
// "Input(X) layout must be NCHW");
auto in_dim = in->dims(); auto in_dim = in->dims();
int batch_size = in_dim[0]; int batch_size = in_dim[0];
int img_channels = in_dim[1]; int img_channels = in_dim[1];
...@@ -80,8 +83,9 @@ class BlockExpandKernel : public framework::OpKernel<T> { ...@@ -80,8 +83,9 @@ class BlockExpandKernel : public framework::OpKernel<T> {
// set lod information // set lod information
// TODO(wanghaoshuang): Move this to InferShape // TODO(wanghaoshuang): Move this to InferShape
framework::LoD lod(1); framework::LoD lod(1);
lod[0].reserve(batch_size + 1);
for (int i = 0, offset = 0; i < batch_size + 1; ++i) { for (int i = 0, offset = 0; i < batch_size + 1; ++i) {
lod[0].push_back(offset); lod[0][i] = offset;
offset += output_height * output_width; offset += output_height * output_width;
} }
out->set_lod(lod); out->set_lod(lod);
...@@ -89,7 +93,7 @@ class BlockExpandKernel : public framework::OpKernel<T> { ...@@ -89,7 +93,7 @@ class BlockExpandKernel : public framework::OpKernel<T> {
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class BlockExpandGradKernel : public framework::OpKernel<T> { class Im2SequenceGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<Tensor>("X"); auto* in = ctx.Input<Tensor>("X");
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
#Licensed under the Apache License, Version 2.0 (the "License");
#you may not use this file except in compliance with the License.
#You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
#Unless required by applicable law or agreed to in writing, software
#distributed under the License is distributed on an "AS IS" BASIS,
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.
import unittest import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
def get_output_shape(attrs, x): def get_output_shape(attrs, in_shape):
img_height = x.shape[2] img_height = in_shape[2]
img_width = x.shape[3] img_width = in_shape[3]
padding_height = attrs['padding_height'] padding_height = attrs['padding_height']
padding_width = attrs['padding_width'] padding_width = attrs['padding_width']
...@@ -73,8 +86,8 @@ def im2col(attrs, im, col): ...@@ -73,8 +86,8 @@ def im2col(attrs, im, col):
im_row_offset][im_col_offset] im_row_offset][im_col_offset]
def block_expand(inputs, attrs): def Im2Sequence(inputs, attrs):
output_height, output_width = get_output_shape(attrs, inputs) output_height, output_width = get_output_shape(attrs, inputs.shape)
img_channels = inputs.shape[1] img_channels = inputs.shape[1]
batch_size = inputs.shape[0] batch_size = inputs.shape[0]
out = np.zeros([ out = np.zeros([
...@@ -109,13 +122,12 @@ class TestBlockExpandOp(OpTest): ...@@ -109,13 +122,12 @@ class TestBlockExpandOp(OpTest):
def setUp(self): def setUp(self):
self.config() self.config()
self.op_type = "block_expand" self.op_type = "im2sequence"
#x = np.random.uniform(0.1, 1, x = np.random.uniform(0.1, 1, [
x = np.random.randint(0, 10, [
self.batch_size, self.img_channels, self.img_height, self.img_width self.batch_size, self.img_channels, self.img_height, self.img_width
]).astype("float32") ]).astype("float32")
out = block_expand(x, self.attrs) out = Im2Sequence(x, self.attrs)
self.inputs = {'X': x} self.inputs = {'X': x}
self.outputs = {'Out': out} self.outputs = {'Out': out}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册