(lower_bound));
+ return *this;
+ }
+
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) {
diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md
index 8aa6728a95bc464ab8884986f0cec6c817d3303b..c762811dfc190b255e0a3389885a081ce8315caf 100644
--- a/paddle/framework/backward.md
+++ b/paddle/framework/backward.md
@@ -2,20 +2,20 @@
## Motivation
-In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
-
+In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the gradient operators/expressions together with the chain rule. Every forward network needs a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
+
## Backward Operator Registry
-A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs and output gradients and then calculate its input gradients.
+A backward network is built up with several backward operators. Backward operators take forward operators' inputs outputs, and output gradients and then calculate its input gradients.
| | forward operator | backward operator
| ---------------------- | ---------------- |------------------------- |
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
- In most cases, there is a one-to-one correspondence between forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
+ In most cases, there is a one-to-one correspondence between the forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
-For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro:
+For example, we have got a `mul_op`, and we can register its information and corresponding backward operator by the following macro:
```cpp
REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
@@ -27,17 +27,17 @@ REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
## Backward Opeartor Creating
-Given a certain forward operator, we can get its corresponding backward opeartor by calling:
+Given a certain forward operator, we can get its corresponding backward operator by calling:
```cpp
OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
-```
+```
The function `BuildGradOp` will sequentially execute following processes:
1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
-2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these are not necessary for gradient computing.
+2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
@@ -49,31 +49,31 @@ A backward network is a series of backward operators. The main idea of building
In our design, the network itself is also a kind of operator. So the operators contained by a big network may be some small network.
-given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`.
+given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`, `InputGradients`.
1. Op
- when the input forward network is a Op, return its gradient Operator Immediately.
+ when the input forward network is an Op, return its gradient Operator Immediately.
2. NetOp
- when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to forward NetOp.
+ when the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp.
- **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwirte their shared input variable.
+ **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwrite their shared input variable.
- 
+ 
- 1. shared variable in two operators.
+ 1. Shared variable in operators.
- Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator replace the overwirte links.
+ Share variable between operators or same input variable used in multiple operators leads to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively and add a generic add operator replace the overwrite links.
- 
+ 
- 2. replace shared variable gradient with `Add` Operator
+ 2. Replace shared variable's gradient with `Add` operator.
diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc
index 85b7de79743bb0390d66b8999f2e8342a51d14a9..fc3d508553c0e966978b28d58127bdbff10d45f1 100644
--- a/paddle/framework/ddim.cc
+++ b/paddle/framework/ddim.cc
@@ -283,5 +283,14 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim::DDim(std::initializer_list init_list) {
*this = make_ddim(init_list);
}
+
+DDim flatten_to_2d(const DDim& src, int num_col_dims) {
+ int rank = src.size();
+ return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
+ product(slice_ddim(src, num_col_dims, rank))});
+}
+
+DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h
index db30c523948b1d437615aa0e9bfecb5e25569296..ca29e7e8c7776de6adf3e3b0e8f11f0d4d8487c3 100644
--- a/paddle/framework/ddim.h
+++ b/paddle/framework/ddim.h
@@ -115,6 +115,12 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&);
+// Reshape a tensor to a matrix. The matrix's first dimension(column length)
+// will be the product of tensor's first `num_col_dims` dimensions.
+DDim flatten_to_2d(const DDim& src, int num_col_dims);
+
+DDim flatten_to_1d(const DDim& src);
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h
index 2d8d9ae10c56e0632414a5bbc754d35bfa9ce6a5..54bbeafcabdeeb1e2c1017c156b3512c83dada3a 100644
--- a/paddle/framework/eigen.h
+++ b/paddle/framework/eigen.h
@@ -63,20 +63,35 @@ struct EigenTensor {
template
-struct EigenMatrix : public EigenTensor {};
+struct EigenMatrix : public EigenTensor {
+ static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+
+ static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
+ int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+};
template
struct EigenVector : public EigenTensor {
// Flatten reshapes a Tensor into an EigenVector.
static typename EigenVector::Type Flatten(Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
static typename EigenVector::ConstType Flatten(const Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
};
diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc
index dc1957691b1a202826e10e84c21ac8874df9e378..bc4a2db32cfba66bef2c444e1f822e0d2a57b91e 100644
--- a/paddle/framework/eigen_test.cc
+++ b/paddle/framework/eigen_test.cc
@@ -108,5 +108,24 @@ TEST(Eigen, Matrix) {
}
}
+TEST(Eigen, MatrixReshape) {
+ Tensor t;
+ float* p = t.mutable_data({2, 3, 6, 4}, platform::CPUPlace());
+ for (int i = 0; i < 2 * 3 * 6 * 4; ++i) {
+ p[i] = static_cast(i);
+ }
+
+ EigenMatrix::Type em = EigenMatrix::Reshape(t, 2);
+
+ ASSERT_EQ(2 * 3, em.dimension(0));
+ ASSERT_EQ(6 * 4, em.dimension(1));
+
+ for (int i = 0; i < 2 * 3; i++) {
+ for (int j = 0; j < 6 * 4; j++) {
+ ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f);
+ }
+ }
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/images/duplicate_op2.graffle b/paddle/framework/images/duplicate_op2.graffle
index 2b658085d6a55d368c320051ba7f94ec2900f13c..ede3bca30ae17d5af52505fd94dc2f79b23b57e0 100644
Binary files a/paddle/framework/images/duplicate_op2.graffle and b/paddle/framework/images/duplicate_op2.graffle differ
diff --git a/paddle/framework/images/duplicate_op2.png b/paddle/framework/images/duplicate_op2.png
index c5588015d1450fd8c1bda3580680d884494868bb..4e872dc2caf3b0cbd0d5176f11a14801b538dc86 100644
Binary files a/paddle/framework/images/duplicate_op2.png and b/paddle/framework/images/duplicate_op2.png differ
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index 643f875491724bf443bd7727391734377ee6180c..ce938b21437195fed8c1adad4329fd139f3f96ab 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -43,6 +43,9 @@ class Tensor {
template
friend struct EigenTensor;
+ template
+ friend struct EigenMatrix;
+
template
friend struct EigenVector;
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index 94f436294f350e2a39785a09959efb3b17bd00a5..637f04ae0037bd402d855b8bcde8087bfe8328d1 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -148,5 +148,13 @@ inline Tensor& Tensor::Resize(const DDim& dims) {
inline const DDim& Tensor::dims() const { return dims_; }
+template
+inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
+ Tensor res;
+ res.ShareDataWith(src);
+ res.Resize(flatten_to_2d(src.dims(), num_col_dims));
+ return res;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index 7db38d5caeebccf710334e854faf785ef0f64063..55302ea47120f420e952b26830c8ea4cbcce6435 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) {
}
#endif
}
+
+TEST(Tensor, ReshapeToMatrix) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+ Tensor src;
+ int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace());
+ for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
+ src_ptr[i] = i;
+ }
+ Tensor res = ReshapeToMatrix(src, 2);
+ ASSERT_EQ(res.dims()[0], 2 * 3);
+ ASSERT_EQ(res.dims()[1], 4 * 9);
+}
\ No newline at end of file
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.cpp b/paddle/gserver/layers/BatchNormBaseLayer.cpp
index 1ceaaaa206ee3cbc5421238574c7f310011ccaa5..f7a80e23e1bd49549bec57b360587adc6b423794 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.cpp
+++ b/paddle/gserver/layers/BatchNormBaseLayer.cpp
@@ -62,14 +62,18 @@ void BatchNormBaseLayer::calFeatureMapSize() {
const ImageConfig& conf = config_.inputs(0).image_conf();
imageH_ = inputLayers_[0]->getOutput().getFrameHeight();
imageW_ = inputLayers_[0]->getOutput().getFrameWidth();
+ imageD_ = inputLayers_[0]->getOutput().getFrameDepth();
+
+ if (0 == imageD_) imageD_ = conf.img_size_z();
if (imageH_ == 0 && imageW_ == 0) {
imageH_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
imageW_ = conf.img_size();
} else {
getOutput().setFrameHeight(imageH_);
getOutput().setFrameWidth(imageW_);
+ getOutput().setFrameDepth(imageD_);
}
- imgPixels_ = imageH_ * imageW_;
+ imgPixels_ = imageH_ * imageW_ * imageD_;
}
} // namespace paddle
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.h b/paddle/gserver/layers/BatchNormBaseLayer.h
index 230bafc31d96bbd49481a7ed135be6888688627e..e721d2d267a31cae46407673b8b1281e87055608 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.h
+++ b/paddle/gserver/layers/BatchNormBaseLayer.h
@@ -80,6 +80,7 @@ protected:
/// Height or width of input image feature.
/// Both of them are 1 if the input is fully-connected layer.
+ int imageD_;
int imageH_;
int imageW_;
/// Height * Width.
diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
index 44ba2c4b7d1562d2ce839b5f4b4de1af35e6925f..49a9540c0b6e36b59ed786287ff5c4569b69a6a5 100644
--- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp
+++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
@@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
}
void CudnnBatchNormLayer::reshape(int batchSize) {
- hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_);
+ hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_);
}
void CudnnBatchNormLayer::forward(PassType passType) {
@@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
EPS,
batchSize,
channels_,
- imageH_,
+ imageH_ * imageD_,
imageW_);
}
}
diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp
index 92cd61cdd515d5c693df086c9575a5f197c00cee..d7eee6eaf078dab8d48adc4c7ee758a433672ac6 100644
--- a/paddle/gserver/layers/SwitchOrderLayer.cpp
+++ b/paddle/gserver/layers/SwitchOrderLayer.cpp
@@ -24,10 +24,12 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap,
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
auto& img_conf = config_.inputs(0).image_conf();
+ size_t inD = img_conf.img_size_z();
size_t inH =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
size_t inW = img_conf.img_size();
size_t inC = img_conf.channels();
+ inH = inH * inD;
inDims_ = TensorShape({0, inC, inH, inW});
outDims_ = TensorShape(4);
@@ -64,9 +66,10 @@ void SwitchOrderLayer::setInDims() {
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
inDims_.setDim(0, batchSize);
-
+ int d = inputLayers_[0]->getOutput().getFrameDepth();
+ d = (d == 0 ? 1 : d);
int h = inputLayers_[0]->getOutput().getFrameHeight();
- if (h != 0) inDims_.setDim(2, h);
+ if (h != 0) inDims_.setDim(2, h * d);
int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w);
int totalCount = input->getElementCnt();
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index d1f3bc241fa621cb0070125980996e8627e40fd6..0e6be2df9ef5f0fae8ed2b0c65ac6c032fe45ab1 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -1703,6 +1703,55 @@ TEST(Layer, BatchNormalizationLayer) {
#endif
}
+void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) {
+ TestConfig config;
+ const int CHANNELS = 10;
+ const int IMG_SIZE = 16;
+ const int IMG_SIZE_Y = 8;
+ const int IMG_SIZE_Z = 8;
+ size_t size = CHANNELS * IMG_SIZE * IMG_SIZE_Y * IMG_SIZE_Z;
+ config.layerConfig.set_type(type);
+ config.layerConfig.set_size(size);
+ config.layerConfig.set_active_type("sigmoid");
+ config.biasSize = CHANNELS;
+ config.inputDefs.push_back({INPUT_DATA,
+ "layer_0",
+ /* dim= */ size,
+ /* paraSize= */ CHANNELS});
+
+ config.inputDefs.push_back({INPUT_DATA, "layer_1_running_mean", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+ config.inputDefs.push_back({INPUT_DATA, "layer_2_running_var", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+
+ LayerInputConfig* input = config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+
+ ImageConfig* img_conf = input->mutable_image_conf();
+ img_conf->set_channels(CHANNELS);
+ img_conf->set_img_size(IMG_SIZE);
+ img_conf->set_img_size_y(IMG_SIZE_Y);
+ img_conf->set_img_size_z(IMG_SIZE_Z);
+
+ testLayerGrad(config,
+ "batch_norm",
+ 64,
+ /* trans= */ trans,
+ useGpu,
+ /* useWeight */ true);
+}
+
+TEST(Layer, testBatchNorm3DLayer) {
+ testBatchNorm3DLayer("batch_norm", false, false);
+#ifndef PADDLE_ONLY_CPU
+ testBatchNorm3DLayer("batch_norm", false, true);
+ if (hl_get_cudnn_lib_version() >= int(4000)) {
+ testBatchNorm3DLayer("cudnn_batch_norm", false, true);
+ }
+#endif
+}
+
void testConvOperator(bool isDeconv) {
TestConfig config;
const int NUM_FILTERS = 16;
diff --git a/paddle/operators/identity_op.cc b/paddle/operators/identity_op.cc
index be956bf3b320d6beacdb0d2ca742c3e854194b19..7d9d4fa519d1c690feacbadc5175aeab49082282 100644
--- a/paddle/operators/identity_op.cc
+++ b/paddle/operators/identity_op.cc
@@ -18,17 +18,20 @@
namespace paddle {
namespace operators {
-// identity is a alias of scale op. This is also a example for creating a alias
-// operator.
+// The identity operator is an alias of the scale operator. This is also an
+// example for creating an alias for an existing operator.
template
class IdentityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IdentityOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
- AddInput("X", "input tensor of identity op");
- AddOutput("Out", "output tensor of identity op");
- AddComment("identity operator. Just a alias of scale op which scale = 1.0");
+ AddInput("X", "The input tensor of identity operator.");
+ AddOutput("Out", "The output tensor of identity operator.");
+ AddComment(R"DOC(
+The identity operator is an alias of the scale operator
+with the attribute scale fixed to 1.0.
+)DOC");
}
};
diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc
index ee5fb98acdfd109369b8953ee1a19599d914e290..186a33edcec88bd5e51091a524a778eeb27ad526 100644
--- a/paddle/operators/math/im2col_test.cc
+++ b/paddle/operators/math/im2col_test.cc
@@ -71,8 +71,12 @@ void testIm2col() {
context =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
} else {
+#ifndef PADDLE_ONLY_CPU
context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
+#else
+ PADDLE_THROW("no GPU support");
+#endif // PADDLE_ONLY_CPU
}
im2col(input, output_cfo, stride, stride, padding, padding, context);
im2col_ocf(input, output_ocf, stride, stride, padding, padding, context);
diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc
index 28a47cdff2e9b7a965ff9f99e787bb8315010823..710a56a0e8e2d17162d7d000df226f1537104eb9 100644
--- a/paddle/operators/mul_op.cc
+++ b/paddle/operators/mul_op.cc
@@ -25,18 +25,27 @@ class MulOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- auto dim0 = ctx.Input("X")->dims();
- auto dim1 = ctx.Input("Y")->dims();
- PADDLE_ENFORCE_EQ(dim0.size(), 2,
- "input X(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("X"));
- PADDLE_ENFORCE_EQ(dim1.size(), 2,
- "input Y(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("Y"));
+ auto x_dims = ctx.Input("X")->dims();
+ auto y_dims = ctx.Input("Y")->dims();
+ int x_num_col_dims = Attr("x_num_col_dims");
+ int y_num_col_dims = Attr("y_num_col_dims");
+
+ PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
+ "The rank of input tensor X(%s) should be larger than "
+ "`mul_op`'s `x_num_col_dims`.",
+ ctx.op().Input("X"));
+ PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
+ "The rank of input tensor Y(%s) should be larger than "
+ "`mul_op`'s `y_num_col_dims`.",
+ ctx.op().Input("Y"));
+
+ auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
+ auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
+
PADDLE_ENFORCE_EQ(
- dim0[1], dim1[0],
+ x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's height.");
- ctx.Output("Out")->Resize({dim0[0], dim1[1]});
+ ctx.Output("Out")->Resize({x_mat_dims[0], y_mat_dims[1]});
}
};
@@ -47,6 +56,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op");
+ AddAttr(
+ "x_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `X`,
+ in that case, tensors will be reshaped to a matrix. The matrix's first
+ dimension(column length) will be the product of tensor's last
+ `num_col_dims` dimensions, and the matrix's second dimension(row length)
+ will be the product of tensor's first `rank - num_col_dims` dimensions.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
+ AddAttr(
+ "y_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `Y`,
+ in that case, tensors will be reshaped to a matrix. Just like input `X`.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
AddComment(R"DOC(
Two Element Mul Operator.
@@ -70,10 +96,20 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output(framework::GradVarName("X"));
auto *y_grad = ctx.Output(framework::GradVarName("Y"));
- PADDLE_ENFORCE(x_dims[0] == out_dims[0],
- "Out@GRAD M X N must equal to X dims 0, M ");
- PADDLE_ENFORCE(y_dims[1] == out_dims[1],
- "Out@GRAD M X N must equal to Y dims 1, N ");
+
+ auto x_mat_dims =
+ framework::flatten_to_2d(x_dims, Attr("x_num_col_dims"));
+ auto y_mat_dims =
+ framework::flatten_to_2d(y_dims, Attr("y_num_col_dims"));
+
+ PADDLE_ENFORCE_EQ(
+ x_mat_dims[0], out_dims[0],
+ "The first dimension of Out@GRAD must equal to the first dimension of "
+ "the first operand.");
+ PADDLE_ENFORCE_EQ(
+ y_mat_dims[1], out_dims[1],
+ "The second dimension of Out@GRAD must equal to the second "
+ "dimension of the second operand.");
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h
index 05a79e13b3470e39a5ebd0394ba05629553a5075..3c01f868bda8cba488b3403df456d63d6b082fa6 100644
--- a/paddle/operators/mul_op.h
+++ b/paddle/operators/mul_op.h
@@ -1,7 +1,7 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
+ You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
@@ -31,13 +31,25 @@ template
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
- auto* x = context.Input("X");
- auto* y = context.Input("Y");
- auto* z = context.Output("Out");
+ const Tensor* x = context.Input("X");
+ const Tensor* y = context.Input("Y");
+ Tensor* z = context.Output("Out");
+ const Tensor x_matrix =
+ x->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *x, context.template Attr("x_num_col_dims"))
+ : *x;
+ const Tensor y_matrix =
+ y->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *y, context.template Attr("y_num_col_dims"))
+ : *y;
+
z->mutable_data(context.GetPlace());
auto* device_context =
const_cast(context.device_context_);
- math::matmul(*x, false, *y, false, 1, z, 0, device_context);
+ math::matmul(x_matrix, false, y_matrix, false, 1, z, 0,
+ device_context);
}
};
@@ -45,23 +57,39 @@ template
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
- auto* x = ctx.Input("X");
- auto* y = ctx.Input("Y");
- auto* dout = ctx.Input(framework::GradVarName("Out"));
+ int x_num_col_dims = ctx.template Attr("x_num_col_dims");
+ int y_num_col_dims = ctx.template Attr("y_num_col_dims");
+ const Tensor* x = ctx.Input("X");
+ const Tensor* y = ctx.Input("Y");
+ const Tensor x_matrix =
+ x->dims().size() > 2 ? framework::ReshapeToMatrix(*x, x_num_col_dims)
+ : *x;
+ const Tensor y_matrix =
+ y->dims().size() > 2 ? framework::ReshapeToMatrix(*y, y_num_col_dims)
+ : *y;
+ const Tensor* dout = ctx.Input(framework::GradVarName("Out"));
- auto* dx = ctx.Output(framework::GradVarName("X"));
- auto* dy = ctx.Output(framework::GradVarName("Y"));
+ Tensor* dx = ctx.Output(framework::GradVarName("X"));
+ Tensor* dy = ctx.Output(framework::GradVarName("Y"));
auto* device_context =
const_cast(ctx.device_context_);
if (dx) {
dx->mutable_data(ctx.GetPlace());
+ Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix(
+ *dx, x_num_col_dims)
+ : *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
- math::matmul(*dout, false, *y, true, 1, dx, 0, device_context);
+ math::matmul(*dout, false, y_matrix, true, 1, &dx_matrix, 0,
+ device_context);
}
if (dy) {
dy->mutable_data(ctx.GetPlace());
+ Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix(
+ *dy, y_num_col_dims)
+ : *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
- math::matmul(*x, true, *dout, false, 1, dy, 0, device_context);
+ math::matmul(x_matrix, true, *dout, false, 1, &dy_matrix, 0,
+ device_context);
}
}
};
diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc
index 30b4b404315a9f041e21d79b75fd06307e33f7f9..fa8f0ff1a858143af427b51025279c726f1628e0 100644
--- a/paddle/operators/rowwise_add_op.cc
+++ b/paddle/operators/rowwise_add_op.cc
@@ -25,14 +25,19 @@ class RowwiseAddOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- auto dim0 = ctx.Input("X")->dims();
- auto dim1 = ctx.Input("b")->dims();
-
- PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix");
- PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
- PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
- PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1");
- ctx.Output("Out")->Resize(ctx.Input("X")->dims());
+ auto x_dims = ctx.Input("X")->dims();
+ auto b_dims = ctx.Input("b")->dims();
+ PADDLE_ENFORCE_GT(
+ x_dims.size(), b_dims.size(),
+ "The rank of input `X` must be larger than the one of input `b`.");
+
+ int num_col_dims = x_dims.size() - b_dims.size();
+
+ PADDLE_ENFORCE_EQ(
+ framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
+ "The width of two operands must be same");
+ PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1");
+ ctx.Output("Out")->Resize(x_dims);
}
};
@@ -61,13 +66,20 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
- auto dims0 = ctx.Input("X")->dims();
- auto dims1 = ctx.Input("b")->dims();
- PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
+ auto x_dims = ctx.Input("X")->dims();
+ auto b_dims = ctx.Input("b")->dims();
+ PADDLE_ENFORCE_GT(
+ x_dims.size(), b_dims.size(),
+ "The rank of input `X` must be larger than the one of input `b`.");
+
+ int num_col_dims = x_dims.size() - b_dims.size();
+ PADDLE_ENFORCE_EQ(
+ framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
+ "The width of two operands must be same");
auto *dx = ctx.Output(framework::GradVarName("X"));
auto *db = ctx.Output(framework::GradVarName("b"));
- if (dx) dx->Resize(dims0);
- if (db) db->Resize(dims1);
+ if (dx) dx->Resize(x_dims);
+ if (db) db->Resize(b_dims);
}
};
diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h
index 4e926d9f2947f37b71e81c0fa592b0c66b19c640..35774b940926f77167b8f19597027e74d3477e5b 100644
--- a/paddle/operators/rowwise_add_op.h
+++ b/paddle/operators/rowwise_add_op.h
@@ -33,10 +33,12 @@ class RowwiseAddKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output("Out");
out->mutable_data(context.GetPlace());
-
- auto input = EigenMatrix::From(*context.Input("X"));
- auto bias = EigenVector::From(*context.Input("b"));
- auto output = EigenMatrix::From(*out);
+ int num_col_dims = context.Input("X")->dims().size() -
+ context.Input("b")->dims().size();
+ auto input =
+ EigenMatrix::Reshape(*context.Input("X"), num_col_dims);
+ auto bias = EigenVector::Flatten(*context.Input("b"));
+ auto output = EigenMatrix::Reshape(*out, num_col_dims);
const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size;
@@ -54,12 +56,15 @@ class RowwiseAddGradKernel : public framework::OpKernel {
auto* dout = context.Input(framework::GradVarName("Out"));
auto* dx = context.Output(framework::GradVarName("X"));
auto* db = context.Output(framework::GradVarName("b"));
+ int num_col_dims = context.Input("X")->dims().size() -
+ context.Input("b")->dims().size();
- auto out_grad = EigenMatrix::From(*dout);
+ auto out_grad = EigenMatrix::Reshape(*dout, num_col_dims);
auto place = context.GetEigenDevice();
+
if (dx) {
dx->mutable_data(context.GetPlace());
- EigenMatrix::From(*dx).device(place) = out_grad;
+ EigenMatrix::Reshape(*dx, num_col_dims).device(place) = out_grad;
}
if (db) {
diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc
index 3d82b345829b0a554a204ada91c807e42b71dc58..ea991f683d841b3dc4624a0d8aa3c88367fd3c6d 100644
--- a/paddle/operators/scale_op.cc
+++ b/paddle/operators/scale_op.cc
@@ -44,11 +44,13 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
The equation is: Out = scale*X
)DOC");
- AddAttr("scale", "scale of scale operator.").SetDefault(1.0);
+ AddAttr("scale", "The scaling factor of the scale operator.")
+ .SetDefault(1.0);
}
};
-// Scale Op's gradient is scale op, too.
+// The operator to calculate gradients of a scale operator is just the scale
+// operator itself.
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template
class ScaleGradOp : public NetOp {
diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc
index 7d062ad67c048bc6bef68121f86334eb3f1efe92..7166b2f60be8a6088ab3a81686f7bed1b7181d97 100644
--- a/paddle/operators/softmax_op.cc
+++ b/paddle/operators/softmax_op.cc
@@ -51,7 +51,7 @@ the other dimensions in the K-dimensional vector input. Then the ratio of the
exponential of the given dimension and the sum of exponential values of all
the other dimensions is the output of the softmax operator.
-For each row `i` and each column `j` in X, we have:
+For each row `i` and each column `j` in input X, we have:
Y[i, j] = exp(X[i, j]) / sum_j(exp(X[i, j]))
)DOC");
@@ -64,14 +64,15 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
- "Input(Y@GRAD) should not be null");
- PADDLE_ENFORCE(ctx.Input("Y")->dims() ==
- ctx.Input(framework::GradVarName("Y"))->dims(),
- "the shape of Input(0) and Input(1) should be the same");
+ "Input(Y@GRAD) should be not null.");
+ PADDLE_ENFORCE_EQ(ctx.Input("Y")->dims(),
+ ctx.Input(framework::GradVarName("Y"))->dims(),
+ "Input(Y) and its gradients should have a same shape.");
+
ctx.Output(framework::GradVarName("X"))
- ->Resize(ctx.Input("Y")->dims());
+ ->Resize(ctx.Input("X")->dims());
}
};
diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h
index 4fa6b59540498638c3b7df639ae10a66c0fa1c16..8a3a5ab927c0e2937936fcc973f000d4d95c3dbc 100644
--- a/paddle/operators/softmax_op.h
+++ b/paddle/operators/softmax_op.h
@@ -28,12 +28,12 @@ template
class SoftmaxKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
- auto input = context.Input("X");
- auto output = context.Output("Y");
- output->mutable_data(context.GetPlace());
+ auto X = context.Input("X");
+ auto Y = context.Output("Y");
+ Y->mutable_data(context.GetPlace());
- auto logits = EigenMatrix::From(*input);
- auto softmax = EigenMatrix::From(*output);
+ auto logits = EigenMatrix::From(*X);
+ auto softmax = EigenMatrix::From(*Y);
const int kBatchDim = 0;
const int kClassDim = 1;
diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..38d2f0a09aec751734864947a2f3cfa20107e22f
--- /dev/null
+++ b/paddle/operators/top_k_op.cc
@@ -0,0 +1,67 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/top_k_op.h"
+
+namespace paddle {
+namespace operators {
+
+class TopkOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
+ "Input of TopkOP must be initialized.");
+ auto *input = ctx.Input("X");
+ const int k = static_cast(ctx.Attr("k"));
+
+ PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
+ PADDLE_ENFORCE_GE(input->dims().size(), 1, "input must have >= 1d shape");
+ PADDLE_ENFORCE_GE(input->dims()[input->dims().size() - 1], k,
+ "input must have >= k columns");
+
+ framework::DDim dims = input->dims();
+ dims[dims.size() - 1] = k;
+ ctx.Output("Out")->Resize(dims);
+ ctx.Output("Indices")->Resize(dims);
+ }
+};
+
+class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ TopkOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "The input of Topk op");
+ AddOutput("Out", "The output tensor of Topk op");
+ AddOutput("Indices", "The indices of Topk elements of input");
+ AddComment(
+ R"DOC(If the input is a vector (1d tensor), finds the k largest entries in the vector and outputs their values and indices as vectors. Thus values[j] is the j-th largest entry in input, and its index is indices[j].
+
+ For matrices, computes the top k entries in each row. )DOC");
+ AddAttr("k",
+ "Number of top elements to look for along the last "
+ "dimension (along each row for matrices).")
+ .SetDefault(1);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_WITHOUT_GRADIENT(top_k, ops::TopkOp, ops::TopkOpMaker);
+REGISTER_OP_CPU_KERNEL(top_k,
+ ops::TopkKernel);
diff --git a/paddle/operators/top_k_op.cu b/paddle/operators/top_k_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..afe4d149c53819c45e20353bc9d16393f3f61e0f
--- /dev/null
+++ b/paddle/operators/top_k_op.cu
@@ -0,0 +1,318 @@
+/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/framework/op_registry.h"
+#include "paddle/platform/assert.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+struct Pair {
+ __device__ __forceinline__ Pair() {}
+ __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
+
+ __device__ __forceinline__ void set(T value, int id) {
+ v = value;
+ id = id;
+ }
+
+ __device__ __forceinline__ void operator=(const Pair& in) {
+ v = in.v;
+ id = in.id;
+ }
+
+ __device__ __forceinline__ bool operator<(const T value) const {
+ return (v < value);
+ }
+
+ __device__ __forceinline__ bool operator<(const Pair& in) const {
+ return (v < in.v) || ((v == in.v) && (id > in.id));
+ }
+
+ __device__ __forceinline__ bool operator>(const Pair& in) const {
+ return (v > in.v) || ((v == in.v) && (id < in.id));
+ }
+
+ T v;
+ int id;
+};
+
+template
+__device__ __forceinline__ void AddTo(Pair topk[], const Pair& p,
+ int beam_size) {
+ for (int k = beam_size - 2; k >= 0; k--) {
+ if (topk[k] < p) {
+ topk[k + 1] = topk[k];
+ } else {
+ topk[k + 1] = p;
+ return;
+ }
+ }
+ topk[0] = p;
+}
+
+template
+__device__ __forceinline__ void AddTo(Pair topk[], const Pair& p) {
+ for (int k = beam_size - 2; k >= 0; k--) {
+ if (topk[k] < p) {
+ topk[k + 1] = topk[k];
+ } else {
+ topk[k + 1] = p;
+ return;
+ }
+ }
+ topk[0] = p;
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* src, int idx,
+ int dim, int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < src[idx]) {
+ Pair tmp(src[idx], idx);
+ AddTo(topk, tmp, beam_size);
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* src, int idx,
+ int dim, const Pair& max,
+ int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < src[idx]) {
+ Pair tmp(src[idx], idx);
+ if (tmp < max) {
+ AddTo(topk, tmp, beam_size);
+ }
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* val, int* col,
+ int idx, int dim, int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < val[idx]) {
+ Pair tmp(val[idx], col[idx]);
+ AddTo(topk, tmp, beam_size);
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* val, int* col,
+ int idx, int dim, const Pair& max,
+ int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < val[idx]) {
+ Pair tmp(val[idx], col[idx]);
+ if (tmp < max) {
+ AddTo(topk, tmp, beam_size);
+ }
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam,
+ int beam_size, const T* src,
+ bool& firstStep, bool& is_empty,
+ Pair& max, int dim,
+ const int tid) {
+ if (beam > 0) {
+ int length = beam < beam_size ? beam : beam_size;
+ if (firstStep) {
+ firstStep = false;
+ GetTopK(topk, src, tid, dim, length);
+ } else {
+ for (int k = 0; k < MaxLength; k++) {
+ if (k < MaxLength - beam) {
+ topk[k] = topk[k + beam];
+ } else {
+ topk[k].set(-INFINITY, -1);
+ }
+ }
+ if (!is_empty) {
+ GetTopK(topk + MaxLength - beam, src, tid, dim, max,
+ length);
+ }
+ }
+
+ max = topk[MaxLength - 1];
+ if (max.v == -1) is_empty = true;
+ beam = 0;
+ }
+}
+
+template
+__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam,
+ int beam_size, const T* val,
+ int* col, bool& firstStep,
+ bool& is_empty, Pair& max,
+ int dim, const int tid) {
+ if (beam > 0) {
+ int length = beam < beam_size ? beam : beam_size;
+ if (firstStep) {
+ firstStep = false;
+ GetTopK(topk, val, col, tid, dim, length);
+ } else {
+ for (int k = 0; k < MaxLength; k++) {
+ if (k < MaxLength - beam) {
+ topk[k] = topk[k + beam];
+ } else {
+ topk[k].set(-INFINITY, -1);
+ }
+ }
+ if (!is_empty) {
+ GetTopK(topk + MaxLength - beam, val, col, tid, dim, max,
+ length);
+ }
+ }
+
+ max = topk[MaxLength - 1];
+ if (max.v == -1) is_empty = true;
+ beam = 0;
+ }
+}
+
+template
+__device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid,
+ Pair topk[], T** topVal,
+ int** topIds, int& beam, int& k,
+ const int tid, const int warp) {
+ while (true) {
+ __syncthreads();
+ if (tid < BlockSize / 2) {
+ if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
+ maxid[tid] = tid + BlockSize / 2;
+ } else {
+ maxid[tid] = tid;
+ }
+ }
+ __syncthreads();
+ for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) {
+ if (tid < stride) {
+ if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) {
+ maxid[tid] = maxid[tid + stride];
+ }
+ }
+ __syncthreads();
+ }
+ __syncthreads();
+
+ if (tid == 0) {
+ **topVal = sh_topk[maxid[0]].v;
+ **topIds = sh_topk[maxid[0]].id;
+ (*topVal)++;
+ (*topIds)++;
+ }
+ if (tid == maxid[0]) beam++;
+ if (--k == 0) break;
+ __syncthreads();
+
+ if (tid == maxid[0]) {
+ if (beam < MaxLength) {
+ sh_topk[tid] = topk[beam];
+ }
+ }
+ if (maxid[0] / 32 == warp) {
+ if (__shfl(beam, (maxid[0]) % 32, 32) == MaxLength) break;
+ }
+ }
+}
+
+/**
+ * Each block compute one sample.
+ * In a block:
+ * 1. every thread get top MaxLength value;
+ * 2. merge to sh_topk, block reduce and get max value;
+ * 3. go to the second setp, until one thread's topk value is null;
+ * 4. go to the first setp, until get the topk value.
+ */
+template
+__global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
+ const T* src, int lds, int dim, int k) {
+ __shared__ Pair sh_topk[BlockSize];
+ __shared__ int maxid[BlockSize / 2];
+ const int tid = threadIdx.x;
+ const int warp = threadIdx.x / 32;
+ output += blockIdx.x * output_stride;
+ indices += blockIdx.x * k;
+
+ Pair topk[MaxLength];
+ int beam = MaxLength;
+ Pair max;
+ bool is_empty = false;
+ bool firststep = true;
+
+ for (int k = 0; k < MaxLength; k++) {
+ topk[k].set(-INFINITY, -1);
+ }
+ while (k) {
+ ThreadGetTopK(topk, beam, k,
+ src + blockIdx.x * lds, firststep,
+ is_empty, max, dim, tid);
+
+ sh_topk[tid] = topk[0];
+ BlockReduce(sh_topk, maxid, topk, &output,
+ &indices, beam, k, tid, warp);
+ }
+}
+
+template
+class TopkOpCUDAKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
+ "It must use GPUPlace.");
+ auto* input = ctx.Input("X");
+ auto* output = ctx.Output("Out");
+ auto* indices = ctx.Output("Indices");
+ size_t k = static_cast(ctx.Attr("k"));
+
+ const T* input_data = input->data();
+
+ T* output_data = output->mutable_data(ctx.GetPlace());
+ // FIXME(typhoonzero): data is always converted to type T?
+ int* indices_data = indices->mutable_data(ctx.GetPlace());
+
+ size_t input_height = input->dims()[0];
+ size_t input_width = input->dims()[1];
+ if (k > input_width) k = input_width;
+
+ // NOTE: pass lds and dim same to input width.
+ // NOTE: old matrix implementation of stride is different to eigen.
+ // TODO(typhoonzero): launch kernel on specified stream.
+ // TODO(typhoonzero): refine this kernel.
+ dim3 threads(256, 1);
+ dim3 grid(input_height, 1);
+
+ KeMatrixTopK<<>>(
+ output_data, output->dims()[1], indices_data, input_data, input_width,
+ input_width, int(k));
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+REGISTER_OP_GPU_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel);
diff --git a/paddle/operators/top_k_op.h b/paddle/operators/top_k_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..ef66acc1d569282a42be64b7a5e90f3fbdb20690
--- /dev/null
+++ b/paddle/operators/top_k_op.h
@@ -0,0 +1,76 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+using EigenMatrix = framework::EigenMatrix;
+
+template
+class TopkKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ // Get the top k elements of each row of input tensor
+ // FIXME: only deal with matrix(2d tensor).
+ auto* input = ctx.Input("X");
+ auto* output = ctx.Output("Out");
+ auto* indices = ctx.Output("Indices");
+ // k is determined by Attr
+ const size_t k = static_cast(ctx.Attr("k"));
+
+ T* output_data = output->mutable_data(ctx.GetPlace());
+ T* indices_data = indices->mutable_data(ctx.GetPlace());
+
+ auto eg_input = EigenMatrix::From(*input);
+
+ // reshape input to a flattern matrix(like flat_inner_dims)
+ framework::DDim inputdims = input->dims();
+ const size_t row = framework::product(
+ framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
+ const size_t col = inputdims[inputdims.size() - 1];
+ Eigen::DSizes flat2dims(row, col);
+ // NOTE: eigen shape doesn't affect paddle tensor.
+ eg_input.reshape(flat2dims);
+
+ for (size_t i = 0; i < row; i++) {
+ std::vector> vec;
+ for (size_t j = 0; j < col; j++) {
+ vec.push_back(std::pair(eg_input(i, j), j));
+ }
+
+ std::partial_sort(
+ vec.begin(), vec.begin() + k, vec.end(),
+ [](const std::pair& l, const std::pair& r) {
+ return l.first > r.first;
+ });
+ for (size_t j = 0; j < k; j++) {
+ output_data[i * k + j] = vec[j].first;
+ indices_data[i * k + j] = vec[j].second;
+ }
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc
index c21ad3470ba97f533b6c42bc2966be04bc6f7976..53985933ed143c7faba6f7e2a704445697c1f58e 100644
--- a/paddle/pybind/pybind.cc
+++ b/paddle/pybind/pybind.cc
@@ -49,6 +49,7 @@ USE_OP(minus);
USE_OP(cos_sim);
USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter);
+USE_OP(top_k);
USE_OP(squared_l2_distance);
namespace paddle {
diff --git a/paddle/scripts/docker/build.sh b/paddle/scripts/docker/build.sh
index 17986420220fec173bbf3ecff240d4c504f8adbd..e57f793ac42b19037e9ca43a5e4a3ac5447dc34c 100644
--- a/paddle/scripts/docker/build.sh
+++ b/paddle/scripts/docker/build.sh
@@ -37,7 +37,7 @@ Configuring cmake in /paddle/build ...
-DWITH_PYTHON=${WITH_PYTHON:-ON}
-DWITH_SWIG_PY=${WITH_SWIG_PY:-ON}
-DCUDNN_ROOT=/usr/
- -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
+ -DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON}
-DWITH_TESTING=${WITH_TESTING:-ON}
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
========================================
diff --git a/paddle/utils/Util.cpp b/paddle/utils/Util.cpp
index b18b73e06a6c39c3bf9717280bc6323917c80efb..2755fdd9cd1c2509cad996557c6fb24363d42d8a 100644
--- a/paddle/utils/Util.cpp
+++ b/paddle/utils/Util.cpp
@@ -320,6 +320,9 @@ void loadFileList(const std::string& fileListFileName,
}
double getMemoryUsage() {
+#if defined(__ANDROID__)
+ return 0.0;
+#else
FILE* fp = fopen("/proc/meminfo", "r");
CHECK(fp) << "failed to fopen /proc/meminfo";
size_t bufsize = 256 * sizeof(char);
@@ -357,6 +360,7 @@ double getMemoryUsage() {
delete[] buf;
double usedMem = 1.0 - 1.0 * (freeMem + bufMem + cacheMem) / totalMem;
return usedMem;
+#endif
}
SyncThreadPool* getGlobalSyncThreadPool() {
diff --git a/paddle/utils/Util.h b/paddle/utils/Util.h
index 613844669d2495ada7b8f7a841f47b821b7fdeba..22ce2534d3468ded36221810aa61c15b37f13f3d 100644
--- a/paddle/utils/Util.h
+++ b/paddle/utils/Util.h
@@ -33,6 +33,13 @@ limitations under the License. */
#include "Flags.h"
#include "hl_gpu.h"
+#if defined(__ANDROID__) && (__ANDROID_API__ < 21)
+inline int rand_r(unsigned int* seedp) {
+ (void)seedp;
+ return rand();
+}
+#endif
+
/**
* Loop over the elements in a container
* TODO(yuyang18): It's this foreach useful? Why not use C++ 11 foreach,
diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto
index 7d7fc23a4691646dfce4c162a445864c748501d9..ebf0911d6ea0b39d51447859ae2aef485b50b0e6 100644
--- a/proto/ModelConfig.proto
+++ b/proto/ModelConfig.proto
@@ -271,6 +271,7 @@ message ImageConfig {
// The size of input feature map.
required uint32 img_size = 8;
optional uint32 img_size_y = 9;
+ optional uint32 img_size_z = 10 [ default = 1 ];
}
message PriorBoxConfig {
@@ -519,6 +520,7 @@ message LayerConfig {
// for HuberRegressionLoss
optional double delta = 57 [ default = 1.0 ];
+ // for 3D data
optional uint64 depth = 58 [ default = 1 ];
// for switch order layer
diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py
index 14a1b38079cfc6a387f8aeb5d3e362dd838e041f..356e1d8b6fa9173db33a340744afd8d513a83a96 100644
--- a/python/paddle/trainer/config_parser.py
+++ b/python/paddle/trainer/config_parser.py
@@ -1332,6 +1332,12 @@ def parse_image(image, input_layer_name, image_conf):
get_img_size(input_layer_name, image_conf.channels)
+def parse_image3d(image, input_layer_name, image_conf):
+ image_conf.channels = image.channels
+ image_conf.img_size, image_conf.img_size_y, image_conf.img_size_z = \
+ get_img3d_size(input_layer_name, image_conf.channels)
+
+
def parse_norm(norm, input_layer_name, norm_conf):
norm_conf.norm_type = norm.norm_type
config_assert(
@@ -2365,9 +2371,11 @@ class BatchNormLayer(LayerBase):
name,
inputs,
bias=True,
+ img3D=False,
use_global_stats=True,
moving_average_fraction=0.9,
batch_norm_type=None,
+ mean_var_names=None,
**xargs):
if inputs is None:
inputs = []
@@ -2409,24 +2417,69 @@ class BatchNormLayer(LayerBase):
input_layer = self.get_input_layer(0)
image_conf = self.config.inputs[0].image_conf
- parse_image(self.inputs[0].image, input_layer.name, image_conf)
-
- # Only pass the width and height of input to batch_norm layer
- # when either of it is non-zero.
- if input_layer.width != 0 or input_layer.height != 0:
- self.set_cnn_layer(name, image_conf.img_size_y, image_conf.img_size,
- image_conf.channels, False)
+ if img3D:
+ parse_image3d(self.inputs[0].image, input_layer.name, image_conf)
+ # Only pass the width and height of input to batch_norm layer
+ # when either of it is non-zero.
+ if input_layer.width != 0 or input_layer.height != 0:
+ self.set_cnn_layer(
+ input_layer_name=name,
+ depth=image_conf.img_size_z,
+ height=image_conf.img_size_y,
+ width=image_conf.img_size,
+ channels=image_conf.channels,
+ is_print=True)
+ else:
+ self.set_layer_size(input_layer.size)
else:
- self.set_layer_size(input_layer.size)
+ parse_image(self.inputs[0].image, input_layer.name, image_conf)
+ # Only pass the width and height of input to batch_norm layer
+ # when either of it is non-zero.
+ if input_layer.width != 0 or input_layer.height != 0:
+ self.set_cnn_layer(
+ input_layer_name=name,
+ height=image_conf.img_size_y,
+ width=image_conf.img_size,
+ channels=image_conf.channels,
+ is_print=True)
+ else:
+ self.set_layer_size(input_layer.size)
psize = self.calc_parameter_size(image_conf)
dims = [1, psize]
+ if mean_var_names is not None:
+ assert len(mean_var_names) == 2
+ self.inputs[1].parameter_name = mean_var_names[0]
+ self.inputs[2].parameter_name = mean_var_names[1]
+
self.create_input_parameter(0, psize)
self.create_input_parameter(1, psize, dims)
self.create_input_parameter(2, psize, dims)
self.create_bias_parameter(bias, psize)
+ def set_cnn_layer(self,
+ input_layer_name,
+ depth=None,
+ height=None,
+ width=None,
+ channels=None,
+ is_print=True):
+ depthIsNone = False
+ if depth is None:
+ depth = 1
+ depthIsNone = True
+ size = depth * height * width * channels
+ self.set_layer_size(size)
+ self.set_layer_height_width(height, width)
+ self.set_layer_depth(depth)
+ if is_print and depthIsNone:
+ print("output for %s: c = %d, h = %d, w = %d, size = %d" %
+ (input_layer_name, channels, height, width, size))
+ elif is_print:
+ print("output for %s: c = %d, d = %d, h = %d, w = %d, size = %d" %
+ (input_layer_name, channels, depth, height, width, size))
+
def calc_parameter_size(self, image_conf):
return image_conf.channels
@@ -2688,9 +2741,20 @@ class AddToLayer(LayerBase):
super(AddToLayer, self).__init__(
name, 'addto', 0, inputs=inputs, **xargs)
config_assert(len(inputs) > 0, 'inputs cannot be empty for AddToLayer')
- for input_index in xrange(len(self.inputs)):
- input_layer = self.get_input_layer(input_index)
- self.set_layer_size(input_layer.size)
+
+ if len(self.inputs) > 1:
+ for input_index in xrange(len(self.inputs)):
+ assert self.get_input_layer(0).height == self.get_input_layer(
+ input_index).height
+ assert self.get_input_layer(0).width == self.get_input_layer(
+ input_index).width
+ assert self.get_input_layer(0).depth == self.get_input_layer(
+ input_index).depth
+
+ self.set_layer_size(self.get_input_layer(0).size)
+ self.set_layer_height_width(self.get_input_layer(0).height, \
+ self.get_input_layer(0).width)
+ self.set_layer_depth(self.get_input_layer(0).depth)
self.create_bias_parameter(bias, self.config.size)
@@ -3370,11 +3434,20 @@ class ConcatenateLayer(LayerBase):
name, 'concat', 0, inputs=inputs, **xargs)
size = 0
for input_index in xrange(len(self.inputs)):
+ assert self.get_input_layer(0).height == self.get_input_layer(
+ input_index).height
+ assert self.get_input_layer(0).width == self.get_input_layer(
+ input_index).width
+ assert self.get_input_layer(0).depth == self.get_input_layer(
+ input_index).depth
input_layer = self.get_input_layer(input_index)
input = self.inputs[input_index]
if self.config.size == 0:
size += input_layer.size
+ self.set_layer_height_width(self.get_input_layer(0).height, \
+ self.get_input_layer(0).width)
+ self.set_layer_depth(self.get_input_layer(0).depth)
self.set_layer_size(size)
diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py
index cba45bd3afa178ab4dd3a50f0947b144e7466e53..dc68c213da66ac680e6b14266cb5038a5ba73ec2 100644
--- a/python/paddle/trainer_config_helpers/layers.py
+++ b/python/paddle/trainer_config_helpers/layers.py
@@ -354,6 +354,10 @@ class LayerOutput(object):
def height(self):
return cp.g_layer_map[self.full_name].height
+ @property
+ def depth(self):
+ return cp.g_layer_map[self.full_name].depth
+
def set_input(self, input):
"""
Set the input for a memory layer. Can only be used for memory layer
@@ -943,7 +947,7 @@ def data_layer(name, size, depth=None, height=None, width=None,
if height is not None and width is not None:
num_filters = size / (width * height * depth)
assert num_filters * width * height * depth == size, \
- "size=%s width=%s height=%s depth=%s" % (size, width, height, depth)
+ "size=%s width=%s height=%s depth=%s" % (size, width, height, depth)
return LayerOutput(name, LayerType.DATA, size=size, num_filters=num_filters)
@@ -2953,13 +2957,15 @@ def img_cmrnorm_layer(input,
def batch_norm_layer(input,
act=None,
name=None,
+ img3D=False,
num_channels=None,
bias_attr=None,
param_attr=None,
layer_attr=None,
batch_norm_type=None,
moving_average_fraction=0.9,
- use_global_stats=None):
+ use_global_stats=None,
+ mean_var_names=None):
"""
Batch Normalization Layer. The notation of this layer as follow.
@@ -3026,6 +3032,8 @@ def batch_norm_layer(input,
:math:`runningMean = newMean*(1-factor)
+ runningMean*factor`
:type moving_average_fraction: float.
+ :param mean_var_names: [mean name, variance name]
+ :type mean_var_names: string list
:return: LayerOutput object.
:rtype: LayerOutput
"""
@@ -3039,6 +3047,7 @@ def batch_norm_layer(input,
(batch_norm_type == "cudnn_batch_norm")
l = Layer(
name=name,
+ img3D=img3D,
inputs=Input(
input.name, image=Image(channels=num_channels), **param_attr.attr),
active_type=act.name,
@@ -3047,6 +3056,7 @@ def batch_norm_layer(input,
bias=ParamAttr.to_bias(bias_attr),
moving_average_fraction=moving_average_fraction,
use_global_stats=use_global_stats,
+ mean_var_names=mean_var_names,
**ExtraLayerAttribute.to_kwargs(layer_attr))
return LayerOutput(
@@ -6410,7 +6420,7 @@ def gated_unit_layer(input,
@wrap_name_default('switch_order')
def switch_order_layer(input,
name=None,
- reshape=None,
+ reshape_axis=None,
act=None,
layer_attr=None):
"""
@@ -6421,8 +6431,9 @@ def switch_order_layer(input,
The example usage is:
.. code-block:: python
+ reshape_axis = 3
+ switch = switch_order(input=layer, name='switch', reshape_axis=reshape_axis)
reshape = {'height':[ 0, 1, 2], 'width':[3]}
- switch = switch_order(input=layer, name='switch', reshape=reshape)
:param input: The input layer.
:type input: LayerOutput
@@ -6434,6 +6445,11 @@ def switch_order_layer(input,
:rtype: LayerOutput
"""
assert isinstance(input, LayerOutput)
+ assert reshape_axis != None and (reshape_axis > 0 and reshape_axis < 4)
+ height = [ele for ele in xrange(reshape_axis)]
+ width = [ele for ele in range(reshape_axis, 4)]
+ reshape = {'height': height, 'width': width}
+
l = Layer(
name=name,
inputs=input.name,
diff --git a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
index df872a90ff388f0d96cef44763dbd076bc768ab9..8a204a96f3ef57673cef65306d0bf8e8c3409751 100755
--- a/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
+++ b/python/paddle/trainer_config_helpers/tests/configs/file_list.sh
@@ -10,6 +10,6 @@ test_prelu_layer test_row_conv test_detection_output_layer test_multibox_loss_la
test_recursive_topology test_gated_unit_layer test_clip_layer test_row_l2_norm_layer
test_kmax_seq_socre_layer test_sub_nested_seq_select_layer test_scale_shift_layer
test_seq_slice_layer test_cross_entropy_over_beam test_pooling3D_layer
-test_conv3d_layer test_deconv3d_layer)
+test_conv3d_layer test_deconv3d_layer test_BatchNorm3D)
export whole_configs=(test_split_datasource)
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
index 1a577b8d9b1e1915236ba6afcfa97040d70c707a..5ddf6052df021b055390a42c25ce6c0d650e4aee 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_layers.protostr
@@ -62,6 +62,7 @@ layers {
moving_average_fraction: 0.9
height: 227
width: 227
+ depth: 1
}
layers {
name: "__crmnorm_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
index 2818389b16cca75f5030b75fc4de8c89c06c5e02..c0252b945b4c7fd6b4dad8770e3e1dccb88df28a 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/img_trans_layers.protostr
@@ -62,6 +62,7 @@ layers {
moving_average_fraction: 0.9
height: 256
width: 256
+ depth: 1
}
layers {
name: "__crmnorm_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr
new file mode 100644
index 0000000000000000000000000000000000000000..832ed24a31dd2bedba9a4fce77d7a088d1796fdb
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_BatchNorm3D.protostr
@@ -0,0 +1,92 @@
+type: "nn"
+layers {
+ name: "data3D"
+ type: "data"
+ size: 360
+ active_type: ""
+ height: 6
+ width: 20
+ depth: 3
+}
+layers {
+ name: "__batch_norm_0__"
+ type: "batch_norm"
+ size: 360
+ active_type: "relu"
+ inputs {
+ input_layer_name: "data3D"
+ input_parameter_name: "___batch_norm_0__.w0"
+ image_conf {
+ channels: 1
+ img_size: 20
+ img_size_y: 6
+ img_size_z: 3
+ }
+ }
+ inputs {
+ input_layer_name: "data3D"
+ input_parameter_name: "___batch_norm_0__.w1"
+ }
+ inputs {
+ input_layer_name: "data3D"
+ input_parameter_name: "___batch_norm_0__.w2"
+ }
+ bias_parameter_name: "___batch_norm_0__.wbias"
+ moving_average_fraction: 0.9
+ height: 6
+ width: 20
+ depth: 3
+}
+parameters {
+ name: "___batch_norm_0__.w0"
+ size: 1
+ initial_mean: 1.0
+ initial_std: 0.0
+ initial_strategy: 0
+ initial_smart: false
+}
+parameters {
+ name: "___batch_norm_0__.w1"
+ size: 1
+ initial_mean: 0.0
+ initial_std: 0.0
+ dims: 1
+ dims: 1
+ initial_strategy: 0
+ initial_smart: false
+ is_static: true
+ is_shared: true
+}
+parameters {
+ name: "___batch_norm_0__.w2"
+ size: 1
+ initial_mean: 0.0
+ initial_std: 0.0
+ dims: 1
+ dims: 1
+ initial_strategy: 0
+ initial_smart: false
+ is_static: true
+ is_shared: true
+}
+parameters {
+ name: "___batch_norm_0__.wbias"
+ size: 1
+ initial_mean: 0.0
+ initial_std: 0.0
+ dims: 1
+ dims: 1
+ initial_strategy: 0
+ initial_smart: false
+}
+input_layer_names: "data3D"
+output_layer_names: "__batch_norm_0__"
+sub_models {
+ name: "root"
+ layer_names: "data3D"
+ layer_names: "__batch_norm_0__"
+ input_layer_names: "data3D"
+ output_layer_names: "__batch_norm_0__"
+ is_recurrent_layer_group: false
+}
+
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
index b110e91498ce7d112987714bd769868179141c54..8a1399efad0ff339e35f69400ac654a4787a6018 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_bi_grumemory.protostr
@@ -74,6 +74,9 @@ layers {
inputs {
input_layer_name: "__bidirectional_gru_0___bw"
}
+ height: 0
+ width: 0
+ depth: 1
}
parameters {
name: "___bidirectional_gru_0___fw_transform.w0"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
index 8133aa9c8d3e7c6843d1b27b70e87d394a1e0e47..046037936a6d85f54095c65f206e468aa69065d7 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/test_recursive_topology.protostr
@@ -16,6 +16,9 @@ layers {
inputs {
input_layer_name: "data"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_1__"
@@ -28,6 +31,9 @@ layers {
inputs {
input_layer_name: "__addto_0__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_2__"
@@ -40,6 +46,9 @@ layers {
inputs {
input_layer_name: "__addto_1__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_3__"
@@ -52,6 +61,9 @@ layers {
inputs {
input_layer_name: "__addto_2__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_4__"
@@ -64,6 +76,9 @@ layers {
inputs {
input_layer_name: "__addto_3__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_5__"
@@ -76,6 +91,9 @@ layers {
inputs {
input_layer_name: "__addto_4__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_6__"
@@ -88,6 +106,9 @@ layers {
inputs {
input_layer_name: "__addto_5__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_7__"
@@ -100,6 +121,9 @@ layers {
inputs {
input_layer_name: "__addto_6__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_8__"
@@ -112,6 +136,9 @@ layers {
inputs {
input_layer_name: "__addto_7__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_9__"
@@ -124,6 +151,9 @@ layers {
inputs {
input_layer_name: "__addto_8__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_10__"
@@ -136,6 +166,9 @@ layers {
inputs {
input_layer_name: "__addto_9__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_11__"
@@ -148,6 +181,9 @@ layers {
inputs {
input_layer_name: "__addto_10__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_12__"
@@ -160,6 +196,9 @@ layers {
inputs {
input_layer_name: "__addto_11__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_13__"
@@ -172,6 +211,9 @@ layers {
inputs {
input_layer_name: "__addto_12__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_14__"
@@ -184,6 +226,9 @@ layers {
inputs {
input_layer_name: "__addto_13__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_15__"
@@ -196,6 +241,9 @@ layers {
inputs {
input_layer_name: "__addto_14__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_16__"
@@ -208,6 +256,9 @@ layers {
inputs {
input_layer_name: "__addto_15__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_17__"
@@ -220,6 +271,9 @@ layers {
inputs {
input_layer_name: "__addto_16__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_18__"
@@ -232,6 +286,9 @@ layers {
inputs {
input_layer_name: "__addto_17__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_19__"
@@ -244,6 +301,9 @@ layers {
inputs {
input_layer_name: "__addto_18__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_20__"
@@ -256,6 +316,9 @@ layers {
inputs {
input_layer_name: "__addto_19__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_21__"
@@ -268,6 +331,9 @@ layers {
inputs {
input_layer_name: "__addto_20__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_22__"
@@ -280,6 +346,9 @@ layers {
inputs {
input_layer_name: "__addto_21__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_23__"
@@ -292,6 +361,9 @@ layers {
inputs {
input_layer_name: "__addto_22__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_24__"
@@ -304,6 +376,9 @@ layers {
inputs {
input_layer_name: "__addto_23__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_25__"
@@ -316,6 +391,9 @@ layers {
inputs {
input_layer_name: "__addto_24__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_26__"
@@ -328,6 +406,9 @@ layers {
inputs {
input_layer_name: "__addto_25__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_27__"
@@ -340,6 +421,9 @@ layers {
inputs {
input_layer_name: "__addto_26__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_28__"
@@ -352,6 +436,9 @@ layers {
inputs {
input_layer_name: "__addto_27__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_29__"
@@ -364,6 +451,9 @@ layers {
inputs {
input_layer_name: "__addto_28__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_30__"
@@ -376,6 +466,9 @@ layers {
inputs {
input_layer_name: "__addto_29__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__addto_31__"
@@ -388,6 +481,9 @@ layers {
inputs {
input_layer_name: "__addto_30__"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__fc_layer_0__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr b/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
index d0ad388165007b8f96f059e5b003c52f756383e5..7a2f3eab38808a031c27cf7ab9d6273952e389eb 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
+++ b/python/paddle/trainer_config_helpers/tests/configs/protostr/util_layers.protostr
@@ -22,6 +22,9 @@ layers {
inputs {
input_layer_name: "b"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__concat_0__"
@@ -34,6 +37,9 @@ layers {
inputs {
input_layer_name: "b"
}
+ height: 0
+ width: 0
+ depth: 1
}
layers {
name: "__concat_1__"
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py b/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py
new file mode 100644
index 0000000000000000000000000000000000000000..a991b22252ba10eed895efd931108c2d8b0e52f1
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_BatchNorm3D.py
@@ -0,0 +1,11 @@
+from paddle.trainer_config_helpers import *
+
+settings(batch_size=1000, learning_rate=1e-4)
+
+#data = data_layer(name='data', size=180, width=30, height=6)
+#batchNorm = batch_norm_layer(data, num_channels=1)
+#outputs(batchNorm)
+
+data3D = data_layer(name='data3D', size=120 * 3, width=20, height=6, depth=3)
+batchNorm3D = batch_norm_layer(data3D, num_channels=1, img3D=True)
+outputs(batchNorm3D)
diff --git a/python/paddle/v2/framework/op.py b/python/paddle/v2/framework/op.py
index 0349407a851ebb48f69d7daef7a318cf348aad5d..c1585bcffcceb75292853018179066c9f614261e 100644
--- a/python/paddle/v2/framework/op.py
+++ b/python/paddle/v2/framework/op.py
@@ -4,8 +4,8 @@ import paddle.v2.framework.proto.framework_pb2 as framework_pb2
def get_all_op_protos():
"""
- Get all registered op proto from Paddle C++
- :return: list of OpProto
+ Get all registered op proto from PaddlePaddle C++ end.
+ :return: A list of registered OpProto.
"""
protostrs = core.get_all_op_protos()
ret_values = []
@@ -21,8 +21,8 @@ def is_str(s):
class OpDescCreationMethod(object):
"""
- A Functor object to convert user input(use key word args) to OpDesc based on
- OpProto.
+ Convert the user's input(only keyword arguments are supported) to OpDesc
+ based on the OpProto.
:param op_proto: The OpProto object.
:type op_proto: op_proto_pb2.OpProto
@@ -30,17 +30,18 @@ class OpDescCreationMethod(object):
def __init__(self, op_proto):
if not isinstance(op_proto, framework_pb2.OpProto):
- raise TypeError("Argument should be OpProto")
+ raise TypeError(
+ "Type of op_proto should be OpProto in PaddlePaddle.")
self.__op_proto__ = op_proto
def __call__(self, *args, **kwargs):
"""
- Convert user input to OpDesc. Only key-word args are supported.
- :return: OpDesc based on user input
+ Convert user's input to OpDesc. Only keyword arguments are supported.
+ :return: The OpDesc based on user input.
:rtype: op_desc_pb2.OpDesc
"""
if len(args) != 0:
- raise ValueError("Only keyword arguments is supported by Paddle")
+ raise ValueError("Only keyword arguments are supported.")
op_desc = framework_pb2.OpDesc()
for input_parameter in self.__op_proto__.inputs:
@@ -49,8 +50,9 @@ class OpDescCreationMethod(object):
input_arguments = [input_arguments]
if not input_parameter.duplicable and len(input_arguments) > 1:
- raise ValueError("Input %s only accepts one input, but give %d"
- % (input_parameter.name, len(input_arguments)))
+ raise ValueError(
+ "Input %s expects only one input, but %d are given." %
+ (input_parameter.name, len(input_arguments)))
ipt = op_desc.inputs.add()
ipt.parameter = input_parameter.name
@@ -63,7 +65,7 @@ class OpDescCreationMethod(object):
if not output_parameter.duplicable and len(output_arguments) > 1:
raise ValueError(
- "Output %s only accepts one output, but give %d" %
+ "Output %s expects only one output, but %d are given." %
(output_parameter.name, len(output_arguments)))
out = op_desc.outputs.add()
@@ -100,15 +102,17 @@ class OpDescCreationMethod(object):
pair.first = p[0]
pair.second = p[1]
else:
- raise NotImplementedError("Not support attribute type " +
- str(attr.type))
+ raise NotImplementedError(
+ "A not supported attribute type: %s." % (
+ str(attr.type)))
return op_desc
@staticmethod
def any_is_true(generator):
"""
- Reduce a bool array to one. If any of them is True, then return True.
+ Reduce a boolean array to a single boolean parameter. If any element in
+ the array is True, this function will return True, otherwise False.
"""
for flag in generator:
if flag:
@@ -127,7 +131,7 @@ class OpInfo(object):
def create_op_creation_method(op_proto):
"""
- Generate op creation method for an OpProto
+ Generate op creation method for an OpProto.
"""
method = OpDescCreationMethod(op_proto)
@@ -146,20 +150,23 @@ def create_op_creation_method(op_proto):
class OperatorFactory(object):
def __init__(self):
self.op_methods = dict()
+
for op_proto in get_all_op_protos():
method = create_op_creation_method(op_proto)
self.op_methods[method.name] = method
def __call__(self, *args, **kwargs):
- if 'type' in kwargs:
+ if "type" in kwargs:
if len(args) != 0:
- raise ValueError("All Paddle argument should be key-word "
- "argument except type")
- t = kwargs.pop('type')
+ raise ValueError(
+ "Except the argument \"type\","
+ "all of the other arguments should be keyword arguments.")
+ t = kwargs.pop("type")
else:
if len(args) != 1:
- raise ValueError("All Paddle argument should be key-word "
- "argument except type")
+ raise ValueError(
+ "Except the argument \"type\","
+ "all of the other arguments should be keyword arguments.")
t = args[0]
return self.get_op_info(t).method(**kwargs)
@@ -169,7 +176,7 @@ class OperatorFactory(object):
def get_op_info(self, t):
if t not in self.op_methods:
- raise ValueError("operator %s is not registered", t)
+ raise ValueError("The operator: %s is not registered." % t)
return self.op_methods.get(t)
def get_op_input_names(self, type):
@@ -184,7 +191,7 @@ class OperatorFactory(object):
class __RecurrentOp__(object):
__proto__ = None
- type = 'recurrent'
+ type = "recurrent"
def __init__(self):
# cache recurrent_op's proto
@@ -194,8 +201,8 @@ class __RecurrentOp__(object):
self.__proto__ = op_proto
def __call__(self, *args, **kwargs):
- if self.type not in args and 'type' not in kwargs:
- kwargs['type'] = self.type
+ if self.type not in args and "type" not in kwargs:
+ kwargs["type"] = self.type
# create proto
create_method = OpDescCreationMethod(self.__proto__)
proto = create_method(*args, **kwargs)
@@ -203,5 +210,5 @@ class __RecurrentOp__(object):
return core.RecurrentOp.create(proto.SerializeToString())
-Operator = OperatorFactory() # Default global factory
+Operator = OperatorFactory() # The default global factory
RecurrentOp = __RecurrentOp__()
diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt
index a9c33ea1631e8358c41a8566de9db4bd00fc9b74..ef910f939be0b9d3cb5e6d49e69a00daa191b1c6 100644
--- a/python/paddle/v2/framework/tests/CMakeLists.txt
+++ b/python/paddle/v2/framework/tests/CMakeLists.txt
@@ -17,6 +17,7 @@ py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py)
py_test(test_gather_op SRCS test_gather_op.py)
py_test(test_scatter_op SRCS test_scatter_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
+py_test(test_top_k_op SRCS test_top_k_op.py)
py_test(gradient_checker SRCS gradient_checker.py)
diff --git a/python/paddle/v2/framework/tests/mnist.py b/python/paddle/v2/framework/tests/mnist.py
index a68f302f9c344bf6d63e8d9b48836d69338c3d0b..f6f8f49b797fb6e5016a5e309f12f192d5096431 100644
--- a/python/paddle/v2/framework/tests/mnist.py
+++ b/python/paddle/v2/framework/tests/mnist.py
@@ -38,9 +38,9 @@ def feed_data(name, data):
assert isinstance(data, numpy.ndarray)
tensor = scope.find_var(name).get_tensor()
tensor.set_dims(data.shape)
- if data.dtype == numpy.dtype('int32'):
+ if data.dtype == numpy.dtype("int32"):
tensor.alloc_int(place)
- elif data.dtype == numpy.dtype('float32'):
+ elif data.dtype == numpy.dtype("float32"):
tensor.alloc_float(place)
else:
raise ValueError("data type not supported")
@@ -74,22 +74,25 @@ def init_param(net, param_name, dims):
# fc_layer
def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None):
"""
- Add a fc layer to net
+ The fully connected layer.
- :param input: input variable name.
+ :param input: The name of input variable.
:type input: str
- :param size: fully connected layer size.
- :param act: activation name
- :param param: parameter attribute, used for initialize parameters.
- :param bias: bias attribute. False will not have a bias.
- :param name: the name of fc layer. If not set, model will generate a
- readable name
- :return: output variable name.
+ :param size: The size of fully connected layer.
+ :param act: The name of activation.
+ :param param: The attribute of learnable parameter which can be used to
+ modify initialization mean and std of the parameter.
+ :param bias: The attribute of bias. If set False, this layer does not have
+ a bias.
+ :param name: The name of this layer. If it is not set explictly, a name
+ will be generated automatically.
+ :return: The name of the output variable.
"""
+
if name is None:
- name = 'fc_%d' % uniq_id()
+ name = "fc_%d" % uniq_id()
if not isinstance(name, str):
- raise ValueError("name should be string")
+ raise ValueError("The name of a layer should be a string.")
input_dims = scope.find_var(input).get_tensor().get_dims()
@@ -123,7 +126,7 @@ def fc_layer(net, input, size, act="softmax", bias=True, param=None, name=None):
def cross_entropy_layer(net, input, label):
- cost_name = 'cross_entropy_%d' % uniq_id()
+ cost_name = "cross_entropy_%d" % uniq_id()
cross_entropy_op = Operator(
"onehot_cross_entropy", X=input, label=label, Y=cost_name)
net.append_op(cross_entropy_op)
@@ -177,8 +180,8 @@ def error_rate(predict, label):
return error_num / float(len(label))
-images = data_layer(name='pixel', dims=[BATCH_SIZE, 784])
-labels = data_layer(name='label', dims=[BATCH_SIZE])
+images = data_layer(name="pixel", dims=[BATCH_SIZE, 784])
+labels = data_layer(name="label", dims=[BATCH_SIZE])
fc1 = fc_layer(net=forward_net, input=images, size=100, act="sigmoid")
fc2 = fc_layer(net=forward_net, input=fc1, size=100, act="sigmoid")
predict = fc_layer(net=forward_net, input=fc2, size=10, act="softmax")
diff --git a/python/paddle/v2/framework/tests/test_gradient_checker.py b/python/paddle/v2/framework/tests/test_gradient_checker.py
index 857427cdfbb4374957e249f0faa4cfc46ac0e8c7..e8a7f848dffa0529c8cb0d6599286ce0e228d180 100644
--- a/python/paddle/v2/framework/tests/test_gradient_checker.py
+++ b/python/paddle/v2/framework/tests/test_gradient_checker.py
@@ -7,11 +7,11 @@ from gradient_checker import get_numeric_gradient
class GetNumericGradientTest(unittest.TestCase):
def test_add_op(self):
- add_op = Operator('add', X="X", Y="Y", Out="Z")
+ add_op = Operator("add", X="X", Y="Y", Out="Z")
x = numpy.random.random((10, 1)).astype("float32")
y = numpy.random.random((10, 1)).astype("float32")
- arr = get_numeric_gradient(add_op, {'X': x, "Y": y}, 'Z', 'X')
+ arr = get_numeric_gradient(add_op, {"X": x, "Y": y}, "Z", "X")
self.assertAlmostEqual(arr.mean(), 1.0, delta=1e-4)
def test_softmax_op(self):
@@ -35,9 +35,9 @@ class GetNumericGradientTest(unittest.TestCase):
dY = numpy.ones(Y.shape)
dX = label_softmax_grad(Y, dY)
- arr = get_numeric_gradient(softmax_op, {"X": X}, 'Y', 'X')
+ arr = get_numeric_gradient(softmax_op, {"X": X}, "Y", "X")
numpy.testing.assert_almost_equal(arr, dX, decimal=1e-2)
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_lookup_table.py b/python/paddle/v2/framework/tests/test_lookup_table.py
index 19eb464baa555fb67a994f3cfb4d3ed628367c73..4b7ce92c0f0492a73c158378299933a0b329948b 100644
--- a/python/paddle/v2/framework/tests/test_lookup_table.py
+++ b/python/paddle/v2/framework/tests/test_lookup_table.py
@@ -4,7 +4,7 @@ from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
-class TestSigmoidOp(unittest.TestCase):
+class TestLookupTableOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
@@ -15,7 +15,7 @@ class TestSigmoidOp(unittest.TestCase):
self.outputs = {'Out': table[ids]}
-class TestSigmoidGradOp(GradientChecker):
+class TestLookupTableGradOp(GradientChecker):
def test_grad(self):
op = create_op('lookup_table')
table = np.random.random((17, 31)).astype('float32')
diff --git a/python/paddle/v2/framework/tests/test_mul_op.py b/python/paddle/v2/framework/tests/test_mul_op.py
index b58e4266d1588a4b6151f5f896537ded6ddd3896..8c827e242e866b267e0fc4b73c31bafa0ccc7c48 100644
--- a/python/paddle/v2/framework/tests/test_mul_op.py
+++ b/python/paddle/v2/framework/tests/test_mul_op.py
@@ -2,6 +2,7 @@ import unittest
import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
+from paddle.v2.framework.op import Operator
class TestMulOp(unittest.TestCase):
@@ -16,6 +17,22 @@ class TestMulOp(unittest.TestCase):
self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])}
+class TestMulOp2(unittest.TestCase):
+ __metaclass__ = OpTestMeta
+
+ def setUp(self):
+ self.type = "mul"
+ self.inputs = {
+ 'X': np.random.random((15, 4, 12, 10)).astype("float32"),
+ 'Y': np.random.random((4, 30, 8, 2, 9)).astype("float32")
+ }
+ self.attrs = {'x_num_col_dims': 2, 'y_num_col_dims': 2}
+ self.outputs = {
+ 'Out': np.dot(self.inputs['X'].reshape(15 * 4, 12 * 10),
+ self.inputs['Y'].reshape(4 * 30, 8 * 2 * 9))
+ }
+
+
class TestMulGradOp(GradientChecker):
def setUp(self):
self.op = create_op("mul")
@@ -49,7 +66,38 @@ class TestMulGradOp(GradientChecker):
no_grad_set={"Y"})
-# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
+class TestMulGradTest2(GradientChecker):
+ def setUp(self):
+ self.op = Operator(
+ "mul", X="X", Y="Y", Out="Out", x_num_col_dims=2, y_num_col_dims=2)
+ self.inputs = {
+ "X": np.random.random((15, 4, 12, 10)).astype("float32"),
+ "Y": np.random.random((4, 30, 8, 2, 9)).astype("float32")
+ }
+
+ def test_cpu_gpu_compare(self):
+ self.compare_grad(self.op, self.inputs)
+
+ def test_normal(self):
+ self.check_grad(
+ self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5)
+
+ def test_ignore_x(self):
+ self.check_grad(
+ self.op,
+ self.inputs, ["Y"],
+ "Out",
+ max_relative_error=0.5,
+ no_grad_set={"X"})
+
+ def test_ignore_y(self):
+ self.check_grad(
+ self.op,
+ self.inputs, ["X"],
+ "Out",
+ max_relative_error=0.5,
+ no_grad_set={"Y"})
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_rowwise_add_op.py b/python/paddle/v2/framework/tests/test_rowwise_add_op.py
index 2ddb85e2e7a98a08bd1d6e24e6f812f6021142e8..8378c1cd21c21fd31da9b82d2cdaaff332f291d7 100644
--- a/python/paddle/v2/framework/tests/test_rowwise_add_op.py
+++ b/python/paddle/v2/framework/tests/test_rowwise_add_op.py
@@ -16,6 +16,18 @@ class TestRowwiseAddOp(unittest.TestCase):
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
+class TestRowwiseAddOp2(unittest.TestCase):
+ __metaclass__ = OpTestMeta
+
+ def setUp(self):
+ self.type = "rowwise_add"
+ self.inputs = {
+ 'X': np.random.random((13, 6, 7, 8)).astype("float32"),
+ 'b': np.random.random((7, 8)).astype("float32")
+ }
+ self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
+
+
class TestRowwiseAddGradOp(GradientChecker):
def setUp(self):
self.op = create_op("rowwise_add")
@@ -34,5 +46,23 @@ class TestRowwiseAddGradOp(GradientChecker):
self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"})
+class TestRowwiseAddGradOp2(GradientChecker):
+ def setUp(self):
+ self.op = create_op("rowwise_add")
+ self.inputs = {
+ "X": np.random.uniform(0.1, 1, [2, 3, 2, 5]).astype("float32"),
+ "b": np.random.uniform(0.1, 1, [2, 5]).astype("float32")
+ }
+
+ def test_normal(self):
+ self.check_grad(self.op, self.inputs, ["X", "b"], "Out")
+
+ def test_ignore_b(self):
+ self.check_grad(self.op, self.inputs, ["X"], "Out", no_grad_set={"b"})
+
+ def test_ignore_x(self):
+ self.check_grad(self.op, self.inputs, ["b"], "Out", no_grad_set={"X"})
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_softmax_op.py b/python/paddle/v2/framework/tests/test_softmax_op.py
index e670d93653e07d35e5019c9daac45c214eddf367..0d590fa7065bdd2df0e3f2aea5464f0524d70670 100644
--- a/python/paddle/v2/framework/tests/test_softmax_op.py
+++ b/python/paddle/v2/framework/tests/test_softmax_op.py
@@ -18,18 +18,22 @@ class TestSoftmaxOp(unittest.TestCase):
def setUp(self):
self.type = "softmax"
- self.inputs = {'X': np.random.random((32, 100)).astype("float32")}
+ self.inputs = {"X": np.random.random((10, 10)).astype("float32")}
self.outputs = {
- 'Y': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
+ "Y": np.apply_along_axis(stable_softmax, 1, self.inputs["X"])
}
-class SoftmaxGradOpTest(GradientChecker):
- def test_softmax(self):
- op = create_op("softmax")
- inputs = {"X": np.random.uniform(0.1, 1, [10, 10]).astype("float32")}
- self.check_grad(op, inputs, set("X"), "Y")
+class TestSoftmaxGradOp(GradientChecker):
+ def setUp(self):
+ self.op = create_op("softmax")
+ self.inputs = {
+ "X": np.random.uniform(0.1, 1, [10, 10]).astype("float32")
+ }
+
+ def test_softmax_grad(self):
+ self.check_grad(self.op, self.inputs, ["X"], "Y")
-if __name__ == '__main__':
+if __name__ == "__main__":
unittest.main()
diff --git a/python/paddle/v2/framework/tests/test_top_k_op.py b/python/paddle/v2/framework/tests/test_top_k_op.py
new file mode 100644
index 0000000000000000000000000000000000000000..e841d96d26bba13b2780c41ea7a209fd470cad3b
--- /dev/null
+++ b/python/paddle/v2/framework/tests/test_top_k_op.py
@@ -0,0 +1,52 @@
+import unittest
+import numpy as np
+from gradient_checker import GradientChecker, create_op
+from op_test_util import OpTestMeta
+
+
+class TestTopkOp(unittest.TestCase):
+ __metaclass__ = OpTestMeta
+
+ def setUp(self):
+ self.type = "top_k"
+ k = 1
+ input = np.random.random((32, 84)).astype("float32")
+ output = np.ndarray((32, k))
+ indices = np.ndarray((32, k))
+
+ self.inputs = {'X': input}
+ self.attrs = {'k': k}
+
+ for rowid in xrange(32):
+ row = input[rowid]
+ output[rowid] = np.sort(row)[-k:]
+ indices[rowid] = row.argsort()[-k:]
+
+ self.outputs = {'Out': output, 'Indices': indices}
+
+
+class TestTopkOp3d(unittest.TestCase):
+ __metaclass__ = OpTestMeta
+
+ def setUp(self):
+ self.type = "top_k"
+ k = 1
+ input = np.random.random((32, 2, 84)).astype("float32")
+ input_flat_2d = input.reshape(64, 84)
+ output = np.ndarray((64, k))
+ indices = np.ndarray((64, k)).astype("int")
+
+ # FIXME: should use 'X': input for a 3d input
+ self.inputs = {'X': input_flat_2d}
+ self.attrs = {'k': k}
+
+ for rowid in xrange(64):
+ row = input_flat_2d[rowid]
+ output[rowid] = np.sort(row)[-k:]
+ indices[rowid] = row.argsort()[-k:]
+
+ self.outputs = {'Out': output, 'Indices': indices}
+
+
+if __name__ == '__main__':
+ unittest.main()