(lower_bound));
+ return *this;
+ }
+
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) {
diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md
index 8aa6728a95bc464ab8884986f0cec6c817d3303b..c762811dfc190b255e0a3389885a081ce8315caf 100644
--- a/paddle/framework/backward.md
+++ b/paddle/framework/backward.md
@@ -2,20 +2,20 @@
## Motivation
-In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
-
+In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the gradient operators/expressions together with the chain rule. Every forward network needs a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
+
## Backward Operator Registry
-A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs and output gradients and then calculate its input gradients.
+A backward network is built up with several backward operators. Backward operators take forward operators' inputs outputs, and output gradients and then calculate its input gradients.
| | forward operator | backward operator
| ---------------------- | ---------------- |------------------------- |
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
- In most cases, there is a one-to-one correspondence between forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
+ In most cases, there is a one-to-one correspondence between the forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
-For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro:
+For example, we have got a `mul_op`, and we can register its information and corresponding backward operator by the following macro:
```cpp
REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
@@ -27,17 +27,17 @@ REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
## Backward Opeartor Creating
-Given a certain forward operator, we can get its corresponding backward opeartor by calling:
+Given a certain forward operator, we can get its corresponding backward operator by calling:
```cpp
OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
-```
+```
The function `BuildGradOp` will sequentially execute following processes:
1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
-2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these are not necessary for gradient computing.
+2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
@@ -49,31 +49,31 @@ A backward network is a series of backward operators. The main idea of building
In our design, the network itself is also a kind of operator. So the operators contained by a big network may be some small network.
-given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`.
+given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`, `InputGradients`.
1. Op
- when the input forward network is a Op, return its gradient Operator Immediately.
+ when the input forward network is an Op, return its gradient Operator Immediately.
2. NetOp
- when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to forward NetOp.
+ when the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp.
- **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwirte their shared input variable.
+ **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwrite their shared input variable.
- 
+ 
- 1. shared variable in two operators.
+ 1. Shared variable in operators.
- Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator replace the overwirte links.
+ Share variable between operators or same input variable used in multiple operators leads to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively and add a generic add operator replace the overwrite links.
- 
+ 
- 2. replace shared variable gradient with `Add` Operator
+ 2. Replace shared variable's gradient with `Add` operator.
diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc
index 85b7de79743bb0390d66b8999f2e8342a51d14a9..fc3d508553c0e966978b28d58127bdbff10d45f1 100644
--- a/paddle/framework/ddim.cc
+++ b/paddle/framework/ddim.cc
@@ -283,5 +283,14 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim::DDim(std::initializer_list init_list) {
*this = make_ddim(init_list);
}
+
+DDim flatten_to_2d(const DDim& src, int num_col_dims) {
+ int rank = src.size();
+ return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
+ product(slice_ddim(src, num_col_dims, rank))});
+}
+
+DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h
index db30c523948b1d437615aa0e9bfecb5e25569296..ca29e7e8c7776de6adf3e3b0e8f11f0d4d8487c3 100644
--- a/paddle/framework/ddim.h
+++ b/paddle/framework/ddim.h
@@ -115,6 +115,12 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&);
+// Reshape a tensor to a matrix. The matrix's first dimension(column length)
+// will be the product of tensor's first `num_col_dims` dimensions.
+DDim flatten_to_2d(const DDim& src, int num_col_dims);
+
+DDim flatten_to_1d(const DDim& src);
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h
index 2d8d9ae10c56e0632414a5bbc754d35bfa9ce6a5..54bbeafcabdeeb1e2c1017c156b3512c83dada3a 100644
--- a/paddle/framework/eigen.h
+++ b/paddle/framework/eigen.h
@@ -63,20 +63,35 @@ struct EigenTensor {
template
-struct EigenMatrix : public EigenTensor {};
+struct EigenMatrix : public EigenTensor {
+ static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+
+ static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
+ int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+};
template
struct EigenVector : public EigenTensor {
// Flatten reshapes a Tensor into an EigenVector.
static typename EigenVector::Type Flatten(Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
static typename EigenVector::ConstType Flatten(const Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
};
diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc
index dc1957691b1a202826e10e84c21ac8874df9e378..bc4a2db32cfba66bef2c444e1f822e0d2a57b91e 100644
--- a/paddle/framework/eigen_test.cc
+++ b/paddle/framework/eigen_test.cc
@@ -108,5 +108,24 @@ TEST(Eigen, Matrix) {
}
}
+TEST(Eigen, MatrixReshape) {
+ Tensor t;
+ float* p = t.mutable_data({2, 3, 6, 4}, platform::CPUPlace());
+ for (int i = 0; i < 2 * 3 * 6 * 4; ++i) {
+ p[i] = static_cast(i);
+ }
+
+ EigenMatrix::Type em = EigenMatrix::Reshape(t, 2);
+
+ ASSERT_EQ(2 * 3, em.dimension(0));
+ ASSERT_EQ(6 * 4, em.dimension(1));
+
+ for (int i = 0; i < 2 * 3; i++) {
+ for (int j = 0; j < 6 * 4; j++) {
+ ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f);
+ }
+ }
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/images/duplicate_op2.graffle b/paddle/framework/images/duplicate_op2.graffle
index 2b658085d6a55d368c320051ba7f94ec2900f13c..ede3bca30ae17d5af52505fd94dc2f79b23b57e0 100644
Binary files a/paddle/framework/images/duplicate_op2.graffle and b/paddle/framework/images/duplicate_op2.graffle differ
diff --git a/paddle/framework/images/duplicate_op2.png b/paddle/framework/images/duplicate_op2.png
index c5588015d1450fd8c1bda3580680d884494868bb..4e872dc2caf3b0cbd0d5176f11a14801b538dc86 100644
Binary files a/paddle/framework/images/duplicate_op2.png and b/paddle/framework/images/duplicate_op2.png differ
diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h
index 154068fef69bc96edbd85b731fe8091b3b1ff823..568f4e89819c8345d8908634f6fa56f09483a763 100644
--- a/paddle/framework/lod_tensor.h
+++ b/paddle/framework/lod_tensor.h
@@ -18,8 +18,10 @@
#ifndef PADDLE_ONLY_CPU
#include
#include
+#include
#endif
+#include
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/enforce.h"
@@ -32,7 +34,8 @@ template
using Vector = std::vector;
#else
template
-using Vector = thrust::host_vector;
+using Vector = thrust::host_vector<
+ T, thrust::system::cuda::experimental::pinned_allocator>;
#endif
using LoD = std::vector>;
diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu
new file mode 100644
index 0000000000000000000000000000000000000000..1079a36a2e7b24f6f8a5bcbb296355567305a765
--- /dev/null
+++ b/paddle/framework/lod_tensor_test.cu
@@ -0,0 +1,52 @@
+/*
+ Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+#include
+#include
+#include "paddle/framework/lod_tensor.h"
+#include "paddle/platform/assert.h"
+
+#include
+
+__global__ void test(size_t* a, int size) {
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
+ i += blockDim.x * gridDim.x) {
+ a[i] *= 2;
+ }
+}
+
+TEST(LoDTensor, LoDInGPU) {
+ paddle::framework::Tensor tensor;
+ paddle::framework::LoDTensor lod_tensor;
+ paddle::platform::GPUPlace place(0);
+
+ paddle::framework::LoD src_lod;
+ src_lod.push_back(std::vector{0, 2, 4, 6, 8, 10, 12, 14});
+
+ tensor.Resize({14, 16});
+ tensor.mutable_data(place);
+
+ lod_tensor.set_lod(src_lod);
+ lod_tensor.set_tensor(&tensor);
+ CHECK_EQ(lod_tensor.lod_element(0, 2), 4);
+ CHECK_EQ(lod_tensor.lod_element(0, 4), 8);
+
+ auto lod = lod_tensor.lod();
+
+ test<<<1, 8>>>(lod[0].data(), lod[0].size());
+ cudaDeviceSynchronize();
+
+ for (size_t i = 0; i < src_lod[0].size(); ++i) {
+ CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
+ }
+}
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index 790cfc4746b1d34da413fa3c29a266f962c6dde6..e1e122091f7759b1a68f1f982bc2a35e8241f9f0 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -123,6 +123,15 @@ OperatorBase::OperatorBase(const std::string& type,
CheckAllInputOutputSet();
}
+std::vector OperatorBase::InputVars() const {
+ std::vector ret_val;
+ for (auto& o : outputs_) {
+ ret_val.reserve(ret_val.size() + o.second.size());
+ ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
+ }
+ return ret_val;
+}
+
std::vector OperatorBase::OutputVars(bool has_intermediate) const {
std::vector ret_val;
if (has_intermediate) {
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 9a98d4d3be0d1cb875d614b263f1e4365ede4113..4600b06009bcef7d0774d25b816aac4733f30795 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -94,11 +94,14 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
+
//! Get a input with argument's name described in `op_proto`
std::string Input(const std::string& name) const;
//! Get a input which has multiple variables.
const std::vector& Inputs(const std::string& name) const;
+ std::vector InputVars() const;
+
//! Get a output with argument's name described in `op_proto`
std::string Output(const std::string& name) const;
//! Get an output which has multiple variables.
@@ -311,9 +314,9 @@ class InferShapeContext {
}
template
- std::vector MultiOutput(const std::string& name) const {
+ std::vector MultiOutput(const std::string& name) const {
auto names = op_.Outputs(name);
- std::vector res;
+ std::vector res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index 19051db539d3e3ba1ff840800c8ff6dec1d02eed..4b5a2ae523f2f7fde5445f0534cd99969ad9d59e 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -43,6 +43,9 @@ class Tensor {
template
friend struct EigenTensor;
+ template
+ friend struct EigenMatrix;
+
template
friend struct EigenVector;
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index 5e32bfcac69370fee65f20df2b153f84ae7b5df5..642b53efc7095d25712ca324638f5fe9b8316c0c 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -151,5 +151,13 @@ inline const DDim& Tensor::dims() const { return dims_; }
inline int64_t Tensor::numel() const { return numel_; }
+template
+inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
+ Tensor res;
+ res.ShareDataWith(src);
+ res.Resize(flatten_to_2d(src.dims(), num_col_dims));
+ return res;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index 7db38d5caeebccf710334e854faf785ef0f64063..55302ea47120f420e952b26830c8ea4cbcce6435 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) {
}
#endif
}
+
+TEST(Tensor, ReshapeToMatrix) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+ Tensor src;
+ int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace());
+ for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
+ src_ptr[i] = i;
+ }
+ Tensor res = ReshapeToMatrix(src, 2);
+ ASSERT_EQ(res.dims()[0], 2 * 3);
+ ASSERT_EQ(res.dims()[1], 4 * 9);
+}
\ No newline at end of file
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.cpp b/paddle/gserver/layers/BatchNormBaseLayer.cpp
index 1ceaaaa206ee3cbc5421238574c7f310011ccaa5..f7a80e23e1bd49549bec57b360587adc6b423794 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.cpp
+++ b/paddle/gserver/layers/BatchNormBaseLayer.cpp
@@ -62,14 +62,18 @@ void BatchNormBaseLayer::calFeatureMapSize() {
const ImageConfig& conf = config_.inputs(0).image_conf();
imageH_ = inputLayers_[0]->getOutput().getFrameHeight();
imageW_ = inputLayers_[0]->getOutput().getFrameWidth();
+ imageD_ = inputLayers_[0]->getOutput().getFrameDepth();
+
+ if (0 == imageD_) imageD_ = conf.img_size_z();
if (imageH_ == 0 && imageW_ == 0) {
imageH_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
imageW_ = conf.img_size();
} else {
getOutput().setFrameHeight(imageH_);
getOutput().setFrameWidth(imageW_);
+ getOutput().setFrameDepth(imageD_);
}
- imgPixels_ = imageH_ * imageW_;
+ imgPixels_ = imageH_ * imageW_ * imageD_;
}
} // namespace paddle
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.h b/paddle/gserver/layers/BatchNormBaseLayer.h
index 230bafc31d96bbd49481a7ed135be6888688627e..e721d2d267a31cae46407673b8b1281e87055608 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.h
+++ b/paddle/gserver/layers/BatchNormBaseLayer.h
@@ -80,6 +80,7 @@ protected:
/// Height or width of input image feature.
/// Both of them are 1 if the input is fully-connected layer.
+ int imageD_;
int imageH_;
int imageW_;
/// Height * Width.
diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
index 44ba2c4b7d1562d2ce839b5f4b4de1af35e6925f..49a9540c0b6e36b59ed786287ff5c4569b69a6a5 100644
--- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp
+++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
@@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
}
void CudnnBatchNormLayer::reshape(int batchSize) {
- hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_);
+ hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_);
}
void CudnnBatchNormLayer::forward(PassType passType) {
@@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
EPS,
batchSize,
channels_,
- imageH_,
+ imageH_ * imageD_,
imageW_);
}
}
diff --git a/paddle/gserver/layers/DeConv3DLayer.cpp b/paddle/gserver/layers/DeConv3DLayer.cpp
index 1b59ed60c57fe3bbfa814befa8a63408a2621715..3eea638649e8ebfdd7efa18615977a9e1344c695 100644
--- a/paddle/gserver/layers/DeConv3DLayer.cpp
+++ b/paddle/gserver/layers/DeConv3DLayer.cpp
@@ -53,27 +53,27 @@ bool DeConv3DLayer::init(const LayerMap &layerMap,
size_t DeConv3DLayer::getSize() {
CHECK_NE(inputLayers_.size(), 0UL);
- outputH_.clear();
- outputW_.clear();
- outputD_.clear();
+ imgSizeW_.clear();
+ imgSizeH_.clear();
+ imgSizeD_.clear();
N_.clear();
NOut_.clear();
size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); ++i) {
- outputW_.push_back(
- imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true));
- outputH_.push_back(imageSize(
- imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
- outputD_.push_back(imageSize(
- imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
- NOut_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
- N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
+ imgSizeW_.push_back(
+ imageSize(outputW_[i], filterSize_[i], padding_[i], stride_[i], true));
+ imgSizeH_.push_back(imageSize(
+ outputH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
+ imgSizeD_.push_back(imageSize(
+ outputD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
+ NOut_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
+ N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
layerSize += NOut_[i] * numFilters_;
}
- getOutput().setFrameHeight(outputH_[0]);
- getOutput().setFrameWidth(outputW_[0]);
- getOutput().setFrameDepth(outputD_[0]);
+ getOutput().setFrameHeight(imgSizeH_[0]);
+ getOutput().setFrameWidth(imgSizeW_[0]);
+ getOutput().setFrameDepth(imgSizeD_[0]);
return layerSize;
}
@@ -103,9 +103,9 @@ void DeConv3DLayer::forward(PassType passType) {
}
colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(),
numFilters_,
- outputD_[i],
- outputH_[i],
- outputW_[i],
+ imgSizeD_[i],
+ imgSizeH_[i],
+ imgSizeW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
@@ -144,9 +144,9 @@ void DeConv3DLayer::backward(const UpdateCallback &callback) {
colBuf_->vol2Col(
getOutputGrad()->getData() + n * getOutputGrad()->getStride(),
numFilters_,
- outputD_[i],
- outputH_[i],
- outputW_[i],
+ imgSizeD_[i],
+ imgSizeH_[i],
+ imgSizeW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
diff --git a/paddle/gserver/layers/DetectionOutputLayer.cpp b/paddle/gserver/layers/DetectionOutputLayer.cpp
index 8ab838e191314ab25469631626c0b0564d7fffda..0cf0a92bf4bd8f9b8eba2016b2377d9dfb18c70a 100644
--- a/paddle/gserver/layers/DetectionOutputLayer.cpp
+++ b/paddle/gserver/layers/DetectionOutputLayer.cpp
@@ -139,7 +139,13 @@ void DetectionOutputLayer::forward(PassType passType) {
allDecodedBBoxes,
&allIndices);
- resetOutput(numKept, 7);
+ if (numKept > 0) {
+ resetOutput(numKept, 7);
+ } else {
+ MatrixPtr outV = getOutputValue();
+ outV = NULL;
+ return;
+ }
MatrixPtr outV = getOutputValue();
getDetectionOutput(confBuffer_->getData(),
numKept,
diff --git a/paddle/gserver/layers/DetectionUtil.cpp b/paddle/gserver/layers/DetectionUtil.cpp
index 3e61adc66e60c54250e4f323452aa13045310879..d83674f45a70212a8adc94a31ff58eb0e01baa00 100644
--- a/paddle/gserver/layers/DetectionUtil.cpp
+++ b/paddle/gserver/layers/DetectionUtil.cpp
@@ -469,7 +469,7 @@ size_t getDetectionIndices(
const size_t numClasses,
const size_t backgroundId,
const size_t batchSize,
- const size_t confThreshold,
+ const real confThreshold,
const size_t nmsTopK,
const real nmsThreshold,
const size_t keepTopK,
diff --git a/paddle/gserver/layers/DetectionUtil.h b/paddle/gserver/layers/DetectionUtil.h
index fe4f9f075e4cf011c97f68f49598a828d62327b3..641ed873b4c8645b6455e5ef5e63593e3005b770 100644
--- a/paddle/gserver/layers/DetectionUtil.h
+++ b/paddle/gserver/layers/DetectionUtil.h
@@ -275,7 +275,7 @@ size_t getDetectionIndices(
const size_t numClasses,
const size_t backgroundId,
const size_t batchSize,
- const size_t confThreshold,
+ const real confThreshold,
const size_t nmsTopK,
const real nmsThreshold,
const size_t keepTopK,
diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp
index 8318c8c519a4cec1610eadd28320ee5ce0b4147d..53433cef35a377a73f87b041fdcfadd848dd2ec9 100644
--- a/paddle/gserver/layers/MKLDNNFcLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp
@@ -77,24 +77,6 @@ void MKLDNNFcLayer::convertWeightsToPaddle() {
wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim);
}
-void MKLDNNFcLayer::convertOutputToOtherDevice() {
- copyOutputInfoToOtherDevice();
- // find other cpu device and reorder output to cpu device
- int cnt = 0;
- for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
- if (outputOtherDevice_[i].deviceId == CPU_DEVICE) {
- // fc cpu output value do not need convert
- // just share point
- outputOtherDevice_[i].value = output_.value;
- ++cnt;
- }
- }
-
- if (cnt > 1) {
- LOG(WARNING) << "should not have more than one CPU devie";
- }
-}
-
void MKLDNNFcLayer::reshape() {
const Argument& input = getInput(0, getPrev(0)->getDeviceId());
int batchSize = input.getBatchSize();
@@ -155,7 +137,10 @@ void MKLDNNFcLayer::resetFwd() {
// change original output value to mkldnn output value
output_.value = std::dynamic_pointer_cast(outVal_);
if (!outputIsOnlyMKLDNN()) {
- convertOutputToOtherDevice();
+ copyOutputInfoToOtherDevice();
+ // fc cpu output value do not need create convert
+ // just share point
+ getOutput(CPU_DEVICE).value->setData(output_.value->getData());
}
// create forward handle
@@ -235,13 +220,12 @@ void MKLDNNFcLayer::resetBwd() {
pipelineBwd_.push_back(*bwdWgt_);
/// backward data
- device = inputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE;
- const MatrixPtr& in = getInputGrad(0, device);
+ const MatrixPtr& in = inputLayers_[0]->getOutput().grad;
if (in == nullptr) {
return;
}
- if (getInput(0, device).getAllCount() > 1) {
- // TODO(TJ): use outputMaps_ ways when merge outgrad done
+ if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) {
+ // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done
} else {
inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc());
}
@@ -258,13 +242,21 @@ void MKLDNNFcLayer::resetBwd() {
pipelineBwd_.push_back(*bwdData_);
}
+void MKLDNNFcLayer::updateInputData() {
+ if (inputLayers_[0]->getType() != "data") {
+ return;
+ }
+ real* iData = getInputValue(0, CPU_DEVICE)->getData();
+ inVal_->setData(iData);
+}
+
void MKLDNNFcLayer::forward(PassType passType) {
Layer::forward(passType);
reshape();
{
REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str());
- syncInputValue();
+ updateInputData();
// just submit forward pipeline
stream_->submit(pipelineFwd_);
@@ -286,7 +278,6 @@ void MKLDNNFcLayer::backward(const UpdateCallback& callback) {
REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str());
resetBwd();
- syncOutputGrad();
// just sumbmit backward pipeline
stream_->submit(pipelineBwd_);
}
diff --git a/paddle/gserver/layers/MKLDNNFcLayer.h b/paddle/gserver/layers/MKLDNNFcLayer.h
index e138a6faf181c412949218458e7ecf800a0d6a07..4ad67a16e056a718c45a28babcf22a7cd571b15c 100644
--- a/paddle/gserver/layers/MKLDNNFcLayer.h
+++ b/paddle/gserver/layers/MKLDNNFcLayer.h
@@ -53,6 +53,8 @@ public:
void backward(const UpdateCallback& callback) override;
+ void updateInputData() override;
+
protected:
/**
* reshape the input image sizes
@@ -72,8 +74,6 @@ protected:
* only would be called when needed
*/
void resetBwd();
-
- void convertOutputToOtherDevice() override;
};
} // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h
index b983b833d510b823c5d4cff0b9390173e4cefc89..543364edceff684bdcd002a8f4f10e7ce5e6953b 100644
--- a/paddle/gserver/layers/MKLDNNLayer.h
+++ b/paddle/gserver/layers/MKLDNNLayer.h
@@ -114,10 +114,10 @@ public:
virtual void convertWeightsToPaddle() {}
/**
- * convert MKLDNN output to other device.
- * only support CPU device yet
+ * Update input value data when input layer is "data" type.
+ * Since the input value data address might be changed.
*/
- virtual void convertOutputToOtherDevice() {}
+ virtual void updateInputData() {}
/**
* print info about sizes
@@ -155,6 +155,7 @@ protected:
* copy base info and do not copy data value
*/
void copyOutputInfoToOtherDevice() {
+ int cnt = 0;
for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
outputOtherDevice_[i].setFrameHeight(output_.getFrameHeight());
outputOtherDevice_[i].setFrameWidth(output_.getFrameWidth());
@@ -163,6 +164,12 @@ protected:
outputOtherDevice_[i].subSequenceStartPositions =
output_.subSequenceStartPositions;
outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims;
+ if (outputOtherDevice_[i].deviceId == CPU_DEVICE) {
+ ++cnt;
+ }
+ }
+ if (cnt > 1) {
+ LOG(WARNING) << "should not have more than one CPU devie";
}
}
@@ -193,32 +200,6 @@ protected:
return outputOtherDevice_.size() == 0;
}
- /**
- * Sync input value data
- */
- void syncInputValue() {
- if (inputIsOnlyMKLDNN()) {
- return;
- }
- real* iData = getInputValue(0, CPU_DEVICE)->getData();
- // update input data
- // since it might be changed if this is after data layer
- inVal_->updateData(iData);
- }
-
- /**
- * Sync output grad data
- */
- void syncOutputGrad() {
- if (outputIsOnlyMKLDNN()) {
- return;
- }
-
- // update diff
- real* oDiff = getOutput(CPU_DEVICE).grad->getData();
- outGrad_->updateData(oDiff);
- }
-
/**
* Set deviceId of this layer.
*/
diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp
index 92cd61cdd515d5c693df086c9575a5f197c00cee..d7eee6eaf078dab8d48adc4c7ee758a433672ac6 100644
--- a/paddle/gserver/layers/SwitchOrderLayer.cpp
+++ b/paddle/gserver/layers/SwitchOrderLayer.cpp
@@ -24,10 +24,12 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap,
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
auto& img_conf = config_.inputs(0).image_conf();
+ size_t inD = img_conf.img_size_z();
size_t inH =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
size_t inW = img_conf.img_size();
size_t inC = img_conf.channels();
+ inH = inH * inD;
inDims_ = TensorShape({0, inC, inH, inW});
outDims_ = TensorShape(4);
@@ -64,9 +66,10 @@ void SwitchOrderLayer::setInDims() {
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
inDims_.setDim(0, batchSize);
-
+ int d = inputLayers_[0]->getOutput().getFrameDepth();
+ d = (d == 0 ? 1 : d);
int h = inputLayers_[0]->getOutput().getFrameHeight();
- if (h != 0) inDims_.setDim(2, h);
+ if (h != 0) inDims_.setDim(2, h * d);
int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w);
int totalCount = input->getElementCnt();
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index d1f3bc241fa621cb0070125980996e8627e40fd6..090bde7b203652e3ffb1662b8f5b8937885d2608 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -1703,6 +1703,55 @@ TEST(Layer, BatchNormalizationLayer) {
#endif
}
+void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) {
+ TestConfig config;
+ const int CHANNELS = 10;
+ const int IMG_SIZE = 16;
+ const int IMG_SIZE_Y = 8;
+ const int IMG_SIZE_Z = 8;
+ size_t size = CHANNELS * IMG_SIZE * IMG_SIZE_Y * IMG_SIZE_Z;
+ config.layerConfig.set_type(type);
+ config.layerConfig.set_size(size);
+ config.layerConfig.set_active_type("sigmoid");
+ config.biasSize = CHANNELS;
+ config.inputDefs.push_back({INPUT_DATA,
+ "layer_0",
+ /* dim= */ size,
+ /* paraSize= */ CHANNELS});
+
+ config.inputDefs.push_back({INPUT_DATA, "layer_1_running_mean", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+ config.inputDefs.push_back({INPUT_DATA, "layer_2_running_var", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+
+ LayerInputConfig* input = config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+
+ ImageConfig* img_conf = input->mutable_image_conf();
+ img_conf->set_channels(CHANNELS);
+ img_conf->set_img_size(IMG_SIZE);
+ img_conf->set_img_size_y(IMG_SIZE_Y);
+ img_conf->set_img_size_z(IMG_SIZE_Z);
+
+ testLayerGrad(config,
+ "batch_norm",
+ 64,
+ /* trans= */ trans,
+ useGpu,
+ /* useWeight */ true);
+}
+
+TEST(Layer, testBatchNorm3DLayer) {
+ testBatchNorm3DLayer("batch_norm", false, false);
+#ifndef PADDLE_ONLY_CPU
+ testBatchNorm3DLayer("batch_norm", false, true);
+ if (hl_get_cudnn_lib_version() >= int(4000)) {
+ testBatchNorm3DLayer("cudnn_batch_norm", false, true);
+ }
+#endif
+}
+
void testConvOperator(bool isDeconv) {
TestConfig config;
const int NUM_FILTERS = 16;
@@ -2253,26 +2302,27 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) {
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_stride_z(2);
- conv->set_img_size(IMAGE_SIZE);
- conv->set_img_size_y(IMAGE_SIZE_Y);
- conv->set_img_size_z(IMAGE_SIZE_Z);
- conv->set_output_x(imageSize(conv->img_size(),
+ conv->set_output_x(IMAGE_SIZE);
+ conv->set_output_y(IMAGE_SIZE_Y);
+ conv->set_output_z(IMAGE_SIZE_Z);
+
+ conv->set_img_size(imageSize(conv->output_x(),
conv->filter_size(),
conv->padding(),
conv->stride(),
true));
- conv->set_output_y(imageSize(conv->img_size_y(),
- conv->filter_size_y(),
- conv->padding_y(),
- conv->stride_y(),
- true));
- conv->set_output_z(imageSize(conv->img_size_z(),
- conv->filter_size_z(),
- conv->padding_z(),
- conv->stride_z(),
- true));
- config.layerConfig.set_size(conv->output_x() * conv->output_y() *
- conv->output_z() * NUM_FILTERS);
+ conv->set_img_size_y(imageSize(conv->output_y(),
+ conv->filter_size_y(),
+ conv->padding_y(),
+ conv->stride_y(),
+ true));
+ conv->set_img_size_z(imageSize(conv->output_z(),
+ conv->filter_size_z(),
+ conv->padding_z(),
+ conv->stride_z(),
+ true));
+ config.layerConfig.set_size(conv->img_size() * conv->img_size_y() *
+ conv->img_size_z() * NUM_FILTERS);
conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups());
config.inputDefs.push_back(
diff --git a/paddle/math/MKLDNNMatrix.cpp b/paddle/math/MKLDNNMatrix.cpp
index 0a355e2644cce572ce90ecf5c9d2a5b7b395bc61..c4063e5069854242d9f93886b66580385557ca73 100644
--- a/paddle/math/MKLDNNMatrix.cpp
+++ b/paddle/math/MKLDNNMatrix.cpp
@@ -33,14 +33,12 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) {
size_t width = cnts / dims[0];
m = Matrix::create(height, width, false, false);
}
-
CHECK(m) << " Matrix should not be empty";
+
CpuMatrixPtr cpuMatrix = std::dynamic_pointer_cast(m);
CHECK(cpuMatrix) << "Only support create from CPU matrix yet";
-
- CHECK_EQ(cnts, m->getElementCnt()) << "Count size does not match";
- return std::make_shared(
- m->getData(), m->getHeight(), m->getWidth(), pd);
+ CHECK_EQ(cpuMatrix->getElementCnt(), cnts) << "Count size does not match";
+ return std::make_shared(cpuMatrix, pd);
}
MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m,
@@ -138,7 +136,7 @@ void MKLDNNMatrix::downSpatial() {
mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr),
"could not create a memory primitive");
reset(result);
- set_data_handle(getData());
+ set_data_handle(data_);
}
} // namespace paddle
diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h
index e50f698b495713e6f15ab7a12a7ee7487662040f..eef3b429e6fa0087aeac3f5aed9dff983b06e826 100644
--- a/paddle/math/MKLDNNMatrix.h
+++ b/paddle/math/MKLDNNMatrix.h
@@ -30,11 +30,10 @@ typedef std::shared_ptr MKLDNNMatrixPtr;
*/
class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory {
public:
- MKLDNNMatrix(real* data,
- size_t height,
- size_t width,
- mkldnn::memory::primitive_desc pd)
- : CpuMatrix(data, height, width, false), mkldnn::memory(pd, data) {}
+ MKLDNNMatrix(CpuMatrixPtr m, mkldnn::memory::primitive_desc pd)
+ : CpuMatrix(m->getData(), m->getHeight(), m->getWidth(), false),
+ mkldnn::memory(pd, m->getData()),
+ m_(m) {}
~MKLDNNMatrix() {}
@@ -81,11 +80,29 @@ public:
void downSpatial();
/**
- * Update the memory data handle.
+ * set the memory data handle.
* Caution: This will not check the buffer size of the data,
* it should be coverd by user.
*/
- void updateData(void* data) { set_data_handle(data); }
+ void setData(real* data) {
+ set_data_handle(data);
+ CpuMatrix::setData(data);
+ m_.reset();
+ }
+
+ /**
+ * override Matrix::getData
+ * check data before return
+ */
+ real* getData() override {
+ CHECK_EQ((void*)data_, get_data_handle());
+ return data_;
+ }
+
+ const real* getData() const override {
+ CHECK_EQ((void*)data_, get_data_handle());
+ return data_;
+ }
/**
* Get primitive descriptor.
@@ -143,6 +160,10 @@ protected:
memory::format srcFmt,
memory::format dstFmt,
memory::dims dm);
+
+private:
+ // save the CpuMatrixPtr in case the buffer released outside
+ CpuMatrixPtr m_;
};
} // namespace paddle
diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..0ebefbab26ec8fdf316f852fbb7f6d9f3bbc48eb
--- /dev/null
+++ b/paddle/operators/concat_op.cc
@@ -0,0 +1,79 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/concat_op.h"
+#include
+
+namespace paddle {
+namespace operators {
+using framework::Tensor;
+
+class ConcatOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ auto ins = ctx.MultiInput("X");
+ auto *out = ctx.Output("Out");
+ size_t axis = static_cast(ctx.Attr("axis"));
+ size_t n = ins.size();
+
+ PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
+
+ auto out_dims = ins[0]->dims();
+ size_t in_zero_dims_size = out_dims.size();
+ for (size_t i = 1; i < n; i++) {
+ for (size_t j = 0; j < in_zero_dims_size; j++) {
+ if (j == axis) {
+ out_dims[axis] += ins[i]->dims()[j];
+ continue;
+ }
+ PADDLE_ENFORCE_EQ(out_dims[j], ins[i]->dims()[j],
+ "Input tensors should have the same "
+ "elements except the specify axis.")
+ }
+ }
+ out->Resize(out_dims);
+ }
+};
+
+class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ ConcatOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "the input tensors of concat operator.").AsDuplicable();
+ AddOutput("Out", "the output tensor of concat operator.");
+ AddComment(R"DOC(
+ Join the input tensors along with the axis.
+ Examples:
+ Input[0] = [[1,2],[3,4]]
+ Input[1] = [[5,6]]
+ axis = 0
+ Output = [[1,2],
+ [3,4],
+ [5,6]]
+ )DOC");
+ AddAttr("axis", "The axis which the inputs will be joined with.")
+ .SetDefault(0);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_WITHOUT_GRADIENT(concat, ops::ConcatOp, ops::ConcatOpMaker)
+REGISTER_OP_CPU_KERNEL(concat,
+ ops::ConcatKernel)
diff --git a/paddle/operators/concat_op.cu b/paddle/operators/concat_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..38fee7473dbb2ba97fe95b6632db7a1749cf3bbe
--- /dev/null
+++ b/paddle/operators/concat_op.cu
@@ -0,0 +1,19 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/concat_op.h"
+
+namespace ops = paddle::operators;
+// TODO(Yancey1989) Add GPU kernel
diff --git a/paddle/operators/concat_op.h b/paddle/operators/concat_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..f977054fdf8aa0164db726b94a21c57f770dd674
--- /dev/null
+++ b/paddle/operators/concat_op.h
@@ -0,0 +1,64 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+template
+class ConcatKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ auto ins = ctx.MultiInput("X");
+ auto* out = ctx.Output("Out");
+ int64_t axis = static_cast(ctx.Attr("axis"));
+ size_t n = ins.size();
+ size_t output_axis_dim = 0;
+ size_t before = 1, after = 1;
+ for (size_t i = 0; i < n; i++) {
+ output_axis_dim += ins[i]->dims()[axis];
+ }
+ auto& input_zero = ins[0];
+ for (int64_t i = 0; i < input_zero->dims().size(); i++) {
+ if (i == axis) {
+ continue;
+ }
+ if (i < axis) {
+ before *= input_zero->dims()[i];
+ } else {
+ after *= input_zero->dims()[i];
+ }
+ }
+ size_t output_offset = 0;
+ for (size_t i = 0; i < n; i++) {
+ auto& in = ins[i];
+ auto axis_dim = in->dims()[axis];
+ for (size_t j = 0; j < before; j++) {
+ size_t len = axis_dim * after * sizeof(T);
+ const T* src = in->data() + axis_dim * after * j;
+ T* out_data = out->mutable_data(platform::CPUPlace());
+ T* dest = out_data + output_offset + output_axis_dim * after * j;
+ memcpy(dest, src, len);
+ }
+ output_offset += axis_dim * after;
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc
index 186a33edcec88bd5e51091a524a778eeb27ad526..4f380388b108dc173d847f027ba5c9db387a87f8 100644
--- a/paddle/operators/math/im2col_test.cc
+++ b/paddle/operators/math/im2col_test.cc
@@ -119,4 +119,4 @@ TEST(math, im2col) {
#ifndef PADDLE_ONLY_CPU
testIm2col();
#endif
-}
\ No newline at end of file
+}
diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc
index 28a47cdff2e9b7a965ff9f99e787bb8315010823..710a56a0e8e2d17162d7d000df226f1537104eb9 100644
--- a/paddle/operators/mul_op.cc
+++ b/paddle/operators/mul_op.cc
@@ -25,18 +25,27 @@ class MulOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- auto dim0 = ctx.Input("X")->dims();
- auto dim1 = ctx.Input("Y")->dims();
- PADDLE_ENFORCE_EQ(dim0.size(), 2,
- "input X(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("X"));
- PADDLE_ENFORCE_EQ(dim1.size(), 2,
- "input Y(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("Y"));
+ auto x_dims = ctx.Input("X")->dims();
+ auto y_dims = ctx.Input("Y")->dims();
+ int x_num_col_dims = Attr("x_num_col_dims");
+ int y_num_col_dims = Attr("y_num_col_dims");
+
+ PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
+ "The rank of input tensor X(%s) should be larger than "
+ "`mul_op`'s `x_num_col_dims`.",
+ ctx.op().Input("X"));
+ PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
+ "The rank of input tensor Y(%s) should be larger than "
+ "`mul_op`'s `y_num_col_dims`.",
+ ctx.op().Input("Y"));
+
+ auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
+ auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
+
PADDLE_ENFORCE_EQ(
- dim0[1], dim1[0],
+ x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's height.");
- ctx.Output("Out")->Resize({dim0[0], dim1[1]});
+ ctx.Output("Out")->Resize({x_mat_dims[0], y_mat_dims[1]});
}
};
@@ -47,6 +56,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op");
+ AddAttr(
+ "x_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `X`,
+ in that case, tensors will be reshaped to a matrix. The matrix's first
+ dimension(column length) will be the product of tensor's last
+ `num_col_dims` dimensions, and the matrix's second dimension(row length)
+ will be the product of tensor's first `rank - num_col_dims` dimensions.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
+ AddAttr(
+ "y_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `Y`,
+ in that case, tensors will be reshaped to a matrix. Just like input `X`.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
AddComment(R"DOC(
Two Element Mul Operator.
@@ -70,10 +96,20 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output(framework::GradVarName("X"));
auto *y_grad = ctx.Output(framework::GradVarName("Y"));
- PADDLE_ENFORCE(x_dims[0] == out_dims[0],
- "Out@GRAD M X N must equal to X dims 0, M ");
- PADDLE_ENFORCE(y_dims[1] == out_dims[1],
- "Out@GRAD M X N must equal to Y dims 1, N ");
+
+ auto x_mat_dims =
+ framework::flatten_to_2d(x_dims, Attr("x_num_col_dims"));
+ auto y_mat_dims =
+ framework::flatten_to_2d(y_dims, Attr("y_num_col_dims"));
+
+ PADDLE_ENFORCE_EQ(
+ x_mat_dims[0], out_dims[0],
+ "The first dimension of Out@GRAD must equal to the first dimension of "
+ "the first operand.");
+ PADDLE_ENFORCE_EQ(
+ y_mat_dims[1], out_dims[1],
+ "The second dimension of Out@GRAD must equal to the second "
+ "dimension of the second operand.");
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h
index 05a79e13b3470e39a5ebd0394ba05629553a5075..3c01f868bda8cba488b3403df456d63d6b082fa6 100644
--- a/paddle/operators/mul_op.h
+++ b/paddle/operators/mul_op.h
@@ -1,7 +1,7 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
+ You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
@@ -31,13 +31,25 @@ template
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
- auto* x = context.Input("X");
- auto* y = context.Input("Y");
- auto* z = context.Output("Out");
+ const Tensor* x = context.Input("X");
+ const Tensor* y = context.Input("Y");
+ Tensor* z = context.Output("Out");
+ const Tensor x_matrix =
+ x->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *x, context.template Attr("x_num_col_dims"))
+ : *x;
+ const Tensor y_matrix =
+ y->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *y, context.template Attr("y_num_col_dims"))
+ : *y;
+
z->mutable_data(context.GetPlace());
auto* device_context =
const_cast(context.device_context_);
- math::matmul(*x, false, *y, false, 1, z, 0, device_context);
+ math::matmul(x_matrix, false, y_matrix, false, 1, z, 0,
+ device_context);
}
};
@@ -45,23 +57,39 @@ template
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
- auto* x = ctx.Input("X");
- auto* y = ctx.Input("Y");
- auto* dout = ctx.Input(framework::GradVarName("Out"));
+ int x_num_col_dims = ctx.template Attr("x_num_col_dims");
+ int y_num_col_dims = ctx.template Attr("y_num_col_dims");
+ const Tensor* x = ctx.Input("X");
+ const Tensor* y = ctx.Input("Y");
+ const Tensor x_matrix =
+ x->dims().size() > 2 ? framework::ReshapeToMatrix(*x, x_num_col_dims)
+ : *x;
+ const Tensor y_matrix =
+ y->dims().size() > 2 ? framework::ReshapeToMatrix(*y, y_num_col_dims)
+ : *y;
+ const Tensor* dout = ctx.Input(framework::GradVarName("Out"));
- auto* dx = ctx.Output(framework::GradVarName("X"));
- auto* dy = ctx.Output(framework::GradVarName("Y"));
+ Tensor* dx = ctx.Output(framework::GradVarName("X"));
+ Tensor* dy = ctx.Output(framework::GradVarName("Y"));
auto* device_context =
const_cast(ctx.device_context_);
if (dx) {
dx->mutable_data(ctx.GetPlace());
+ Tensor dx_matrix = dx->dims().size() > 2 ? framework::ReshapeToMatrix(
+ *dx, x_num_col_dims)
+ : *dx;
// dx = dout * y'. dx: M x K, dout : M x N, y : K x N
- math::matmul(*dout, false, *y, true, 1, dx, 0, device_context);
+ math::matmul(*dout, false, y_matrix, true, 1, &dx_matrix, 0,
+ device_context);
}
if (dy) {
dy->mutable_data(ctx.GetPlace());
+ Tensor dy_matrix = dy->dims().size() > 2 ? framework::ReshapeToMatrix(
+ *dy, y_num_col_dims)
+ : *dy;
// dy = x' * dout. dy K x N, dout : M x N, x : M x K
- math::matmul(*x, true, *dout, false, 1, dy, 0, device_context);
+ math::matmul(x_matrix, true, *dout, false, 1, &dy_matrix, 0,
+ device_context);
}
}
};
diff --git a/paddle/operators/rowwise_add_op.cc b/paddle/operators/rowwise_add_op.cc
index 30b4b404315a9f041e21d79b75fd06307e33f7f9..fa8f0ff1a858143af427b51025279c726f1628e0 100644
--- a/paddle/operators/rowwise_add_op.cc
+++ b/paddle/operators/rowwise_add_op.cc
@@ -25,14 +25,19 @@ class RowwiseAddOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- auto dim0 = ctx.Input("X")->dims();
- auto dim1 = ctx.Input("b")->dims();
-
- PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix");
- PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
- PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
- PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1");
- ctx.Output("Out")->Resize(ctx.Input("X")->dims());
+ auto x_dims = ctx.Input("X")->dims();
+ auto b_dims = ctx.Input("b")->dims();
+ PADDLE_ENFORCE_GT(
+ x_dims.size(), b_dims.size(),
+ "The rank of input `X` must be larger than the one of input `b`.");
+
+ int num_col_dims = x_dims.size() - b_dims.size();
+
+ PADDLE_ENFORCE_EQ(
+ framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
+ "The width of two operands must be same");
+ PADDLE_ENFORCE_EQ(ctx.OutputSize("Out"), 1, "The output size must be 1");
+ ctx.Output("Out")->Resize(x_dims);
}
};
@@ -61,13 +66,20 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("b"), "b should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
- auto dims0 = ctx.Input("X")->dims();
- auto dims1 = ctx.Input("b")->dims();
- PADDLE_ENFORCE_EQ(1, dims1.size(), "b dims should be 1")
+ auto x_dims = ctx.Input("X")->dims();
+ auto b_dims = ctx.Input("b")->dims();
+ PADDLE_ENFORCE_GT(
+ x_dims.size(), b_dims.size(),
+ "The rank of input `X` must be larger than the one of input `b`.");
+
+ int num_col_dims = x_dims.size() - b_dims.size();
+ PADDLE_ENFORCE_EQ(
+ framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
+ "The width of two operands must be same");
auto *dx = ctx.Output(framework::GradVarName("X"));
auto *db = ctx.Output(framework::GradVarName("b"));
- if (dx) dx->Resize(dims0);
- if (db) db->Resize(dims1);
+ if (dx) dx->Resize(x_dims);
+ if (db) db->Resize(b_dims);
}
};
diff --git a/paddle/operators/rowwise_add_op.h b/paddle/operators/rowwise_add_op.h
index 4e926d9f2947f37b71e81c0fa592b0c66b19c640..35774b940926f77167b8f19597027e74d3477e5b 100644
--- a/paddle/operators/rowwise_add_op.h
+++ b/paddle/operators/rowwise_add_op.h
@@ -33,10 +33,12 @@ class RowwiseAddKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output("Out");
out->mutable_data(context.GetPlace());
-
- auto input = EigenMatrix::From(*context.Input("X"));
- auto bias = EigenVector::From(*context.Input("b"));
- auto output = EigenMatrix::From(*out);
+ int num_col_dims = context.Input("X")->dims().size() -
+ context.Input("b")->dims().size();
+ auto input =
+ EigenMatrix::Reshape(*context.Input("X"), num_col_dims);
+ auto bias = EigenVector::Flatten(*context.Input("b"));
+ auto output = EigenMatrix::Reshape(*out, num_col_dims);
const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size;
@@ -54,12 +56,15 @@ class RowwiseAddGradKernel : public framework::OpKernel {
auto* dout = context.Input(framework::GradVarName("Out"));
auto* dx = context.Output(framework::GradVarName("X"));
auto* db = context.Output(framework::GradVarName("b"));
+ int num_col_dims = context.Input("X")->dims().size() -
+ context.Input("b")->dims().size();
- auto out_grad = EigenMatrix::From(*dout);
+ auto out_grad = EigenMatrix::Reshape(*dout, num_col_dims);
auto place = context.GetEigenDevice();
+
if (dx) {
dx->mutable_data(context.GetPlace());
- EigenMatrix::From(*dx).device(place) = out_grad;
+ EigenMatrix::Reshape(*dx, num_col_dims).device(place) = out_grad;
}
if (db) {
diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..5805826ee8a555ca6dfc1ca81feaadffea9e1012
--- /dev/null
+++ b/paddle/operators/sum_op.cc
@@ -0,0 +1,73 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/sum_op.h"
+#include
+
+namespace paddle {
+namespace operators {
+using framework::Tensor;
+
+class SumOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ auto ins = ctx.MultiInput("X");
+ auto *out = ctx.Output("Out");
+ int N = ins.size();
+
+ auto in_dim = ins[0]->dims();
+
+ PADDLE_ENFORCE_GT(N, 1, "Input tensors count should > 1.");
+ for (int i = 1; i < N; i++) {
+ auto dim = ins[i]->dims();
+ PADDLE_ENFORCE(in_dim == dim, "Input tensors must have same shape");
+ }
+ out->Resize(in_dim);
+ }
+};
+
+class SumOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ SumOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "the input tensors of sum operator.").AsDuplicable();
+ AddOutput("Out", "the output tensor of sum operator.");
+ AddComment(R"DOC(
+ Sum the input tensors.
+ )DOC");
+ }
+};
+
+class SumGradOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ auto outputs = ctx.MultiOutput(framework::GradVarName("X"));
+ auto dims = ctx.Input(framework::GradVarName("Out"))->dims();
+ for (auto output : outputs) {
+ output->Resize(dims);
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(sum, ops::SumOp, ops::SumOpMaker, sum_grad, ops::SumGradOp);
+REGISTER_OP_CPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_CPU_KERNEL(sum_grad,
+ ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.cu b/paddle/operators/sum_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..a465cf3659ba7c51338abadfc62962fb6755a39d
--- /dev/null
+++ b/paddle/operators/sum_op.cu
@@ -0,0 +1,18 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/sum_op.h"
+
+namespace ops = paddle::operators;
+REGISTER_OP_GPU_KERNEL(sum, ops::SumKernel);
+REGISTER_OP_GPU_KERNEL(sum_grad,
+ ops::SumGradKernel);
diff --git a/paddle/operators/sum_op.h b/paddle/operators/sum_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..0b1e9ebaa38d455fb5e3ce8c1a39cbbcdad9a940
--- /dev/null
+++ b/paddle/operators/sum_op.h
@@ -0,0 +1,65 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+template
+using EigenVector = framework::EigenVector;
+
+template
+class SumKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto ins = context.MultiInput("X");
+ auto* out = context.Output("Out");
+ out->mutable_data(context.GetPlace());
+
+ auto place = context.GetEigenDevice();
+ auto result = EigenVector::Flatten(*out);
+
+ int N = ins.size();
+ auto in = EigenVector::Flatten(*(ins[0]));
+ result.device(place) = in;
+ for (int i = 1; i < N; i++) {
+ auto in = EigenVector::Flatten(*(ins[i]));
+ result.device(place) = result + in;
+ }
+ }
+};
+
+template
+class SumGradKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& context) const override {
+ auto* input = context.Input(framework::GradVarName("Out"));
+ auto outs = context.MultiOutput(framework::GradVarName("X"));
+ for (auto out : outs) {
+ out->mutable_data(context.GetPlace());
+ }
+
+ auto place = context.GetEigenDevice();
+ auto in = EigenVector::Flatten(*input);
+ for (auto out : outs) {
+ auto result = EigenVector::Flatten(*out);
+ result.device(place) = in;
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..38d2f0a09aec751734864947a2f3cfa20107e22f
--- /dev/null
+++ b/paddle/operators/top_k_op.cc
@@ -0,0 +1,67 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/top_k_op.h"
+
+namespace paddle {
+namespace operators {
+
+class TopkOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
+ "Input of TopkOP must be initialized.");
+ auto *input = ctx.Input("X");
+ const int k = static_cast(ctx.Attr("k"));
+
+ PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
+ PADDLE_ENFORCE_GE(input->dims().size(), 1, "input must have >= 1d shape");
+ PADDLE_ENFORCE_GE(input->dims()[input->dims().size() - 1], k,
+ "input must have >= k columns");
+
+ framework::DDim dims = input->dims();
+ dims[dims.size() - 1] = k;
+ ctx.Output("Out")->Resize(dims);
+ ctx.Output("Indices")->Resize(dims);
+ }
+};
+
+class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ TopkOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "The input of Topk op");
+ AddOutput("Out", "The output tensor of Topk op");
+ AddOutput("Indices", "The indices of Topk elements of input");
+ AddComment(
+ R"DOC(If the input is a vector (1d tensor), finds the k largest entries in the vector and outputs their values and indices as vectors. Thus values[j] is the j-th largest entry in input, and its index is indices[j].
+
+ For matrices, computes the top k entries in each row. )DOC");
+ AddAttr("k",
+ "Number of top elements to look for along the last "
+ "dimension (along each row for matrices).")
+ .SetDefault(1);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_WITHOUT_GRADIENT(top_k, ops::TopkOp, ops::TopkOpMaker);
+REGISTER_OP_CPU_KERNEL(top_k,
+ ops::TopkKernel);
diff --git a/paddle/operators/top_k_op.cu b/paddle/operators/top_k_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..afe4d149c53819c45e20353bc9d16393f3f61e0f
--- /dev/null
+++ b/paddle/operators/top_k_op.cu
@@ -0,0 +1,318 @@
+/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/framework/op_registry.h"
+#include "paddle/platform/assert.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template
+struct Pair {
+ __device__ __forceinline__ Pair() {}
+ __device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
+
+ __device__ __forceinline__ void set(T value, int id) {
+ v = value;
+ id = id;
+ }
+
+ __device__ __forceinline__ void operator=(const Pair& in) {
+ v = in.v;
+ id = in.id;
+ }
+
+ __device__ __forceinline__ bool operator<(const T value) const {
+ return (v < value);
+ }
+
+ __device__ __forceinline__ bool operator<(const Pair& in) const {
+ return (v < in.v) || ((v == in.v) && (id > in.id));
+ }
+
+ __device__ __forceinline__ bool operator>(const Pair& in) const {
+ return (v > in.v) || ((v == in.v) && (id < in.id));
+ }
+
+ T v;
+ int id;
+};
+
+template
+__device__ __forceinline__ void AddTo(Pair topk[], const Pair& p,
+ int beam_size) {
+ for (int k = beam_size - 2; k >= 0; k--) {
+ if (topk[k] < p) {
+ topk[k + 1] = topk[k];
+ } else {
+ topk[k + 1] = p;
+ return;
+ }
+ }
+ topk[0] = p;
+}
+
+template
+__device__ __forceinline__ void AddTo(Pair topk[], const Pair& p) {
+ for (int k = beam_size - 2; k >= 0; k--) {
+ if (topk[k] < p) {
+ topk[k + 1] = topk[k];
+ } else {
+ topk[k + 1] = p;
+ return;
+ }
+ }
+ topk[0] = p;
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* src, int idx,
+ int dim, int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < src[idx]) {
+ Pair tmp(src[idx], idx);
+ AddTo(topk, tmp, beam_size);
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* src, int idx,
+ int dim, const Pair& max,
+ int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < src[idx]) {
+ Pair tmp(src[idx], idx);
+ if (tmp < max) {
+ AddTo(topk, tmp, beam_size);
+ }
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* val, int* col,
+ int idx, int dim, int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < val[idx]) {
+ Pair tmp(val[idx], col[idx]);
+ AddTo(topk, tmp, beam_size);
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void GetTopK(Pair topk[], const T* val, int* col,
+ int idx, int dim, const Pair& max,
+ int beam_size) {
+ while (idx < dim) {
+ if (topk[beam_size - 1] < val[idx]) {
+ Pair tmp(val[idx], col[idx]);
+ if (tmp < max) {
+ AddTo(topk, tmp, beam_size);
+ }
+ }
+ idx += BlockSize;
+ }
+}
+
+template
+__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam,
+ int beam_size, const T* src,
+ bool& firstStep, bool& is_empty,
+ Pair& max, int dim,
+ const int tid) {
+ if (beam > 0) {
+ int length = beam < beam_size ? beam : beam_size;
+ if (firstStep) {
+ firstStep = false;
+ GetTopK(topk, src, tid, dim, length);
+ } else {
+ for (int k = 0; k < MaxLength; k++) {
+ if (k < MaxLength - beam) {
+ topk[k] = topk[k + beam];
+ } else {
+ topk[k].set(-INFINITY, -1);
+ }
+ }
+ if (!is_empty) {
+ GetTopK(topk + MaxLength - beam, src, tid, dim, max,
+ length);
+ }
+ }
+
+ max = topk[MaxLength - 1];
+ if (max.v == -1) is_empty = true;
+ beam = 0;
+ }
+}
+
+template
+__device__ __forceinline__ void ThreadGetTopK(Pair topk[], int& beam,
+ int beam_size, const T* val,
+ int* col, bool& firstStep,
+ bool& is_empty, Pair& max,
+ int dim, const int tid) {
+ if (beam > 0) {
+ int length = beam < beam_size ? beam : beam_size;
+ if (firstStep) {
+ firstStep = false;
+ GetTopK(topk, val, col, tid, dim, length);
+ } else {
+ for (int k = 0; k < MaxLength; k++) {
+ if (k < MaxLength - beam) {
+ topk[k] = topk[k + beam];
+ } else {
+ topk[k].set(-INFINITY, -1);
+ }
+ }
+ if (!is_empty) {
+ GetTopK(topk + MaxLength - beam, val, col, tid, dim, max,
+ length);
+ }
+ }
+
+ max = topk[MaxLength - 1];
+ if (max.v == -1) is_empty = true;
+ beam = 0;
+ }
+}
+
+template
+__device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid,
+ Pair topk[], T** topVal,
+ int** topIds, int& beam, int& k,
+ const int tid, const int warp) {
+ while (true) {
+ __syncthreads();
+ if (tid < BlockSize / 2) {
+ if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
+ maxid[tid] = tid + BlockSize / 2;
+ } else {
+ maxid[tid] = tid;
+ }
+ }
+ __syncthreads();
+ for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) {
+ if (tid < stride) {
+ if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) {
+ maxid[tid] = maxid[tid + stride];
+ }
+ }
+ __syncthreads();
+ }
+ __syncthreads();
+
+ if (tid == 0) {
+ **topVal = sh_topk[maxid[0]].v;
+ **topIds = sh_topk[maxid[0]].id;
+ (*topVal)++;
+ (*topIds)++;
+ }
+ if (tid == maxid[0]) beam++;
+ if (--k == 0) break;
+ __syncthreads();
+
+ if (tid == maxid[0]) {
+ if (beam < MaxLength) {
+ sh_topk[tid] = topk[beam];
+ }
+ }
+ if (maxid[0] / 32 == warp) {
+ if (__shfl(beam, (maxid[0]) % 32, 32) == MaxLength) break;
+ }
+ }
+}
+
+/**
+ * Each block compute one sample.
+ * In a block:
+ * 1. every thread get top MaxLength value;
+ * 2. merge to sh_topk, block reduce and get max value;
+ * 3. go to the second setp, until one thread's topk value is null;
+ * 4. go to the first setp, until get the topk value.
+ */
+template
+__global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
+ const T* src, int lds, int dim, int k) {
+ __shared__ Pair sh_topk[BlockSize];
+ __shared__ int maxid[BlockSize / 2];
+ const int tid = threadIdx.x;
+ const int warp = threadIdx.x / 32;
+ output += blockIdx.x * output_stride;
+ indices += blockIdx.x * k;
+
+ Pair topk[MaxLength];
+ int beam = MaxLength;
+ Pair max;
+ bool is_empty = false;
+ bool firststep = true;
+
+ for (int k = 0; k < MaxLength; k++) {
+ topk[k].set(-INFINITY, -1);
+ }
+ while (k) {
+ ThreadGetTopK(topk, beam, k,
+ src + blockIdx.x * lds, firststep,
+ is_empty, max, dim, tid);
+
+ sh_topk[tid] = topk[0];
+ BlockReduce(sh_topk, maxid, topk, &output,
+ &indices, beam, k, tid, warp);
+ }
+}
+
+template
+class TopkOpCUDAKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
+ "It must use GPUPlace.");
+ auto* input = ctx.Input("X");
+ auto* output = ctx.Output("Out");
+ auto* indices = ctx.Output("Indices");
+ size_t k = static_cast(ctx.Attr("k"));
+
+ const T* input_data = input->data();
+
+ T* output_data = output->mutable_data(ctx.GetPlace());
+ // FIXME(typhoonzero): data is always converted to type T?
+ int* indices_data = indices->mutable_data(ctx.GetPlace());
+
+ size_t input_height = input->dims()[0];
+ size_t input_width = input->dims()[1];
+ if (k > input_width) k = input_width;
+
+ // NOTE: pass lds and dim same to input width.
+ // NOTE: old matrix implementation of stride is different to eigen.
+ // TODO(typhoonzero): launch kernel on specified stream.
+ // TODO(typhoonzero): refine this kernel.
+ dim3 threads(256, 1);
+ dim3 grid(input_height, 1);
+
+ KeMatrixTopK<<>>(
+ output_data, output->dims()[1], indices_data, input_data, input_width,
+ input_width, int(k));
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+REGISTER_OP_GPU_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel);
diff --git a/paddle/operators/top_k_op.h b/paddle/operators/top_k_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..ef66acc1d569282a42be64b7a5e90f3fbdb20690
--- /dev/null
+++ b/paddle/operators/top_k_op.h
@@ -0,0 +1,76 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+#include
+#include
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+template