提交 d04c8538 编写于 作者: Y yangyaming

Refine .cc and .h, more unit test more readable.

上级 0d9ba3da
...@@ -25,13 +25,15 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -25,13 +25,15 @@ class ExpandOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null.");
std::vector<int> expand_times = std::vector<int> expand_times =
ctx->Attrs().Get<std::vector<int>>("expandTimes"); ctx->Attrs().Get<std::vector<int>>("expand_times");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(), PADDLE_ENFORCE_EQ(static_cast<size_t>(x_dims.size()), expand_times.size(),
"The number of Attr(expandTimes)'s value must be equal " "The number of Attr(expand_times)'s value must be equal "
"to the rank of Input(X)."); "to the rank of Input(X).");
PADDLE_ENFORCE_LE(x_dims.size(), 6, PADDLE_ENFORCE_LE(x_dims.size(), 6,
"The rank of Input(X) must not be greater than 6."); "The rank of Input(X) must not be greater than 6.");
...@@ -39,13 +41,15 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -39,13 +41,15 @@ class ExpandOp : public framework::OperatorWithKernel {
std::vector<int64_t> out_shape(x_dims.size()); std::vector<int64_t> out_shape(x_dims.size());
for (size_t i = 0; i < expand_times.size(); ++i) { for (size_t i = 0; i < expand_times.size(); ++i) {
PADDLE_ENFORCE_GE(expand_times[i], 1, PADDLE_ENFORCE_GE(expand_times[i], 1,
"Each value of Attr(expandTimes) should not be " "Each value of Attr(expand_times) should not be "
"less than 1."); "less than 1.");
out_shape[i] = x_dims[i] * expand_times[i]; out_shape[i] = x_dims[i] * expand_times[i];
} }
ctx->SetOutputDim("Out", framework::make_ddim(out_shape)); ctx->SetOutputDim("Out", framework::make_ddim(out_shape));
ctx->ShareLoD("X", "Out"); if (out_shape[0] == x_dims[0]) {
ctx->ShareLoD("X", "Out");
}
} }
}; };
...@@ -61,13 +65,13 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -61,13 +65,13 @@ class ExpandOpMaker : public framework::OpProtoAndCheckerMaker {
"The rank of Output(Out) is same as Input(X) except that each " "The rank of Output(Out) is same as Input(X) except that each "
"dimension size of Output(Out) is equal to corresponding " "dimension size of Output(Out) is equal to corresponding "
"dimension size of Input(X) multiplying corresponding value of " "dimension size of Input(X) multiplying corresponding value of "
"Attr(expandTimes)."); "Attr(expand_times).");
AddAttr<std::vector<int>>("expandTimes", AddAttr<std::vector<int>>("expand_times",
"Expand times number for each dimension."); "Expand times number for each dimension.");
AddComment(R"DOC( AddComment(R"DOC(
Expand operator tiles the input by given times number. You should set times Expand operator tiles the input by given times number. You should set times
number for each dimension by providing attribute 'expandTimes'. The rank of X number for each dimension by providing attribute 'expand_times'. The rank of X
should be in [1, 6]. Please notice that size of 'expandTimes' must be same with should be in [1, 6]. Please notice that size of 'expand_times' must be same with
X's rank. X's rank.
)DOC"); )DOC");
} }
...@@ -82,16 +86,17 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -82,16 +86,17 @@ class ExpandGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null."); "Input(Out@GRAD) should not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
std::vector<int> expand_times = std::vector<int> expand_times =
ctx->Attrs().Get<std::vector<int>>("expandTimes"); ctx->Attrs().Get<std::vector<int>>("expand_times");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
for (size_t i = 0; i < expand_times.size(); ++i) { for (size_t i = 0; i < expand_times.size(); ++i) {
PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i], PADDLE_ENFORCE_EQ(x_dims[i] * expand_times[i], out_dims[i],
"Each dimension size of Input(Out@GRAD) should be " "Each dimension size of Input(Out@GRAD) should be "
"equal to multiplication of crroresponding dimension " "equal to multiplication of crroresponding dimension "
"size of Input(X) and Attr(expandTimes) value."); "size of Input(X) and Attr(expand_times) value.");
} }
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
......
...@@ -25,14 +25,17 @@ ...@@ -25,14 +25,17 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#define MAX_RANK_SUPPORTED 6
#define EXPAND_TEMPLATE(z, n, data) \ #define EXPAND_TEMPLATE(z, n, data) \
case n + 1: { \ case n + 1: { \
Expand<n + 1>(context); \ Expand<n + 1>(context); \
break; \ break; \
} }
#define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~) #define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~)
#define COND(n) \
#define COND(n) BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, 6), BOOST_PP_MOD(n, 6)) BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \
BOOST_PP_MOD(n, MAX_RANK_SUPPORTED))
#define EXPAND_GRAD_CASE(n) \ #define EXPAND_GRAD_CASE(n) \
case n: { \ case n: { \
ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \ ExpandBackward<n>(context, reshape_dims_vec, reduce_dims_vec); \
...@@ -46,7 +49,6 @@ namespace paddle { ...@@ -46,7 +49,6 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
...@@ -60,7 +62,7 @@ class ExpandKernel : public framework::OpKernel<T> { ...@@ -60,7 +62,7 @@ class ExpandKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto rank = context.Input<Tensor>("X")->dims().size(); auto rank = context.Input<Tensor>("X")->dims().size();
switch (rank) { switch (rank) {
REP_EXPAND_TEMPLATE(6) REP_EXPAND_TEMPLATE(MAX_RANK_SUPPORTED)
default: default:
PADDLE_ENFORCE(false, PADDLE_ENFORCE(false,
"Only support tensor with rank being between 1 and 6."); "Only support tensor with rank being between 1 and 6.");
...@@ -71,7 +73,7 @@ class ExpandKernel : public framework::OpKernel<T> { ...@@ -71,7 +73,7 @@ class ExpandKernel : public framework::OpKernel<T> {
template <int Rank> template <int Rank>
void Expand(const framework::ExecutionContext& context) const { void Expand(const framework::ExecutionContext& context) const {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
auto& expand_times = context.Attr<std::vector<int>>("expandTimes"); auto& expand_times = context.Attr<std::vector<int>>("expand_times");
auto* out0 = context.Output<Tensor>("Out"); auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims; Eigen::DSizes<int, Rank> bcast_dims;
auto x_dims = in0->dims(); auto x_dims = in0->dims();
...@@ -91,8 +93,14 @@ class ExpandGradKernel : public framework::OpKernel<T> { ...@@ -91,8 +93,14 @@ class ExpandGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
auto& expand_times = context.Attr<std::vector<int>>("expandTimes"); auto& expand_times = context.Attr<std::vector<int>>("expand_times");
auto x_dims = in0->dims(); auto x_dims = in0->dims();
// 1. reshape_dims_vec is the broadcast parameter. For each dimension i,
// if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two
// dimensions [expand_times[i], x_dims[i]].
// 2. reduce_dims_vec is the dimension parameter to compute gradients. For
// each dimension expanded, the gradients should be summed to original
// size.
std::vector<int> reshape_dims_vec; std::vector<int> reshape_dims_vec;
std::vector<int> reduce_dims_vec; std::vector<int> reduce_dims_vec;
for (size_t i = 0; i < expand_times.size(); ++i) { for (size_t i = 0; i < expand_times.size(); ++i) {
...@@ -110,7 +118,8 @@ class ExpandGradKernel : public framework::OpKernel<T> { ...@@ -110,7 +118,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
} }
} }
int dims = reshape_dims_vec.size() * 6 + reduce_dims_vec.size() - 7; int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED +
reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1;
// no need reduce, just copy // no need reduce, just copy
if (reduce_dims_vec.size() == 0) { if (reduce_dims_vec.size() == 0) {
auto* in0 = context.Input<Tensor>(framework::GradVarName("Out")); auto* in0 = context.Input<Tensor>(framework::GradVarName("Out"));
...@@ -132,8 +141,8 @@ class ExpandGradKernel : public framework::OpKernel<T> { ...@@ -132,8 +141,8 @@ class ExpandGradKernel : public framework::OpKernel<T> {
void ExpandBackward(const framework::ExecutionContext& context, void ExpandBackward(const framework::ExecutionContext& context,
const std::vector<int>& reshape_dims_vec, const std::vector<int>& reshape_dims_vec,
const std::vector<int>& reduce_dims_vec) const { const std::vector<int>& reduce_dims_vec) const {
size_t reshape_size = Dims / 6 + 1; size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1;
size_t reduce_size = Dims % 6 + 1; size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1;
PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(), PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(),
"Inconsistent size between template Dims and " "Inconsistent size between template Dims and "
"reshape dimensions."); "reshape dimensions.");
...@@ -145,11 +154,11 @@ class ExpandGradKernel : public framework::OpKernel<T> { ...@@ -145,11 +154,11 @@ class ExpandGradKernel : public framework::OpKernel<T> {
auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X"))); auto x = EigenVector<T>::Flatten(*(context.Input<Tensor>("X")));
out0->mutable_data<T>(context.GetPlace()); out0->mutable_data<T>(context.GetPlace());
auto x_grad = EigenVector<T>::Flatten(*out0); auto x_grad = EigenVector<T>::Flatten(*out0);
Eigen::DSizes<int, Dims / 6 + 1> reshape_dims; Eigen::DSizes<int, Dims / MAX_RANK_SUPPORTED + 1> reshape_dims;
for (size_t i = 0; i < reshape_size; ++i) { for (size_t i = 0; i < reshape_size; ++i) {
reshape_dims[i] = reshape_dims_vec[i]; reshape_dims[i] = reshape_dims_vec[i];
} }
Eigen::DSizes<int, Dims % 6 + 1> reduce_dims; Eigen::DSizes<int, Dims % MAX_RANK_SUPPORTED + 1> reduce_dims;
for (size_t i = 0; i < reduce_size; ++i) { for (size_t i = 0; i < reduce_size; ++i) {
reduce_dims[i] = reduce_dims_vec[i]; reduce_dims[i] = reduce_dims_vec[i];
} }
......
...@@ -7,7 +7,7 @@ class TestExpandOpRank1(OpTest): ...@@ -7,7 +7,7 @@ class TestExpandOpRank1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.inputs = {'X': np.random.random(12).astype("float32")} self.inputs = {'X': np.random.random(12).astype("float32")}
self.attrs = {'expandTimes': [2]} self.attrs = {'expand_times': [2]}
output = np.tile(self.inputs['X'], 2) output = np.tile(self.inputs['X'], 2)
self.outputs = {'Out': output} self.outputs = {'Out': output}
...@@ -18,11 +18,11 @@ class TestExpandOpRank1(OpTest): ...@@ -18,11 +18,11 @@ class TestExpandOpRank1(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestExpandOpRank2_1(OpTest): class TestExpandOpRank2_Corner(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.inputs = {'X': np.random.random((12, 14)).astype("float32")} self.inputs = {'X': np.random.random((12, 14)).astype("float32")}
self.attrs = {'expandTimes': [1, 1]} self.attrs = {'expand_times': [1, 1]}
output = np.tile(self.inputs['X'], (1, 1)) output = np.tile(self.inputs['X'], (1, 1))
self.outputs = {'Out': output} self.outputs = {'Out': output}
...@@ -33,11 +33,11 @@ class TestExpandOpRank2_1(OpTest): ...@@ -33,11 +33,11 @@ class TestExpandOpRank2_1(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestExpandOpRank2_2(OpTest): class TestExpandOpRank2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.inputs = {'X': np.random.random((12, 14)).astype("float32")} self.inputs = {'X': np.random.random((12, 14)).astype("float32")}
self.attrs = {'expandTimes': [2, 3]} self.attrs = {'expand_times': [2, 3]}
output = np.tile(self.inputs['X'], (2, 3)) output = np.tile(self.inputs['X'], (2, 3))
self.outputs = {'Out': output} self.outputs = {'Out': output}
...@@ -48,11 +48,11 @@ class TestExpandOpRank2_2(OpTest): ...@@ -48,11 +48,11 @@ class TestExpandOpRank2_2(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestExpandOpRank3_1(OpTest): class TestExpandOpRank3_Corner(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")} self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")}
self.attrs = {'expandTimes': [1, 1, 1]} self.attrs = {'expand_times': [1, 1, 1]}
output = np.tile(self.inputs['X'], (1, 1, 1)) output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output} self.outputs = {'Out': output}
...@@ -63,11 +63,11 @@ class TestExpandOpRank3_1(OpTest): ...@@ -63,11 +63,11 @@ class TestExpandOpRank3_1(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestExpandOpRank3_2(OpTest): class TestExpandOpRank3(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")} self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")}
self.attrs = {'expandTimes': [2, 1, 4]} self.attrs = {'expand_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4)) output = np.tile(self.inputs['X'], (2, 1, 4))
self.outputs = {'Out': output} self.outputs = {'Out': output}
...@@ -82,7 +82,7 @@ class TestExpandOpRank4(OpTest): ...@@ -82,7 +82,7 @@ class TestExpandOpRank4(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand" self.op_type = "expand"
self.inputs = {'X': np.random.random((2, 4, 5, 7)).astype("float32")} self.inputs = {'X': np.random.random((2, 4, 5, 7)).astype("float32")}
self.attrs = {'expandTimes': [3, 2, 1, 2]} self.attrs = {'expand_times': [3, 2, 1, 2]}
output = np.tile(self.inputs['X'], (3, 2, 1, 2)) output = np.tile(self.inputs['X'], (3, 2, 1, 2))
self.outputs = {'Out': output} self.outputs = {'Out': output}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册