(lower_bound));
+ return *this;
+ }
+
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) {
diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md
index 8aa6728a95bc464ab8884986f0cec6c817d3303b..c762811dfc190b255e0a3389885a081ce8315caf 100644
--- a/paddle/framework/backward.md
+++ b/paddle/framework/backward.md
@@ -2,20 +2,20 @@
## Motivation
-In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
-
+In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the gradient operators/expressions together with the chain rule. Every forward network needs a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
+
## Backward Operator Registry
-A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs and output gradients and then calculate its input gradients.
+A backward network is built up with several backward operators. Backward operators take forward operators' inputs outputs, and output gradients and then calculate its input gradients.
| | forward operator | backward operator
| ---------------------- | ---------------- |------------------------- |
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
- In most cases, there is a one-to-one correspondence between forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
+ In most cases, there is a one-to-one correspondence between the forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
-For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro:
+For example, we have got a `mul_op`, and we can register its information and corresponding backward operator by the following macro:
```cpp
REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
@@ -27,17 +27,17 @@ REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
## Backward Opeartor Creating
-Given a certain forward operator, we can get its corresponding backward opeartor by calling:
+Given a certain forward operator, we can get its corresponding backward operator by calling:
```cpp
OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
-```
+```
The function `BuildGradOp` will sequentially execute following processes:
1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
-2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these are not necessary for gradient computing.
+2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
@@ -49,31 +49,31 @@ A backward network is a series of backward operators. The main idea of building
In our design, the network itself is also a kind of operator. So the operators contained by a big network may be some small network.
-given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`.
+given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`, `InputGradients`.
1. Op
- when the input forward network is a Op, return its gradient Operator Immediately.
+ when the input forward network is an Op, return its gradient Operator Immediately.
2. NetOp
- when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to forward NetOp.
+ when the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp.
- **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwirte their shared input variable.
+ **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwrite their shared input variable.
- 
+ 
- 1. shared variable in two operators.
+ 1. Shared variable in operators.
- Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator replace the overwirte links.
+ Share variable between operators or same input variable used in multiple operators leads to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively and add a generic add operator replace the overwrite links.
- 
+ 
- 2. replace shared variable gradient with `Add` Operator
+ 2. Replace shared variable's gradient with `Add` operator.
diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc
index 85b7de79743bb0390d66b8999f2e8342a51d14a9..fc3d508553c0e966978b28d58127bdbff10d45f1 100644
--- a/paddle/framework/ddim.cc
+++ b/paddle/framework/ddim.cc
@@ -283,5 +283,14 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim::DDim(std::initializer_list init_list) {
*this = make_ddim(init_list);
}
+
+DDim flatten_to_2d(const DDim& src, int num_col_dims) {
+ int rank = src.size();
+ return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
+ product(slice_ddim(src, num_col_dims, rank))});
+}
+
+DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h
index db30c523948b1d437615aa0e9bfecb5e25569296..ca29e7e8c7776de6adf3e3b0e8f11f0d4d8487c3 100644
--- a/paddle/framework/ddim.h
+++ b/paddle/framework/ddim.h
@@ -115,6 +115,12 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&);
+// Reshape a tensor to a matrix. The matrix's first dimension(column length)
+// will be the product of tensor's first `num_col_dims` dimensions.
+DDim flatten_to_2d(const DDim& src, int num_col_dims);
+
+DDim flatten_to_1d(const DDim& src);
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h
index 2d8d9ae10c56e0632414a5bbc754d35bfa9ce6a5..54bbeafcabdeeb1e2c1017c156b3512c83dada3a 100644
--- a/paddle/framework/eigen.h
+++ b/paddle/framework/eigen.h
@@ -63,20 +63,35 @@ struct EigenTensor {
template
-struct EigenMatrix : public EigenTensor {};
+struct EigenMatrix : public EigenTensor {
+ static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+
+ static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
+ int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+};
template
struct EigenVector : public EigenTensor {
// Flatten reshapes a Tensor into an EigenVector.
static typename EigenVector::Type Flatten(Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
static typename EigenVector::ConstType Flatten(const Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
};
diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc
index dc1957691b1a202826e10e84c21ac8874df9e378..bc4a2db32cfba66bef2c444e1f822e0d2a57b91e 100644
--- a/paddle/framework/eigen_test.cc
+++ b/paddle/framework/eigen_test.cc
@@ -108,5 +108,24 @@ TEST(Eigen, Matrix) {
}
}
+TEST(Eigen, MatrixReshape) {
+ Tensor t;
+ float* p = t.mutable_data({2, 3, 6, 4}, platform::CPUPlace());
+ for (int i = 0; i < 2 * 3 * 6 * 4; ++i) {
+ p[i] = static_cast(i);
+ }
+
+ EigenMatrix::Type em = EigenMatrix::Reshape(t, 2);
+
+ ASSERT_EQ(2 * 3, em.dimension(0));
+ ASSERT_EQ(6 * 4, em.dimension(1));
+
+ for (int i = 0; i < 2 * 3; i++) {
+ for (int j = 0; j < 6 * 4; j++) {
+ ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f);
+ }
+ }
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/images/duplicate_op2.graffle b/paddle/framework/images/duplicate_op2.graffle
index 2b658085d6a55d368c320051ba7f94ec2900f13c..ede3bca30ae17d5af52505fd94dc2f79b23b57e0 100644
Binary files a/paddle/framework/images/duplicate_op2.graffle and b/paddle/framework/images/duplicate_op2.graffle differ
diff --git a/paddle/framework/images/duplicate_op2.png b/paddle/framework/images/duplicate_op2.png
index c5588015d1450fd8c1bda3580680d884494868bb..4e872dc2caf3b0cbd0d5176f11a14801b538dc86 100644
Binary files a/paddle/framework/images/duplicate_op2.png and b/paddle/framework/images/duplicate_op2.png differ
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index 790cfc4746b1d34da413fa3c29a266f962c6dde6..e1e122091f7759b1a68f1f982bc2a35e8241f9f0 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -123,6 +123,15 @@ OperatorBase::OperatorBase(const std::string& type,
CheckAllInputOutputSet();
}
+std::vector OperatorBase::InputVars() const {
+ std::vector ret_val;
+ for (auto& o : outputs_) {
+ ret_val.reserve(ret_val.size() + o.second.size());
+ ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
+ }
+ return ret_val;
+}
+
std::vector OperatorBase::OutputVars(bool has_intermediate) const {
std::vector ret_val;
if (has_intermediate) {
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 9a98d4d3be0d1cb875d614b263f1e4365ede4113..4600b06009bcef7d0774d25b816aac4733f30795 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -94,11 +94,14 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
+
//! Get a input with argument's name described in `op_proto`
std::string Input(const std::string& name) const;
//! Get a input which has multiple variables.
const std::vector& Inputs(const std::string& name) const;
+ std::vector InputVars() const;
+
//! Get a output with argument's name described in `op_proto`
std::string Output(const std::string& name) const;
//! Get an output which has multiple variables.
@@ -311,9 +314,9 @@ class InferShapeContext {
}
template
- std::vector MultiOutput(const std::string& name) const {
+ std::vector MultiOutput(const std::string& name) const {
auto names = op_.Outputs(name);
- std::vector res;
+ std::vector res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index 643f875491724bf443bd7727391734377ee6180c..ce938b21437195fed8c1adad4329fd139f3f96ab 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -43,6 +43,9 @@ class Tensor {
template
friend struct EigenTensor;
+ template
+ friend struct EigenMatrix;
+
template
friend struct EigenVector;
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index 94f436294f350e2a39785a09959efb3b17bd00a5..637f04ae0037bd402d855b8bcde8087bfe8328d1 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -148,5 +148,13 @@ inline Tensor& Tensor::Resize(const DDim& dims) {
inline const DDim& Tensor::dims() const { return dims_; }
+template
+inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
+ Tensor res;
+ res.ShareDataWith(src);
+ res.Resize(flatten_to_2d(src.dims(), num_col_dims));
+ return res;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index 7db38d5caeebccf710334e854faf785ef0f64063..55302ea47120f420e952b26830c8ea4cbcce6435 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) {
}
#endif
}
+
+TEST(Tensor, ReshapeToMatrix) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+ Tensor src;
+ int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace());
+ for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
+ src_ptr[i] = i;
+ }
+ Tensor res = ReshapeToMatrix(src, 2);
+ ASSERT_EQ(res.dims()[0], 2 * 3);
+ ASSERT_EQ(res.dims()[1], 4 * 9);
+}
\ No newline at end of file
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.cpp b/paddle/gserver/layers/BatchNormBaseLayer.cpp
index 1ceaaaa206ee3cbc5421238574c7f310011ccaa5..f7a80e23e1bd49549bec57b360587adc6b423794 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.cpp
+++ b/paddle/gserver/layers/BatchNormBaseLayer.cpp
@@ -62,14 +62,18 @@ void BatchNormBaseLayer::calFeatureMapSize() {
const ImageConfig& conf = config_.inputs(0).image_conf();
imageH_ = inputLayers_[0]->getOutput().getFrameHeight();
imageW_ = inputLayers_[0]->getOutput().getFrameWidth();
+ imageD_ = inputLayers_[0]->getOutput().getFrameDepth();
+
+ if (0 == imageD_) imageD_ = conf.img_size_z();
if (imageH_ == 0 && imageW_ == 0) {
imageH_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
imageW_ = conf.img_size();
} else {
getOutput().setFrameHeight(imageH_);
getOutput().setFrameWidth(imageW_);
+ getOutput().setFrameDepth(imageD_);
}
- imgPixels_ = imageH_ * imageW_;
+ imgPixels_ = imageH_ * imageW_ * imageD_;
}
} // namespace paddle
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.h b/paddle/gserver/layers/BatchNormBaseLayer.h
index 230bafc31d96bbd49481a7ed135be6888688627e..e721d2d267a31cae46407673b8b1281e87055608 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.h
+++ b/paddle/gserver/layers/BatchNormBaseLayer.h
@@ -80,6 +80,7 @@ protected:
/// Height or width of input image feature.
/// Both of them are 1 if the input is fully-connected layer.
+ int imageD_;
int imageH_;
int imageW_;
/// Height * Width.
diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
index 44ba2c4b7d1562d2ce839b5f4b4de1af35e6925f..49a9540c0b6e36b59ed786287ff5c4569b69a6a5 100644
--- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp
+++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
@@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
}
void CudnnBatchNormLayer::reshape(int batchSize) {
- hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_);
+ hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_);
}
void CudnnBatchNormLayer::forward(PassType passType) {
@@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
EPS,
batchSize,
channels_,
- imageH_,
+ imageH_ * imageD_,
imageW_);
}
}
diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp
index 6a91042f628920a9986763531fb4c633307b43b8..d7eee6eaf078dab8d48adc4c7ee758a433672ac6 100644
--- a/paddle/gserver/layers/SwitchOrderLayer.cpp
+++ b/paddle/gserver/layers/SwitchOrderLayer.cpp
@@ -24,19 +24,21 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap,
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
auto& img_conf = config_.inputs(0).image_conf();
+ size_t inD = img_conf.img_size_z();
size_t inH =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
size_t inW = img_conf.img_size();
size_t inC = img_conf.channels();
+ inH = inH * inD;
inDims_ = TensorShape({0, inC, inH, inW});
outDims_ = TensorShape(4);
auto& reshape_conf = config_.reshape_conf();
- for (size_t i = 0; i < reshape_conf.heightaxis_size(); i++) {
- heightAxis_.push_back(reshape_conf.heightaxis(i));
+ for (int i = 0; i < reshape_conf.height_axis_size(); i++) {
+ heightAxis_.push_back(reshape_conf.height_axis(i));
}
- for (size_t i = 0; i < reshape_conf.widthaxis_size(); i++) {
- widthAxis_.push_back(reshape_conf.widthaxis(i));
+ for (int i = 0; i < reshape_conf.width_axis_size(); i++) {
+ widthAxis_.push_back(reshape_conf.width_axis(i));
}
createFunction(nchw2nhwc_, "NCHW2NHWC", FuncConfig());
createFunction(nhwc2nchw_, "NHWC2NCHW", FuncConfig());
@@ -64,9 +66,10 @@ void SwitchOrderLayer::setInDims() {
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
inDims_.setDim(0, batchSize);
-
+ int d = inputLayers_[0]->getOutput().getFrameDepth();
+ d = (d == 0 ? 1 : d);
int h = inputLayers_[0]->getOutput().getFrameHeight();
- if (h != 0) inDims_.setDim(2, h);
+ if (h != 0) inDims_.setDim(2, h * d);
int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w);
int totalCount = input->getElementCnt();
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index e0c14ad5b512c7329062a5426ef34844ec268020..0e6be2df9ef5f0fae8ed2b0c65ac6c032fe45ab1 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -1703,6 +1703,55 @@ TEST(Layer, BatchNormalizationLayer) {
#endif
}
+void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) {
+ TestConfig config;
+ const int CHANNELS = 10;
+ const int IMG_SIZE = 16;
+ const int IMG_SIZE_Y = 8;
+ const int IMG_SIZE_Z = 8;
+ size_t size = CHANNELS * IMG_SIZE * IMG_SIZE_Y * IMG_SIZE_Z;
+ config.layerConfig.set_type(type);
+ config.layerConfig.set_size(size);
+ config.layerConfig.set_active_type("sigmoid");
+ config.biasSize = CHANNELS;
+ config.inputDefs.push_back({INPUT_DATA,
+ "layer_0",
+ /* dim= */ size,
+ /* paraSize= */ CHANNELS});
+
+ config.inputDefs.push_back({INPUT_DATA, "layer_1_running_mean", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+ config.inputDefs.push_back({INPUT_DATA, "layer_2_running_var", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+
+ LayerInputConfig* input = config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+
+ ImageConfig* img_conf = input->mutable_image_conf();
+ img_conf->set_channels(CHANNELS);
+ img_conf->set_img_size(IMG_SIZE);
+ img_conf->set_img_size_y(IMG_SIZE_Y);
+ img_conf->set_img_size_z(IMG_SIZE_Z);
+
+ testLayerGrad(config,
+ "batch_norm",
+ 64,
+ /* trans= */ trans,
+ useGpu,
+ /* useWeight */ true);
+}
+
+TEST(Layer, testBatchNorm3DLayer) {
+ testBatchNorm3DLayer("batch_norm", false, false);
+#ifndef PADDLE_ONLY_CPU
+ testBatchNorm3DLayer("batch_norm", false, true);
+ if (hl_get_cudnn_lib_version() >= int(4000)) {
+ testBatchNorm3DLayer("cudnn_batch_norm", false, true);
+ }
+#endif
+}
+
void testConvOperator(bool isDeconv) {
TestConfig config;
const int NUM_FILTERS = 16;
@@ -2019,10 +2068,10 @@ TEST(Layer, SwitchOrderLayer) {
img->set_img_size_y(16);
ReshapeConfig* reshape = config.layerConfig.mutable_reshape_conf();
- reshape->add_heightaxis(0);
- reshape->add_heightaxis(1);
- reshape->add_heightaxis(2);
- reshape->add_widthaxis(3);
+ reshape->add_height_axis(0);
+ reshape->add_height_axis(1);
+ reshape->add_height_axis(2);
+ reshape->add_width_axis(3);
// config softmax layer
config.layerConfig.set_type("switch_order");
diff --git a/paddle/operators/identity_op.cc b/paddle/operators/identity_op.cc
index be956bf3b320d6beacdb0d2ca742c3e854194b19..7d9d4fa519d1c690feacbadc5175aeab49082282 100644
--- a/paddle/operators/identity_op.cc
+++ b/paddle/operators/identity_op.cc
@@ -18,17 +18,20 @@
namespace paddle {
namespace operators {
-// identity is a alias of scale op. This is also a example for creating a alias
-// operator.
+// The identity operator is an alias of the scale operator. This is also an
+// example for creating an alias for an existing operator.
template
class IdentityOpMaker : public framework::OpProtoAndCheckerMaker {
public:
IdentityOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
- AddInput("X", "input tensor of identity op");
- AddOutput("Out", "output tensor of identity op");
- AddComment("identity operator. Just a alias of scale op which scale = 1.0");
+ AddInput("X", "The input tensor of identity operator.");
+ AddOutput("Out", "The output tensor of identity operator.");
+ AddComment(R"DOC(
+The identity operator is an alias of the scale operator
+with the attribute scale fixed to 1.0.
+)DOC");
}
};
diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc
index f905600bb3ac27ccf6939c59e52bd6d3c3c87a4e..4f380388b108dc173d847f027ba5c9db387a87f8 100644
--- a/paddle/operators/math/im2col_test.cc
+++ b/paddle/operators/math/im2col_test.cc
@@ -74,7 +74,9 @@ void testIm2col() {
#ifndef PADDLE_ONLY_CPU
context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
-#endif
+#else
+ PADDLE_THROW("no GPU support");
+#endif // PADDLE_ONLY_CPU
}
im2col(input, output_cfo, stride, stride, padding, padding, context);
im2col_ocf(input, output_ocf, stride, stride, padding, padding, context);
diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc
index 28a47cdff2e9b7a965ff9f99e787bb8315010823..710a56a0e8e2d17162d7d000df226f1537104eb9 100644
--- a/paddle/operators/mul_op.cc
+++ b/paddle/operators/mul_op.cc
@@ -25,18 +25,27 @@ class MulOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- auto dim0 = ctx.Input("X")->dims();
- auto dim1 = ctx.Input("Y")->dims();
- PADDLE_ENFORCE_EQ(dim0.size(), 2,
- "input X(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("X"));
- PADDLE_ENFORCE_EQ(dim1.size(), 2,
- "input Y(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("Y"));
+ auto x_dims = ctx.Input("X")->dims();
+ auto y_dims = ctx.Input("Y")->dims();
+ int x_num_col_dims = Attr("x_num_col_dims");
+ int y_num_col_dims = Attr("y_num_col_dims");
+
+ PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
+ "The rank of input tensor X(%s) should be larger than "
+ "`mul_op`'s `x_num_col_dims`.",
+ ctx.op().Input("X"));
+ PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
+ "The rank of input tensor Y(%s) should be larger than "
+ "`mul_op`'s `y_num_col_dims`.",
+ ctx.op().Input("Y"));
+
+ auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
+ auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
+
PADDLE_ENFORCE_EQ(
- dim0[1], dim1[0],
+ x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's height.");
- ctx.Output("Out")->Resize({dim0[0], dim1[1]});
+ ctx.Output("Out")->Resize({x_mat_dims[0], y_mat_dims[1]});
}
};
@@ -47,6 +56,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op");
+ AddAttr(
+ "x_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `X`,
+ in that case, tensors will be reshaped to a matrix. The matrix's first
+ dimension(column length) will be the product of tensor's last
+ `num_col_dims` dimensions, and the matrix's second dimension(row length)
+ will be the product of tensor's first `rank - num_col_dims` dimensions.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
+ AddAttr(
+ "y_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `Y`,
+ in that case, tensors will be reshaped to a matrix. Just like input `X`.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
AddComment(R"DOC(
Two Element Mul Operator.
@@ -70,10 +96,20 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output(framework::GradVarName("X"));
auto *y_grad = ctx.Output(framework::GradVarName("Y"));
- PADDLE_ENFORCE(x_dims[0] == out_dims[0],
- "Out@GRAD M X N must equal to X dims 0, M ");
- PADDLE_ENFORCE(y_dims[1] == out_dims[1],
- "Out@GRAD M X N must equal to Y dims 1, N ");
+
+ auto x_mat_dims =
+ framework::flatten_to_2d(x_dims, Attr("x_num_col_dims"));
+ auto y_mat_dims =
+ framework::flatten_to_2d(y_dims, Attr("y_num_col_dims"));
+
+ PADDLE_ENFORCE_EQ(
+ x_mat_dims[0], out_dims[0],
+ "The first dimension of Out@GRAD must equal to the first dimension of "
+ "the first operand.");
+ PADDLE_ENFORCE_EQ(
+ y_mat_dims[1], out_dims[1],
+ "The second dimension of Out@GRAD must equal to the second "
+ "dimension of the second operand.");
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h
index 05a79e13b3470e39a5ebd0394ba05629553a5075..3c01f868bda8cba488b3403df456d63d6b082fa6 100644
--- a/paddle/operators/mul_op.h
+++ b/paddle/operators/mul_op.h
@@ -1,7 +1,7 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
+ You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
@@ -31,13 +31,25 @@ template
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
- auto* x = context.Input("X");
- auto* y = context.Input("Y");
- auto* z = context.Output("Out");
+ const Tensor* x = context.Input("X");
+ const Tensor* y = context.Input("Y");
+ Tensor* z = context.Output("Out");
+ const Tensor x_matrix =
+ x->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *x, context.template Attr("x_num_col_dims"))
+ : *x;
+ const Tensor y_matrix =
+ y->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *y, context.template Attr("y_num_col_dims"))
+ : *y;
+
z->mutable_data(context.GetPlace());
auto* device_context =
const_cast(context.device_context_);
- math::matmul(*x, false, *y, false, 1, z, 0, device_context);
+ math::matmul(x_matrix, false, y_matrix, false, 1, z, 0,
+ device_context);
}
};
@@ -45,23 +57,39 @@ template
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
- auto* x = ctx.Input("X");
- auto* y = ctx.Input("Y");
- auto* dout = ctx.Input(framework::GradVarName("Out"));
+ int x_num_col_dims = ctx.template Attr("x_num_col_dims");
+ int y_num_col_dims = ctx.template Attr("y_num_col_dims");
+ const Tensor* x = ctx.Input("X");
+ const Tensor* y = ctx.Input("Y");
+ const Tensor x_matrix =
+ x->dims().size() > 2 ? framework::ReshapeToMatrix(*x, x_num_col_dims)
+ : *x;
+ const Tensor y_matrix =
+ y->dims().size() > 2 ? framework::ReshapeToMatrix(*y, y_num_col_dims)
+ : *y;
+ const Tensor* dout = ctx.Input(framework::GradVarName("Out"));
- auto* dx = ctx.Output(framework::GradVarName("X"));
- auto* dy = ctx.Output(framework::GradVarName("Y"));
+ Tensor* dx = ctx.Output(framework::GradVarName("X"));
+ Tensor* dy = ctx.Output(framework::GradVarName("Y"));
auto* device_context =
const_cast(ctx.device_context_);
if (dx) {
dx->mutable_data(ctx.GetPlace());
+ Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix(
+ *dx, x_num_col_dims)
+ : *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
- math::matmul(*dout, false, *y, true, 1, dx, 0, device_context);
+ math::matmul(*dout, false, y_matrix, true, 1, &dx_matrix, 0,
+ device_context);
}
if (dy) {
dy->mutable_data(ctx.GetPlace());
+ Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix(
+ *dy, y_num_col_dims)
+ : *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
- math::matmul(*x, true, *dout, false, 1, dy, 0, device_context);
+ math::matmul(x_matrix, true, *dout, false, 1, &dy_matrix, 0,
+ device_context);
}
}
};
diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc
index 30b4b404315a9f041e21d79b75fd06307e33f7f9..fa8f0ff1a858143af427b51025279c726f1628e0 100644
--- a/paddle/operators/rowwise_add_op.cc
+++ b/paddle/operators/rowwise_add_op.cc
@@ -25,14 +25,19 @@ class RowwiseAddOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- auto dim0 = ctx.Input("X")->dims();
- auto dim1 = ctx.Input("b")->dims();
-
- PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix");
- PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
- PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
- PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1");
- ctx.Output("Out")->Resize(ctx.Input("X")->dims());
+ auto x_dims = ctx.Input("X")->dims();
+ auto b_dims = ctx.Input("b")->dims();
+ PADDLE_ENFORCE_GT(
+ x_dims.size(), b_dims.size(),
+ "The rank of input `X` must be larger than the one of input `b`.");
+
+ int num_col_dims = x_dims.size() - b_dims.size();
+
+ PADDLE_ENFORCE_EQ(
+ framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
+ "The width of two operands must be same");
+ PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1");
+ ctx.Output("Out")->Resize(x_dims);
}
};
@@ -61,13 +66,20 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
- auto dims0 = ctx.Input("X")->dims();
- auto dims1 = ctx.Input("b")->dims();
- PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
+ auto x_dims = ctx.Input("X")->dims();
+ auto b_dims = ctx.Input("b")->dims();
+ PADDLE_ENFORCE_GT(
+ x_dims.size(), b_dims.size(),
+ "The rank of input `X` must be larger than the one of input `b`.");
+
+ int num_col_dims = x_dims.size() - b_dims.size();
+ PADDLE_ENFORCE_EQ(
+ framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
+ "The width of two operands must be same");
auto *dx = ctx.Output(framework::GradVarName("X"));
auto *db = ctx.Output(framework::GradVarName("b"));
- if (dx) dx->Resize(dims0);
- if (db) db->Resize(dims1);
+ if (dx) dx->Resize(x_dims);
+ if (db) db->Resize(b_dims);
}
};
diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h
index 4e926d9f2947f37b71e81c0fa592b0c66b19c640..35774b940926f77167b8f19597027e74d3477e5b 100644
--- a/paddle/operators/rowwise_add_op.h
+++ b/paddle/operators/rowwise_add_op.h
@@ -33,10 +33,12 @@ class RowwiseAddKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output("Out");
out->mutable_data(context.GetPlace());
-
- auto input = EigenMatrix::From(*context.Input("X"));
- auto bias = EigenVector::From(*context.Input("b"));
- auto output = EigenMatrix::From(*out);
+ int num_col_dims = context.Input("X")->dims().size() -
+ context.Input("b")->dims().size();
+ auto input =
+ EigenMatrix::Reshape(*context.Input("X"), num_col_dims);
+ auto bias = EigenVector::Flatten(*context.Input("b"));
+ auto output = EigenMatrix::Reshape(*out, num_col_dims);
const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size;
@@ -54,12 +56,15 @@ class RowwiseAddGradKernel : public framework::OpKernel {
auto* dout = context.Input(framework::GradVarName("Out"));
auto* dx = context.Output(framework::GradVarName("X"));
auto* db = context.Output(framework::GradVarName("b"));
+ int num_col_dims = context.Input("X")->dims().size() -
+ context.Input("b")->dims().size();
- auto out_grad = EigenMatrix::From(*dout);
+ auto out_grad = EigenMatrix::Reshape(*dout, num_col_dims);
auto place = context.GetEigenDevice();
+
if (dx) {
dx->mutable_data(context.GetPlace());
- EigenMatrix::From(*dx).device(place) = out_grad;
+ EigenMatrix::Reshape(*dx, num_col_dims).device(place) = out_grad;
}
if (db) {
diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc
index 3d82b345829b0a554a204ada91c807e42b71dc58..ea991f683d841b3dc4624a0d8aa3c88367fd3c6d 100644
--- a/paddle/operators/scale_op.cc
+++ b/paddle/operators/scale_op.cc
@@ -44,11 +44,13 @@ class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
The equation is: Out = scale*X
)DOC");
- AddAttr("scale", "scale of scale operator.").SetDefault(1.0);
+ AddAttr("scale", "The scaling factor of the scale operator.")
+ .SetDefault(1.0);
}
};
-// Scale Op's gradient is scale op, too.
+// The operator to calculate gradients of a scale operator is just the scale
+// operator itself.
// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out))
template
class ScaleGradOp : public NetOp {
diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc
index 7d062ad67c048bc6bef68121f86334eb3f1efe92..7166b2f60be8a6088ab3a81686f7bed1b7181d97 100644
--- a/paddle/operators/softmax_op.cc
+++ b/paddle/operators/softmax_op.cc
@@ -51,7 +51,7 @@ the other dimensions in the K-dimensional vector input. Then the ratio of the
exponential of the given dimension and the sum of exponential values of all
the other dimensions is the output of the softmax operator.
-For each row `i` and each column `j` in X, we have:
+For each row `i` and each column `j` in input X, we have:
Y[i, j] = exp(X[i, j]) / sum_j(exp(X[i, j]))
)DOC");
@@ -64,14 +64,15 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Y")),
- "Input(Y@GRAD) should not be null");
- PADDLE_ENFORCE(ctx.Input("Y")->dims() ==
- ctx.Input(framework::GradVarName("Y"))->dims(),
- "the shape of Input(0) and Input(1) should be the same");
+ "Input(Y@GRAD) should be not null.");
+ PADDLE_ENFORCE_EQ(ctx.Input("Y")->dims(),
+ ctx.Input(framework::GradVarName("Y"))->dims(),
+ "Input(Y) and its gradients should have a same shape.");
+
ctx.Output(framework::GradVarName("X"))
- ->Resize(ctx.Input("Y")->dims());
+ ->Resize(ctx.Input("X")->dims());
}
};
diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h
index 4fa6b59540498638c3b7df639ae10a66c0fa1c16..8a3a5ab927c0e2937936fcc973f000d4d95c3dbc 100644
--- a/paddle/operators/softmax_op.h
+++ b/paddle/operators/softmax_op.h
@@ -28,12 +28,12 @@ template
class SoftmaxKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
- auto input = context.Input("X");
- auto output = context.Output("Y");
- output->mutable_data(context.GetPlace());
+ auto X = context.Input("X");
+ auto Y = context.Output("Y");
+ Y->mutable_data(context.GetPlace());
- auto logits = EigenMatrix::From(*input);
- auto softmax = EigenMatrix::From(*output);
+ auto logits = EigenMatrix::From(*X);
+ auto softmax = EigenMatrix::From(*Y);
const int kBatchDim = 0;
const int kClassDim = 1;
diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..5805826ee8a555ca6dfc1ca81feaadffea9e1012
--- /dev/null
+++ b/paddle/operators/sum_op.cc
@@ -0,0 +1,73 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/sum_op.h"
+#include
+
+namespace paddle {
+namespace operators {
+using framework::Tensor;
+
+class SumOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ auto ins = ctx.MultiInput("X");
+ auto *out = ctx.Output("Out");
+ int N = ins.size();
+
+ auto in_dim = ins[0]->dims();
+
+ PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
+ for (int i = 1; i < N; i++) {
+ auto dim = ins[i]->dims();
+ PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape");
+ }
+ out->Resize(in_dim);
+ }
+};
+
+class SumOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "the input tensors of sum operator.").AsDuplicable();
+ AddOutput("Out", "the output tensor of sum operator.");
+ AddComment(R"DOC(
+ Sum the input tensors.
+ )DOC");
+ }
+};
+
+class SumGradOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ auto outputs = ctx.MultiOutput(framework::GradVarName("X"));
+ auto dims = ctx.Input(framework::GradVarName("Out"))->dims();
+ for (auto output : outputs) {
+ output->Resize(dims);
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
+REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_CPU_KERNEL(sum_grad,
+ ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.cu b/paddle/operators/sum_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a465cf3659ba7c51338abadfc62962fb6755a39d
--- /dev/null
+++ b/paddle/operators/sum_op.cu
@@ -0,0 +1,18 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/sum_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_GPU_KERNEL(sum_grad,
+ ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..0b1e9ebaa38d455fb5e3ce8c1a39cbbcdad9a940
--- /dev/null
+++ b/paddle/operators/sum_op.h
@@ -0,0 +1,65 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template
+using EigenVector = framework::EigenVector;
+
+template
+class SumKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto ins = context.MultiInput("X");
+ auto* out = context.Output("Out");
+ out->mutable_data(context.GetPlace());
+
+ auto place = context.GetEigenDevice();
+ auto result = EigenVector::Flatten(*out);
+
+ int N = ins.size();
+ auto in = EigenVector::Flatten(*(ins[0]));
+ result.device(place) = in;
+ for (int i = 1; i < N; i++) {
+ auto in = EigenVector::Flatten(*(ins[i]));
+ result.device(place) = result + in;
+ }
+ }
+};
+
+template
+class SumGradKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto* input = context.Input(framework::GradVarName("Out"));
+ auto outs = context.MultiOutput(framework::GradVarName("X"));
+ for (auto out : outs) {
+ out->mutable_data(context.GetPlace());
+ }
+
+ auto place = context.GetEigenDevice();
+ auto in = EigenVector::Flatten(*input);
+ for (auto out : outs) {
+ auto result = EigenVector::Flatten(*out);
+ result.device(place) = in;
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..38d2f0a09aec751734864947a2f3cfa20107e22f
--- /dev/null
+++ b/paddle/operators/top_k_op.cc
@@ -0,0 +1,67 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/top_k_op.h"
+
+namespace paddle {
+namespace operators {
+
+class TopkOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
+ "Input of TopkOP must be initialized.");
+ auto *input = ctx.Input("X");
+ const int k = static_cast(ctx.Attr("k"));
+
+ PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
+ PADDLE_ENFORCE_GE(input->dims().size(), 1, "input must have >= 1d shape");
+ PADDLE_ENFORCE_GE(input->dims()[input->dims().size() - 1], k,
+ "input must have >= k columns");
+
+ framework::DDim dims = input->dims();
+ dims[dims.size() - 1] = k;
+ ctx.Output("Out")->Resize(dims);
+ ctx.Output("Indices")->Resize(dims);
+ }
+};
+
+class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ TopkOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "The input of Topk op");
+ AddOutput("Out", "The output tensor of Topk op");
+ AddOutput("Indices", "The indices of Topk elements of input");
+ AddComment(
+ R"DOC(If the input is a vector (1d tensor), finds the k largest entries in the vector and outputs their values and indices as vectors. Thus values[j] is the j-th largest entry in input, and its index is indices[j].
+
+ For matrices, computes the top k entries in each row. )DOC");
+ AddAttr("k",
+ "Number of top elements to look for along the last "
+ "dimension (along each row for matrices).")
+ .SetDefault(1);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_WITHOUT_GRADIENT(top_k, ops::TopkOp, ops::TopkOpMaker);
+REGISTER_OP_CPU_KERNEL(top_k,
+ ops::TopkKernel);
diff --git a/paddle/operators/top_k_op.cu b/paddle/operators/top_k_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..afe4d149c53819c45e20353bc9d16393f3f61e0f
--- /dev/null
+++ b/paddle/operators/top_k_op.cu
@@ -0,0 +1,318 @@
+/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/framework/op_registry.h"
+#include "paddle/platform/assert.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+struct Pair {
+ __device__ __forceinline__ Pair() {}
+ __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
+
+ __device__ __forceinline__ void set(T value, int id) {
+ v = value;
+ id = id;
+ }
+
+ __device__ __forceinline__ void operator=(const Pair& in) {
+ v = in.v;
+ id = in.id;
+ }
+
+ __device__ __forceinline__ bool operator<(const T value) const {
+ return (v < value);
+ }
+
+ __device__ __forceinline__ bool operator<(const Pair& in) const {
+ return (v < in.v) || ((v == in.v) && (id > in.id));
+ }
+
+ __device__ __forceinline__ bool operator>(const Pair& in) const {
+ return (v > in.v) || ((v == in.v) && (id < in.id));
+ }
+
+ T v;
+ int id;
+};
+
+template
+__device__ __forceinline__ void AddTo(Pair topk[], const Pair& p,
+ int beam_size) {
+ for (int k = beam_size - 2; k >= 0; k--) {
+ if (topk[k] < p) {
+ topk[k + 1] = topk[k];
+ } else {
+ topk[k + 1] = p;
+ return;
+ }
+ }
+ topk[0] = p;
+}
+
+template
+__device__ __forceinline__ void AddTo(Pair topk[], const Pair& p) {
+ for (int k = beam_size - 2; k >= 0; k--) {
+ if (topk[k] < p) {
+ topk[k + 1] = topk[k];
+ } else {
+ topk[k + 1] = p;
+ return;
+ }
+ }
+ topk[0] = p;
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* src, int idx,
+ int dim, int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < src[idx]) {
+ Pair tmp(src[idx], idx);
+ AddTo(topk, tmp, beam_size);
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* src, int idx,
+ int dim, const Pair& max,
+ int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < src[idx]) {
+ Pair tmp(src[idx], idx);
+ if (tmp < max) {
+ AddTo(topk, tmp, beam_size);
+ }
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* val, int* col,
+ int idx, int dim, int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < val[idx]) {
+ Pair tmp(val[idx], col[idx]);
+ AddTo(topk, tmp, beam_size);
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* val, int* col,
+ int idx, int dim, const Pair& max,
+ int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < val[idx]) {
+ Pair tmp(val[idx], col[idx]);
+ if (tmp < max) {
+ AddTo(topk, tmp, beam_size);
+ }
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam,
+ int beam_size, const T* src,
+ bool& firstStep, bool& is_empty,
+ Pair& max, int dim,
+ const int tid) {
+ if (beam > 0) {
+ int length = beam < beam_size ? beam : beam_size;
+ if (firstStep) {
+ firstStep = false;
+ GetTopK(topk, src, tid, dim, length);
+ } else {
+ for (int k = 0; k < MaxLength; k++) {
+ if (k < MaxLength - beam) {
+ topk[k] = topk[k + beam];
+ } else {
+ topk[k].set(-INFINITY, -1);
+ }
+ }
+ if (!is_empty) {
+ GetTopK(topk + MaxLength - beam, src, tid, dim, max,
+ length);
+ }
+ }
+
+ max = topk[MaxLength - 1];
+ if (max.v == -1) is_empty = true;
+ beam = 0;
+ }
+}
+
+template
+__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam,
+ int beam_size, const T* val,
+ int* col, bool& firstStep,
+ bool& is_empty, Pair& max,
+ int dim, const int tid) {
+ if (beam > 0) {
+ int length = beam < beam_size ? beam : beam_size;
+ if (firstStep) {
+ firstStep = false;
+ GetTopK(topk, val, col, tid, dim, length);
+ } else {
+ for (int k = 0; k < MaxLength; k++) {
+ if (k < MaxLength - beam) {
+ topk[k] = topk[k + beam];
+ } else {
+ topk[k].set(-INFINITY, -1);
+ }
+ }
+ if (!is_empty) {
+ GetTopK(topk + MaxLength - beam, val, col, tid, dim, max,
+ length);
+ }
+ }
+
+ max = topk[MaxLength - 1];
+ if (max.v == -1) is_empty = true;
+ beam = 0;
+ }
+}
+
+template
+__device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid,
+ Pair topk[], T** topVal,
+ int** topIds, int& beam, int& k,
+ const int tid, const int warp) {
+ while (true) {
+ __syncthreads();
+ if (tid < BlockSize / 2) {
+ if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
+ maxid[tid] = tid + BlockSize / 2;
+ } else {
+ maxid[tid] = tid;
+ }
+ }
+ __syncthreads();
+ for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) {
+ if (tid < stride) {
+ if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) {
+ maxid[tid] = maxid[tid + stride];
+ }
+ }
+ __syncthreads();
+ }
+ __syncthreads();
+
+ if (tid == 0) {
+ **topVal = sh_topk[maxid[0]].v;
+ **topIds = sh_topk[maxid[0]].id;
+ (*topVal)++;
+ (*topIds)++;
+ }
+ if (tid == maxid[0]) beam++;
+ if (--k == 0) break;
+ __syncthreads();
+
+ if (tid == maxid[0]) {
+ if (beam < MaxLength) {
+ sh_topk[tid] = topk[beam];
+ }
+ }
+ if (maxid[0] / 32 == warp) {
+ if (__shfl(beam, (maxid[0]) % 32, 32) == MaxLength) break;
+ }
+ }
+}
+
+/**
+ * Each block compute one sample.
+ * In a block:
+ * 1. every thread get top MaxLength value;
+ * 2. merge to sh_topk, block reduce and get max value;
+ * 3. go to the second setp, until one thread's topk value is null;
+ * 4. go to the first setp, until get the topk value.
+ */
+template
+__global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
+ const T* src, int lds, int dim, int k) {
+ __shared__ Pair sh_topk[BlockSize];
+ __shared__ int maxid[BlockSize / 2];
+ const int tid = threadIdx.x;
+ const int warp = threadIdx.x / 32;
+ output += blockIdx.x * output_stride;
+ indices += blockIdx.x * k;
+
+ Pair topk[MaxLength];
+ int beam = MaxLength;
+ Pair max;
+ bool is_empty = false;
+ bool firststep = true;
+
+ for (int k = 0; k < MaxLength; k++) {
+ topk[k].set(-INFINITY, -1);
+ }
+ while (k) {
+ ThreadGetTopK(topk, beam, k,
+ src + blockIdx.x * lds, firststep,
+ is_empty, max, dim, tid);
+
+ sh_topk[tid] = topk[0];
+ BlockReduce(sh_topk, maxid, topk, &output,
+ &indices, beam, k, tid, warp);
+ }
+}
+
+template
+class TopkOpCUDAKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
+ "It must use GPUPlace.");
+ auto* input = ctx.Input("X");
+ auto* output = ctx.Output("Out");
+ auto* indices = ctx.Output("Indices");
+ size_t k = static_cast(ctx.Attr("k"));
+
+ const T* input_data = input->data();
+
+ T* output_data = output->mutable_data(ctx.GetPlace());
+ // FIXME(typhoonzero): data is always converted to type T?
+ int* indices_data = indices->mutable_data(ctx.GetPlace());
+
+ size_t input_height = input->dims()[0];
+ size_t input_width = input->dims()[1];
+ if (k > input_width) k = input_width;
+
+ // NOTE: pass lds and dim same to input width.
+ // NOTE: old matrix implementation of stride is different to eigen.
+ // TODO(typhoonzero): launch kernel on specified stream.
+ // TODO(typhoonzero): refine this kernel.
+ dim3 threads(256, 1);
+ dim3 grid(input_height, 1);
+
+ KeMatrixTopK<<>>(
+ output_data, output->dims()[1], indices_data, input_data, input_width,
+ input_width, int(k));
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+REGISTER_OP_GPU_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel);
diff --git a/paddle/operators/top_k_op.h b/paddle/operators/top_k_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..ef66acc1d569282a42be64b7a5e90f3fbdb20690
--- /dev/null
+++ b/paddle/operators/top_k_op.h
@@ -0,0 +1,76 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+using EigenMatrix = framework::EigenMatrix;
+
+template
+class TopkKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ // Get the top k elements of each row of input tensor
+ // FIXME: only deal with matrix(2d tensor).
+ auto* input = ctx.Input("X");
+ auto* output = ctx.Output("Out");
+ auto* indices = ctx.Output("Indices");
+ // k is determined by Attr
+ const size_t k = static_cast(ctx.Attr("k"));
+
+ T* output_data = output->mutable_data(ctx.GetPlace());
+ T* indices_data = indices->mutable_data(ctx.GetPlace());
+
+ auto eg_input = EigenMatrix::From(*input);
+
+ // reshape input to a flattern matrix(like flat_inner_dims)
+ framework::DDim inputdims = input->dims();
+ const size_t row = framework::product(
+ framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
+ const size_t col = inputdims[inputdims.size() - 1];
+ Eigen::DSizes flat2dims(row, col);
+ // NOTE: eigen shape doesn't affect paddle tensor.
+ eg_input.reshape(flat2dims);
+
+ for (size_t i = 0; i < row; i++) {
+ std::vector> vec;
+ for (size_t j = 0; j < col; j++) {
+ vec.push_back(std::pair(eg_input(i, j), j));
+ }
+
+ std::partial_sort(
+ vec.begin(), vec.begin() + k, vec.end(),
+ [](const std::pair