diff --git a/.travis.yml b/.travis.yml
index 3fadf8deb787bfb1d055496c5987798297f80bbf..e217c8f5a740ef5ab7315656ed7839ffa219c805 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -21,7 +21,6 @@ addons:
- python
- python-pip
- python2.7-dev
- - python-numpy
- python-wheel
- libboost-dev
- curl
@@ -35,8 +34,8 @@ before_install:
- if [[ "$JOB" == "check_style" ]]; then sudo ln -s /usr/bin/clang-format-3.8 /usr/bin/clang-format; fi
# Paddle is using protobuf 3.1 currently. Protobuf 3.2 breaks the compatibility. So we specify the python
# protobuf version.
- - pip install -r $TRAVIS_BUILD_DIR/python/requirements.txt
- - pip install wheel sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit LinkChecker
+ - sudo pip install -r $TRAVIS_BUILD_DIR/python/requirements.txt
+ - sudo pip install wheel sphinx==1.5.6 recommonmark sphinx-rtd-theme==0.1.9 virtualenv pre-commit LinkChecker
- curl https://glide.sh/get | bash
- eval "$(GIMME_GO_VERSION=1.8.3 gimme)"
- go get -u github.com/alecthomas/gometalinter
diff --git a/CMakeLists.txt b/CMakeLists.txt
index ad559672ad2f83a3d62cdf332b47c6cf1e730f70..08237cd850ae20c515a39c8783a18deaac431626 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -65,8 +65,8 @@ if(NOT CMAKE_BUILD_TYPE)
endif()
if(ANDROID)
- if(${CMAKE_SYSTEM_VERSION} VERSION_LESS "21")
- message(FATAL_ERROR "Unsupport standalone toolchains with Android API level lower than 21")
+ if(${CMAKE_SYSTEM_VERSION} VERSION_LESS "16")
+ message(FATAL_ERROR "Unsupport standalone toolchains with Android API level lower than 16")
endif()
set(WITH_GPU OFF CACHE STRING
diff --git a/doc/design/functions_operators_layers.md b/doc/design/functions_operators_layers.md
index 7a2e8fd0ace2e3f4462b15215de22c31e944b7cb..d23ba56b5773a36d448a99e4abdebc1475ed789c 100644
--- a/doc/design/functions_operators_layers.md
+++ b/doc/design/functions_operators_layers.md
@@ -86,12 +86,13 @@ def layer.fc(X):
We'd like to have Python bindings to operators in package `paddle.operator`, and Python compositions of operators in package `paddle.layer`. So we have the following concepts in above illustrative example:
-```
+
| C++ functions/functors | mul | add | | |
+|------------------------|--------------|--------------|-------------|----------|
| C++ operator class | mulOp | addOp | FCOp | |
| Python binding | operator.mul | operator.add | operator.fc | |
| Python function | | | | layer.fc |
-```
+
This is how we differentiate layer and operators in PaddlePaddle:
diff --git a/doc/design/ops/dist_train.md b/doc/design/ops/dist_train.md
new file mode 100644
index 0000000000000000000000000000000000000000..fa3c5d7990213cf2b0d236e66e592dd2699da876
--- /dev/null
+++ b/doc/design/ops/dist_train.md
@@ -0,0 +1,106 @@
+# Design Doc: Operation Graph Based Parameter Server
+
+## Abstract
+
+We propose an approach to implement the parameter server. In this
+approach, there is no fundamental difference between the trainer and
+the parameter server: they both run subgraphs, but subgraphs of
+different purposes.
+
+## Background
+
+The previous implementations of the parameter server does not run a
+subgraph. parameter initialization, optimizer computation, network
+communication and checkpointing are implemented twice on both the
+trainer and the parameter server.
+
+It would be great if we can write code once and use them on both the
+trainer and the parameter server: reduces code duplication and
+improves extensibility. Given that after the current refactor, we are
+representing everything as a computing graph on the
+trainer. Representing everything as a computing graph on the parameter
+server becomes a natural extension.
+
+## Design
+
+### Graph Converter
+
+The *graph converter* converts the user-defined operation (OP) graph
+into subgraphs to be scheduled on different nodes with the following
+steps:
+
+1. OP placement: the OPs will be placed on different nodes according
+ to heuristic that minimizes estimated total computation
+ time. Currently we will use a simple heuristic that puts parameter
+ varable on parameter server workers and everything else on trainer
+ workers.
+
+1. Add communication OPs to enable the communication between nodes.
+
+We will need these OPs: *Send*, *Recv*, *Enqueue*, *Dequeue*.
+
+Below is an example of converting the user defined graph to the
+subgraphs for the trainer and the parameter server:
+
+
+
+After converting:
+
+
+
+1. The parameter variable W and it's optimizer subgraph are placed on the parameter server.
+1. Operators are added to the subgraphs.
+ - *Send* sends data to the connected *Recv* operator. The
+ scheduler on the receive node will only schedule *Recv* operator
+ to run when the *Send* operator has ran (the *Send* OP will mark
+ the *Recv* OP runnable automatically).
+ - *Enueue* enqueues the input variable, it can block until space
+ become available in the queue.
+ - *Dequeue* outputs configurable numbers of tensors from the
+ queue. It will block until the queue have the required number of
+ tensors.
+
+
+### Benefits
+
+- Model parallelism become easier to implement: it's an extension to
+ the trainer - parameter server approach. we already have the
+ communication OPs, but need to extend the graph converter's
+ placement functionality.
+
+- User-defined optimizer is easier to add - user can now express it as
+ a subgraph.
+
+- No more duplication logic inside the trainer and the parameter
+ server mentioned in the background section.
+
+### Challenges
+
+- It might be hard for the graph converter to cut a general graph
+ (without any hint for which subgraph is the optimizer). We may need
+ to label which subgraph inside the OP graph is the optimizer.
+
+- It's important to balance the parameter shards of on multiple
+ parameter server. If a single parameter is very big (some
+ word-embedding, fully connected, softmax layer), we need to
+ automatically partition the single parameter onto different
+ parameter servers when possible (only element-wise optimizer depends
+ on the parameter variable).
+
+### Discussion
+
+- In the "Aync SGD" figure, the "W" variable on the parameter server
+ could be read and wrote concurrently, what is our locking strategy?
+ E.g., each variable have a lock cpp method to be invoked by every
+ OP, or, have a lock OP.
+
+- Can the Enqueue OP be implemented under our current tensor design
+ (puts the input tensor into the queue tensor)?
+
+- *Dequeue* OP will have variable numbers of output (depends on the
+ `min_count` attribute), does our current design support it? (similar
+ question for the *Add* OP)
+
+
+### References:
+[1] [TensorFlow: Large-Scale Machine Learning on Heterogeneous Distributed Systems](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/45166.pdf)
diff --git a/doc/design/ops/src/dist-graph.graffle b/doc/design/ops/src/dist-graph.graffle
new file mode 100644
index 0000000000000000000000000000000000000000..941399c6ced8d5f65b6c595522b770c88259df4b
Binary files /dev/null and b/doc/design/ops/src/dist-graph.graffle differ
diff --git a/doc/design/ops/src/dist-graph.png b/doc/design/ops/src/dist-graph.png
new file mode 100644
index 0000000000000000000000000000000000000000..3546b09f1c2ee3e4f60f519d5e47f823f08051a7
Binary files /dev/null and b/doc/design/ops/src/dist-graph.png differ
diff --git a/doc/design/ops/src/local-graph.graffle b/doc/design/ops/src/local-graph.graffle
new file mode 100644
index 0000000000000000000000000000000000000000..19e509bd9af3c1e9a3f5e0f16ddd281457a339c5
Binary files /dev/null and b/doc/design/ops/src/local-graph.graffle differ
diff --git a/doc/design/ops/src/local-graph.png b/doc/design/ops/src/local-graph.png
new file mode 100644
index 0000000000000000000000000000000000000000..ada51200f793a9bb18911e7d63cfdb3244b967d7
Binary files /dev/null and b/doc/design/ops/src/local-graph.png differ
diff --git a/doc/howto/dev/new_op_cn.md b/doc/howto/dev/new_op_cn.md
index 58665e9f2b6299ec3959ed6858ab01d459f64dd8..e3892849abe21fc207d2fcbe4adc65184ba771f4 100644
--- a/doc/howto/dev/new_op_cn.md
+++ b/doc/howto/dev/new_op_cn.md
@@ -262,7 +262,7 @@ MulOp(const std::string &type, const framework::VariableNameMap &inputs,
- 生成库
- 无需修改 [`paddle/pybind/CMakeLists.txt`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/pybind/CMakeLists.txt)文件,`paddle/operators` 目录下新增的 `*_op.cc` 文件会被自动添加链接到生成的lib库中。
+ `paddle/operators` 目录下新增的 `*_op.cc` 文件会被自动添加链接到生成的lib库中。
## 实现单元测试
@@ -354,11 +354,7 @@ class TestMulGradOp(GradientChecker):
### 编译和执行单元测试
-单元测试编写完成之后,在[`python/paddle/v2/framework/tests/CMakeLists.txt`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/CMakeLists.txt)中添加以下内容,将单元测试加入工程:
-
-```
-py_test(test_mul_op SRCS test_mul_op.py)
-```
+`python/paddle/v2/framework/tests` 目录下新增的 `test_*.py` 单元测试会被自动加入工程进行编译。
请注意,**不同于Op的编译测试,运行单元测试测时需要编译整个工程**,并且编译时需要打开`WITH_TESTING`, 即`cmake paddle_dir -DWITH_TESTING=ON`。编译成功后,执行下面的命令来运行单元测试:
diff --git a/doc/howto/dev/write_docs_cn.rst b/doc/howto/dev/write_docs_cn.rst
index 36e5d420c986fc8d88eefee4aa221dba0a0480f2..731a63f945c29ba78538b3d71289b234e569354d 100644
--- a/doc/howto/dev/write_docs_cn.rst
+++ b/doc/howto/dev/write_docs_cn.rst
@@ -5,15 +5,13 @@
PaddlePaddle的文档包括英文文档 ``doc`` 和中文文档 ``doc_cn`` 两个部分。文档都是通过 `cmake`_ 驱动 `sphinx`_ 编译生成,生成后的文档分别存储在编译目录的 ``doc`` 和 ``doc_cn`` 两个子目录下。
-如何构建PaddlePaddle的文档
-==========================
+如何构建文档
+============
-PaddlePaddle的文档构建有直接构建和基于Docker构建两种方式,我们提供了一个构建脚本build_docs.sh来进行构建。
-PaddlePaddle文档需要准备的环境相对较复杂,所以我们推荐使用基于Docker来构建PaddlePaddle的文档。
+PaddlePaddle的文档构建有两种方式。
-
-使用Docker构建PaddlePaddle的文档
---------------------------------
+使用Docker构建
+--------------
使用Docker构建PaddlePaddle的文档,需要在系统里先安装好Docker工具包。Docker安装请参考 `Docker的官网 `_ 。安装好Docker之后可以使用源码目录下的脚本构建文档,即
@@ -21,58 +19,46 @@ PaddlePaddle文档需要准备的环境相对较复杂,所以我们推荐使
cd TO_YOUR_PADDLE_CLONE_PATH
cd paddle/scripts/tools/build_docs
- bash build_docs.sh with_docker
-
-编译完成后,会在当前目录生成两个子目录\:
-
-* doc 英文文档目录
-* doc_cn 中文文档目录
+ sh build_docs.sh
+编译完成之后,会在当前目录生成两个子目录\: doc(英文文档目录)和 doc_cn(中文文档目录)。
打开浏览器访问对应目录下的index.html即可访问本地文档。
-
-
-直接构建PaddlePaddle的文档
---------------------------
-
-因为PaddlePaddle的v2 api文档生成过程依赖于py_paddle Python包,用户需要首先确认py_paddle包已经安装。
-
-.. code-block:: bash
-
- python -c "import py_paddle"
-
-如果提示错误,那么用户需要在本地编译安装PaddlePaddle,请参考 `源码编译文档 `_ 。
-注意,用户在首次编译安装PaddlePaddle时,请将WITH_DOC选项关闭。在编译安装正确之后,请再次确认py_paddle包已经安装,即可进行下一步操作。
+直接构建
+--------
如果提示正确,可以执行以下命令编译生成文档,即
.. code-block:: bash
cd TO_YOUR_PADDLE_CLONE_PATH
- cd paddle/scripts/tools/build_docs
- bash build_docs.sh local
-
-编译完成之后,会在当前目录生成两个子目录\:
-
-* doc 英文文档目录
-* doc_cn 中文文档目录
+ mkdir -p build
+ cd build
+ cmake .. -DCMAKE_BUILD_TYPE=Debug -DWITH_GPU=OFF -DWITH_MKLDNN=OFF -DWITH_MKLML=OFF -DWITH_DOC=ON
+ make gen_proto_py
+ make paddle_docs paddle_docs_cn
+编译完成之后,会在当前目录生成两个子目录\: doc(英文文档目录)和 doc_cn(中文文档目录)。
打开浏览器访问对应目录下的index.html即可访问本地文档。
-如何书写PaddlePaddle的文档
-==========================
+如何书写文档
+============
PaddlePaddle文档使用 `sphinx`_ 自动生成,用户可以参考sphinx教程进行书写。
-如何更新www.paddlepaddle.org文档
-================================
+如何更新文档主题
+================
+
+PaddlePaddle文档主题在 `TO_YOUR_PADDLE_CLONE_PATH/doc_theme` 文件夹下,包含所有和前端网页设计相关的文件。
-开发者给PaddlePaddle代码增加的注释以PR的形式提交到github中,提交方式可参见 `贡献文档 `_ 。
+如何更新doc.paddlepaddle.org
+============================
+
+更新的文档以PR的形式提交到github中,提交方式参见 `贡献文档 `_ 。
目前PaddlePaddle的develop分支的文档是自动触发更新的,用户可以分别查看最新的 `中文文档 `_ 和
`英文文档 `_ 。
-
.. _cmake: https://cmake.org/
.. _sphinx: http://www.sphinx-doc.org/en/1.4.8/
diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt
index c0838d9b759110fd706577386d2c81bda6876223..3371962c635c3731f00a6af2a6e287ece33397cd 100644
--- a/paddle/framework/CMakeLists.txt
+++ b/paddle/framework/CMakeLists.txt
@@ -9,6 +9,7 @@ cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor)
+nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_test(variable_test SRCS variable_test.cc)
diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h
index cde3dfa1d3d19b1bee9fd23dad52ecbbe628c3a9..2b788a76cafe198abb9aed8ba842e37cc6ff73a6 100644
--- a/paddle/framework/attribute.h
+++ b/paddle/framework/attribute.h
@@ -45,7 +45,19 @@ class GreaterThanChecker {
public:
explicit GreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
void operator()(T& value) const {
- PADDLE_ENFORCE(value > lower_bound_, "larger_than check fail");
+ PADDLE_ENFORCE(value > lower_bound_, "larger_than check fails.");
+ }
+
+ private:
+ T lower_bound_;
+};
+
+template
+class EqualGreaterThanChecker {
+ public:
+ explicit EqualGreaterThanChecker(T lower_bound) : lower_bound_(lower_bound) {}
+ void operator()(T& value) const {
+ PADDLE_ENFORCE_GE(value, lower_bound_, "equal_larger_than check fails.");
}
private:
@@ -115,6 +127,11 @@ class TypedAttrChecker {
return *this;
}
+ TypedAttrChecker& EqualGreaterThan(const T& lower_bound) {
+ value_checkers_.push_back(EqualGreaterThanChecker(lower_bound));
+ return *this;
+ }
+
// we can add more common limits, like LessThan(), Between()...
TypedAttrChecker& SetDefault(const T& default_value) {
diff --git a/paddle/framework/backward.md b/paddle/framework/backward.md
index 8aa6728a95bc464ab8884986f0cec6c817d3303b..0a6d762bc8be5201ac196b4bc6107c06d07a31d7 100644
--- a/paddle/framework/backward.md
+++ b/paddle/framework/backward.md
@@ -2,20 +2,31 @@
## Motivation
-In Neural Network, the backpropagation algorithm follows the chain rule, so we need to compound the fundmental gradient operators/expressions together with chain rule . Every forward network need a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
-
-## Backward Operator Registry
+In Neural Network, many model is solved by the the backpropagation algorithm(known as BP) at present. Technically it caculates the gradient of the loss function, then distributed back through the networks. Follows the chain rule, so we need a module chains the gradient operators/expressions together with to construct the backward pass. Every forward network needs a backward network to construct the full computation graph, the operator/expression's backward pass will be generated respect to forward pass.
-A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs and output gradients and then calculate its input gradients.
+## Implementation
+
+In this design doc, we exported only one API for generating the backward pass.
+
+```c++
+std::unique_ptr Backward(const OperatorBase& forwardOp,
+ const std::unordered_set& no_grad_vars);
+```
+
+The implementation behind it can be divided into two parts, **Backward Operator Creating** and **Backward Operator Building**.
+
+### Backward Operator Registry
+
+A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs, and output gradients and then calculate its input gradients.
| | forward operator | backward operator
| ---------------------- | ---------------- |------------------------- |
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
- In most cases, there is a one-to-one correspondence between forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
+ In most cases, there is a one-to-one correspondence between the forward and backward operators. These correspondences are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and make operators pluggable, the registry mechanism is introduced.
-For example, we have got a `mul_op`, and we can register it's information and corresponding backward operator by the following macro:
+For example, we have got a `mul_op`, and we can register its information and corresponding backward operator by the following macro:
```cpp
REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
@@ -25,58 +36,65 @@ REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
`mul_grad` is the type of backward operator, and `MulOpGrad` is its class name.
-## Backward Opeartor Creating
+### Backward Opeartor Creating
-Given a certain forward operator, we can get its corresponding backward opeartor by calling:
+Given a certain forward operator, we can get its corresponding backward operator by calling:
```cpp
OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
-```
+```
The function `BuildGradOp` will sequentially execute following processes:
1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
-2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these are not necessary for gradient computing.
+2. Build two maps named `inputs` and `outputs` to temporary storage backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
4. Building backward operator with `inputs`, `outputs` and forward operator's attributes.
-## Backward Network Building
+### Backward Network Building
-A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and put them together.
-
-In our design, the network itself is also a kind of operator. So the operators contained by a big network may be some small network.
-
-given a forward network, it generates the backward network. We only care about the Gradients—`OutputGradients`,`InputGradients`.
+A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and append them together one by one. There is some corner case need to process specially.
1. Op
- when the input forward network is a Op, return its gradient Operator Immediately.
+ When the input forward network is an Op, return its gradient Operator Immediately. If all of its outputs are in no gradient set, then return a special `NOP`.
2. NetOp
- when the input forward network is a NetOp, it need to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to forward NetOp.
+ In our design, the network itself is also a kind of operator(**NetOp**). So the operators contained by a big network may be some small network. When the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp.
+
+3. RnnOp
+
+ RnnOp is a nested stepnet operator. Backward module need to recusively call `Backward` for every stepnet.
+
+4. Sharing Variables
+
+ **sharing variables**. As illustrated in the pictures, two operator's share the same variable name of W@GRAD, which will overwrite their sharing input variable.
+
+
+
- **shared variable**. As illustrated in the pictures, two operator's `Output` `Gradient` will overwirte their shared input variable.
+ pic 1. Sharing variables in operators.
-
-
+
- 1. shared variable in two operators.
+ Sharing variable between operators or same input variable used in multiple operators leads to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively and add a generic add operator to replace the overwrite links.
-
+
+
- Share variable between operators or same input variable used in multiple operators lead to a duplicate gradient variable. As demo show above, we need to rename gradient name recursively, and add a generic add operator replace the overwirte links.
+ pic 2. Replace sharing variable's gradient with `Add` operator.
-
-
+
- 2. replace shared variable gradient with `Add` Operator
+ Because our framework finds variables accord to their names, we need to rename the output links. We add a suffix of number to represent its position in clockwise.
-
+5. Part of Gradient is Zero.
+ In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implement, we insert a special `fillZeroLike` operator.
- Then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it.
+Follow these rules above, then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it.
diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc
index 85b7de79743bb0390d66b8999f2e8342a51d14a9..fc3d508553c0e966978b28d58127bdbff10d45f1 100644
--- a/paddle/framework/ddim.cc
+++ b/paddle/framework/ddim.cc
@@ -283,5 +283,14 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDim::DDim(std::initializer_list init_list) {
*this = make_ddim(init_list);
}
+
+DDim flatten_to_2d(const DDim& src, int num_col_dims) {
+ int rank = src.size();
+ return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
+ product(slice_ddim(src, num_col_dims, rank))});
+}
+
+DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h
index db30c523948b1d437615aa0e9bfecb5e25569296..ca29e7e8c7776de6adf3e3b0e8f11f0d4d8487c3 100644
--- a/paddle/framework/ddim.h
+++ b/paddle/framework/ddim.h
@@ -115,6 +115,12 @@ int arity(const DDim& ddim);
std::ostream& operator<<(std::ostream&, const DDim&);
+// Reshape a tensor to a matrix. The matrix's first dimension(column length)
+// will be the product of tensor's first `num_col_dims` dimensions.
+DDim flatten_to_2d(const DDim& src, int num_col_dims);
+
+DDim flatten_to_1d(const DDim& src);
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/eigen.h b/paddle/framework/eigen.h
index 2d8d9ae10c56e0632414a5bbc754d35bfa9ce6a5..54bbeafcabdeeb1e2c1017c156b3512c83dada3a 100644
--- a/paddle/framework/eigen.h
+++ b/paddle/framework/eigen.h
@@ -63,20 +63,35 @@ struct EigenTensor {
template
-struct EigenMatrix : public EigenTensor {};
+struct EigenMatrix : public EigenTensor {
+ static typename EigenMatrix::Type Reshape(Tensor& tensor, int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+
+ static typename EigenMatrix::ConstType Reshape(const Tensor& tensor,
+ int num_col_dims) {
+ int rank = tensor.dims_.size();
+ PADDLE_ENFORCE(num_col_dims > 0 && num_col_dims < rank,
+ "`num_col_dims` must be between (0, rank_of_tensor).");
+ return EigenMatrix::From(tensor,
+ flatten_to_2d(tensor.dims(), num_col_dims));
+ }
+};
template
struct EigenVector : public EigenTensor {
// Flatten reshapes a Tensor into an EigenVector.
static typename EigenVector::Type Flatten(Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
static typename EigenVector::ConstType Flatten(const Tensor& tensor) {
- return EigenVector::From(
- tensor, make_ddim({static_cast(product(tensor.dims_))}));
+ return EigenVector::From(tensor, {product(tensor.dims_)});
}
};
diff --git a/paddle/framework/eigen_test.cc b/paddle/framework/eigen_test.cc
index dc1957691b1a202826e10e84c21ac8874df9e378..bc4a2db32cfba66bef2c444e1f822e0d2a57b91e 100644
--- a/paddle/framework/eigen_test.cc
+++ b/paddle/framework/eigen_test.cc
@@ -108,5 +108,24 @@ TEST(Eigen, Matrix) {
}
}
+TEST(Eigen, MatrixReshape) {
+ Tensor t;
+ float* p = t.mutable_data({2, 3, 6, 4}, platform::CPUPlace());
+ for (int i = 0; i < 2 * 3 * 6 * 4; ++i) {
+ p[i] = static_cast(i);
+ }
+
+ EigenMatrix::Type em = EigenMatrix::Reshape(t, 2);
+
+ ASSERT_EQ(2 * 3, em.dimension(0));
+ ASSERT_EQ(6 * 4, em.dimension(1));
+
+ for (int i = 0; i < 2 * 3; i++) {
+ for (int j = 0; j < 6 * 4; j++) {
+ ASSERT_NEAR(i * 6 * 4 + j, em(i, j), 1e-6f);
+ }
+ }
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/images/duplicate_op2.graffle b/paddle/framework/images/duplicate_op2.graffle
index 2b658085d6a55d368c320051ba7f94ec2900f13c..5cec3bc64dbd44dc99e348485969f29bd128ceb1 100644
Binary files a/paddle/framework/images/duplicate_op2.graffle and b/paddle/framework/images/duplicate_op2.graffle differ
diff --git a/paddle/framework/images/duplicate_op2.png b/paddle/framework/images/duplicate_op2.png
index c5588015d1450fd8c1bda3580680d884494868bb..21cdd5cabf1b5203e1435a75b57770d2f702fa92 100644
Binary files a/paddle/framework/images/duplicate_op2.png and b/paddle/framework/images/duplicate_op2.png differ
diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h
index 154068fef69bc96edbd85b731fe8091b3b1ff823..568f4e89819c8345d8908634f6fa56f09483a763 100644
--- a/paddle/framework/lod_tensor.h
+++ b/paddle/framework/lod_tensor.h
@@ -18,8 +18,10 @@
#ifndef PADDLE_ONLY_CPU
#include
#include
+#include
#endif
+#include
#include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/enforce.h"
@@ -32,7 +34,8 @@ template
using Vector = std::vector;
#else
template
-using Vector = thrust::host_vector;
+using Vector = thrust::host_vector<
+ T, thrust::system::cuda::experimental::pinned_allocator>;
#endif
using LoD = std::vector>;
diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu
new file mode 100644
index 0000000000000000000000000000000000000000..1079a36a2e7b24f6f8a5bcbb296355567305a765
--- /dev/null
+++ b/paddle/framework/lod_tensor_test.cu
@@ -0,0 +1,52 @@
+/*
+ Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+#include
+#include
+#include "paddle/framework/lod_tensor.h"
+#include "paddle/platform/assert.h"
+
+#include
+
+__global__ void test(size_t* a, int size) {
+ for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
+ i += blockDim.x * gridDim.x) {
+ a[i] *= 2;
+ }
+}
+
+TEST(LoDTensor, LoDInGPU) {
+ paddle::framework::Tensor tensor;
+ paddle::framework::LoDTensor lod_tensor;
+ paddle::platform::GPUPlace place(0);
+
+ paddle::framework::LoD src_lod;
+ src_lod.push_back(std::vector{0, 2, 4, 6, 8, 10, 12, 14});
+
+ tensor.Resize({14, 16});
+ tensor.mutable_data(place);
+
+ lod_tensor.set_lod(src_lod);
+ lod_tensor.set_tensor(&tensor);
+ CHECK_EQ(lod_tensor.lod_element(0, 2), 4);
+ CHECK_EQ(lod_tensor.lod_element(0, 4), 8);
+
+ auto lod = lod_tensor.lod();
+
+ test<<<1, 8>>>(lod[0].data(), lod[0].size());
+ cudaDeviceSynchronize();
+
+ for (size_t i = 0; i < src_lod[0].size(); ++i) {
+ CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
+ }
+}
diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc
index 790cfc4746b1d34da413fa3c29a266f962c6dde6..e1e122091f7759b1a68f1f982bc2a35e8241f9f0 100644
--- a/paddle/framework/operator.cc
+++ b/paddle/framework/operator.cc
@@ -123,6 +123,15 @@ OperatorBase::OperatorBase(const std::string& type,
CheckAllInputOutputSet();
}
+std::vector OperatorBase::InputVars() const {
+ std::vector ret_val;
+ for (auto& o : outputs_) {
+ ret_val.reserve(ret_val.size() + o.second.size());
+ ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
+ }
+ return ret_val;
+}
+
std::vector OperatorBase::OutputVars(bool has_intermediate) const {
std::vector ret_val;
if (has_intermediate) {
diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h
index 9a98d4d3be0d1cb875d614b263f1e4365ede4113..4600b06009bcef7d0774d25b816aac4733f30795 100644
--- a/paddle/framework/operator.h
+++ b/paddle/framework/operator.h
@@ -94,11 +94,14 @@ class OperatorBase {
const VariableNameMap& Inputs() const { return inputs_; }
const VariableNameMap& Outputs() const { return outputs_; }
+
//! Get a input with argument's name described in `op_proto`
std::string Input(const std::string& name) const;
//! Get a input which has multiple variables.
const std::vector& Inputs(const std::string& name) const;
+ std::vector InputVars() const;
+
//! Get a output with argument's name described in `op_proto`
std::string Output(const std::string& name) const;
//! Get an output which has multiple variables.
@@ -311,9 +314,9 @@ class InferShapeContext {
}
template
- std::vector MultiOutput(const std::string& name) const {
+ std::vector MultiOutput(const std::string& name) const {
auto names = op_.Outputs(name);
- std::vector res;
+ std::vector res;
res.reserve(names.size());
std::transform(names.begin(), names.end(), std::back_inserter(res),
[&](const std::string& sub_name) {
diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h
index 643f875491724bf443bd7727391734377ee6180c..4b5a2ae523f2f7fde5445f0534cd99969ad9d59e 100644
--- a/paddle/framework/tensor.h
+++ b/paddle/framework/tensor.h
@@ -43,6 +43,9 @@ class Tensor {
template
friend struct EigenTensor;
+ template
+ friend struct EigenMatrix;
+
template
friend struct EigenVector;
@@ -78,6 +81,9 @@ class Tensor {
/*! Return the dimensions of the memory block. */
inline const DDim& dims() const;
+ /*! Return the numel of the memory block. */
+ inline int64_t numel() const;
+
/*! Resize the dimensions of the memory block. */
inline Tensor& Resize(const DDim& dims);
@@ -159,6 +165,12 @@ class Tensor {
/*! points to dimensions of memory block. */
DDim dims_;
+ /**
+ * A cache of the number of elements in a tensor.
+ * Would be 0 for an uninitialized tensor.
+ */
+ int64_t numel_;
+
/**
* @brief A PlaceHolder may be shared by more than one tensor.
*
diff --git a/paddle/framework/tensor_impl.h b/paddle/framework/tensor_impl.h
index 94f436294f350e2a39785a09959efb3b17bd00a5..642b53efc7095d25712ca324638f5fe9b8316c0c 100644
--- a/paddle/framework/tensor_impl.h
+++ b/paddle/framework/tensor_impl.h
@@ -24,7 +24,7 @@ inline void Tensor::check_memory_size() const {
PADDLE_ENFORCE_NOT_NULL(
holder_, "Tenosr holds no memory. Call Tensor::mutable_data first.");
PADDLE_ENFORCE_GE(
- holder_->size(), product(dims_) * sizeof(T) + offset_,
+ holder_->size(), numel() * sizeof(T) + offset_,
"Tensor's dims_ is out of bound. Call Tensor::mutable_data "
"first to re-allocate memory.\n"
"or maybe the required data-type mismatches the data already stored.");
@@ -54,11 +54,11 @@ inline T* Tensor::mutable_data(DDim dims, platform::Place place) {
template
inline T* Tensor::mutable_data(platform::Place place) {
static_assert(std::is_pod::value, "T must be POD");
- PADDLE_ENFORCE_GT(product(dims_), 0,
+ PADDLE_ENFORCE_GT(numel(), 0,
"Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first.");
/* some versions of boost::variant don't have operator!= */
- int64_t size = product(dims_) * sizeof(T);
+ int64_t size = numel() * sizeof(T);
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) {
if (platform::is_cpu_place(place)) {
@@ -97,7 +97,7 @@ inline void Tensor::CopyFrom(const Tensor& src,
auto dst_ptr = static_cast(mutable_data(dst_place));
- auto size = product(src.dims_) * sizeof(T);
+ auto size = src.numel() * sizeof(T);
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get(dst_place), dst_ptr,
@@ -131,7 +131,7 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
PADDLE_ENFORCE_LT(begin_idx, end_idx,
"Begin index must be less than end index.");
PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
- size_t base = product(dims_) / dims_[0];
+ size_t base = numel() / dims_[0];
Tensor dst;
dst.holder_ = holder_;
DDim dst_dims = dims_;
@@ -143,10 +143,21 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
inline Tensor& Tensor::Resize(const DDim& dims) {
dims_ = dims;
+ numel_ = product(dims_);
return *this;
}
inline const DDim& Tensor::dims() const { return dims_; }
+inline int64_t Tensor::numel() const { return numel_; }
+
+template
+inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
+ Tensor res;
+ res.ShareDataWith(src);
+ res.Resize(flatten_to_2d(src.dims(), num_col_dims));
+ return res;
+}
+
} // namespace framework
} // namespace paddle
diff --git a/paddle/framework/tensor_test.cc b/paddle/framework/tensor_test.cc
index 7db38d5caeebccf710334e854faf785ef0f64063..55302ea47120f420e952b26830c8ea4cbcce6435 100644
--- a/paddle/framework/tensor_test.cc
+++ b/paddle/framework/tensor_test.cc
@@ -262,3 +262,16 @@ TEST(Tensor, CopyFrom) {
}
#endif
}
+
+TEST(Tensor, ReshapeToMatrix) {
+ using namespace paddle::framework;
+ using namespace paddle::platform;
+ Tensor src;
+ int* src_ptr = src.mutable_data({2, 3, 4, 9}, CPUPlace());
+ for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
+ src_ptr[i] = i;
+ }
+ Tensor res = ReshapeToMatrix(src, 2);
+ ASSERT_EQ(res.dims()[0], 2 * 3);
+ ASSERT_EQ(res.dims()[1], 4 * 9);
+}
\ No newline at end of file
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.cpp b/paddle/gserver/layers/BatchNormBaseLayer.cpp
index 1ceaaaa206ee3cbc5421238574c7f310011ccaa5..f7a80e23e1bd49549bec57b360587adc6b423794 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.cpp
+++ b/paddle/gserver/layers/BatchNormBaseLayer.cpp
@@ -62,14 +62,18 @@ void BatchNormBaseLayer::calFeatureMapSize() {
const ImageConfig& conf = config_.inputs(0).image_conf();
imageH_ = inputLayers_[0]->getOutput().getFrameHeight();
imageW_ = inputLayers_[0]->getOutput().getFrameWidth();
+ imageD_ = inputLayers_[0]->getOutput().getFrameDepth();
+
+ if (0 == imageD_) imageD_ = conf.img_size_z();
if (imageH_ == 0 && imageW_ == 0) {
imageH_ = conf.has_img_size_y() ? conf.img_size_y() : conf.img_size();
imageW_ = conf.img_size();
} else {
getOutput().setFrameHeight(imageH_);
getOutput().setFrameWidth(imageW_);
+ getOutput().setFrameDepth(imageD_);
}
- imgPixels_ = imageH_ * imageW_;
+ imgPixels_ = imageH_ * imageW_ * imageD_;
}
} // namespace paddle
diff --git a/paddle/gserver/layers/BatchNormBaseLayer.h b/paddle/gserver/layers/BatchNormBaseLayer.h
index 230bafc31d96bbd49481a7ed135be6888688627e..e721d2d267a31cae46407673b8b1281e87055608 100644
--- a/paddle/gserver/layers/BatchNormBaseLayer.h
+++ b/paddle/gserver/layers/BatchNormBaseLayer.h
@@ -80,6 +80,7 @@ protected:
/// Height or width of input image feature.
/// Both of them are 1 if the input is fully-connected layer.
+ int imageD_;
int imageH_;
int imageW_;
/// Height * Width.
diff --git a/paddle/gserver/layers/CudnnBatchNormLayer.cpp b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
index 44ba2c4b7d1562d2ce839b5f4b4de1af35e6925f..49a9540c0b6e36b59ed786287ff5c4569b69a6a5 100644
--- a/paddle/gserver/layers/CudnnBatchNormLayer.cpp
+++ b/paddle/gserver/layers/CudnnBatchNormLayer.cpp
@@ -37,7 +37,7 @@ bool CudnnBatchNormLayer::init(const LayerMap& layerMap,
}
void CudnnBatchNormLayer::reshape(int batchSize) {
- hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_, imageW_);
+ hl_tensor_reshape(ioDesc_, batchSize, channels_, imageH_ * imageD_, imageW_);
}
void CudnnBatchNormLayer::forward(PassType passType) {
@@ -104,7 +104,7 @@ void CudnnBatchNormLayer::forward(PassType passType) {
EPS,
batchSize,
channels_,
- imageH_,
+ imageH_ * imageD_,
imageW_);
}
}
diff --git a/paddle/gserver/layers/DeConv3DLayer.cpp b/paddle/gserver/layers/DeConv3DLayer.cpp
index 1b59ed60c57fe3bbfa814befa8a63408a2621715..3eea638649e8ebfdd7efa18615977a9e1344c695 100644
--- a/paddle/gserver/layers/DeConv3DLayer.cpp
+++ b/paddle/gserver/layers/DeConv3DLayer.cpp
@@ -53,27 +53,27 @@ bool DeConv3DLayer::init(const LayerMap &layerMap,
size_t DeConv3DLayer::getSize() {
CHECK_NE(inputLayers_.size(), 0UL);
- outputH_.clear();
- outputW_.clear();
- outputD_.clear();
+ imgSizeW_.clear();
+ imgSizeH_.clear();
+ imgSizeD_.clear();
N_.clear();
NOut_.clear();
size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); ++i) {
- outputW_.push_back(
- imageSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i], true));
- outputH_.push_back(imageSize(
- imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
- outputD_.push_back(imageSize(
- imgSizeD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
- NOut_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
- N_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
+ imgSizeW_.push_back(
+ imageSize(outputW_[i], filterSize_[i], padding_[i], stride_[i], true));
+ imgSizeH_.push_back(imageSize(
+ outputH_[i], filterSizeY_[i], paddingY_[i], strideY_[i], true));
+ imgSizeD_.push_back(imageSize(
+ outputD_[i], filterSizeZ_[i], paddingZ_[i], strideZ_[i], true));
+ NOut_.push_back(imgSizeD_[i] * imgSizeH_[i] * imgSizeW_[i]);
+ N_.push_back(outputD_[i] * outputH_[i] * outputW_[i]);
CHECK(layerSize == 0 || N_[i] * size_t(numFilters_) == layerSize);
layerSize += NOut_[i] * numFilters_;
}
- getOutput().setFrameHeight(outputH_[0]);
- getOutput().setFrameWidth(outputW_[0]);
- getOutput().setFrameDepth(outputD_[0]);
+ getOutput().setFrameHeight(imgSizeH_[0]);
+ getOutput().setFrameWidth(imgSizeW_[0]);
+ getOutput().setFrameDepth(imgSizeD_[0]);
return layerSize;
}
@@ -103,9 +103,9 @@ void DeConv3DLayer::forward(PassType passType) {
}
colBuf_->col2Vol(outMat->getData() + n * outMat->getStride(),
numFilters_,
- outputD_[i],
- outputH_[i],
- outputW_[i],
+ imgSizeD_[i],
+ imgSizeH_[i],
+ imgSizeW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
@@ -144,9 +144,9 @@ void DeConv3DLayer::backward(const UpdateCallback &callback) {
colBuf_->vol2Col(
getOutputGrad()->getData() + n * getOutputGrad()->getStride(),
numFilters_,
- outputD_[i],
- outputH_[i],
- outputW_[i],
+ imgSizeD_[i],
+ imgSizeH_[i],
+ imgSizeW_[i],
filterSizeZ_[i],
filterSizeY_[i],
filterSize_[i],
diff --git a/paddle/gserver/layers/DetectionOutputLayer.cpp b/paddle/gserver/layers/DetectionOutputLayer.cpp
index 8ab838e191314ab25469631626c0b0564d7fffda..0cf0a92bf4bd8f9b8eba2016b2377d9dfb18c70a 100644
--- a/paddle/gserver/layers/DetectionOutputLayer.cpp
+++ b/paddle/gserver/layers/DetectionOutputLayer.cpp
@@ -139,7 +139,13 @@ void DetectionOutputLayer::forward(PassType passType) {
allDecodedBBoxes,
&allIndices);
- resetOutput(numKept, 7);
+ if (numKept > 0) {
+ resetOutput(numKept, 7);
+ } else {
+ MatrixPtr outV = getOutputValue();
+ outV = NULL;
+ return;
+ }
MatrixPtr outV = getOutputValue();
getDetectionOutput(confBuffer_->getData(),
numKept,
diff --git a/paddle/gserver/layers/DetectionUtil.cpp b/paddle/gserver/layers/DetectionUtil.cpp
index 3e61adc66e60c54250e4f323452aa13045310879..d83674f45a70212a8adc94a31ff58eb0e01baa00 100644
--- a/paddle/gserver/layers/DetectionUtil.cpp
+++ b/paddle/gserver/layers/DetectionUtil.cpp
@@ -469,7 +469,7 @@ size_t getDetectionIndices(
const size_t numClasses,
const size_t backgroundId,
const size_t batchSize,
- const size_t confThreshold,
+ const real confThreshold,
const size_t nmsTopK,
const real nmsThreshold,
const size_t keepTopK,
diff --git a/paddle/gserver/layers/DetectionUtil.h b/paddle/gserver/layers/DetectionUtil.h
index fe4f9f075e4cf011c97f68f49598a828d62327b3..641ed873b4c8645b6455e5ef5e63593e3005b770 100644
--- a/paddle/gserver/layers/DetectionUtil.h
+++ b/paddle/gserver/layers/DetectionUtil.h
@@ -275,7 +275,7 @@ size_t getDetectionIndices(
const size_t numClasses,
const size_t backgroundId,
const size_t batchSize,
- const size_t confThreshold,
+ const real confThreshold,
const size_t nmsTopK,
const real nmsThreshold,
const size_t keepTopK,
diff --git a/paddle/gserver/layers/Layer.h b/paddle/gserver/layers/Layer.h
index edef36194aabdb9c122ec3423deb036169a34d7c..4002a3d0747a86ab7b495ffe52247521831b71b8 100644
--- a/paddle/gserver/layers/Layer.h
+++ b/paddle/gserver/layers/Layer.h
@@ -49,6 +49,12 @@ struct LayerState {
};
typedef std::shared_ptr LayerStatePtr;
+/// Paddle device ID, MKLDNN is -2, CPU is -1
+enum PADDLE_DEVICE_ID {
+ MKLDNN_DEVICE = -2,
+ CPU_DEVICE = -1,
+};
+
/**
* @brief Base class for layer.
* Define necessary variables and functions for every layer.
@@ -59,11 +65,6 @@ protected:
LayerConfig config_;
/// whether to use GPU
bool useGpu_;
- /// Paddle device ID, MKLDNN is -2, CPU is -1
- enum PADDLE_DEVICE_ID {
- MKLDNN_DEVICE = -2,
- CPU_DEVICE = -1,
- };
/// Device Id. MKLDNN is -2, CPU is -1, and GPU is 0, 1, 2 ...
int deviceId_;
/// Input layers
diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp
index 8318c8c519a4cec1610eadd28320ee5ce0b4147d..f70343251ad4fbb99f9614618f6d1bff1174f15e 100644
--- a/paddle/gserver/layers/MKLDNNFcLayer.cpp
+++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp
@@ -14,7 +14,6 @@ limitations under the License. */
#include "MKLDNNFcLayer.h"
#include "paddle/utils/Logging.h"
-#include "paddle/utils/Stat.h"
using namespace mkldnn; // NOLINT
typedef memory::format format;
@@ -40,6 +39,8 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap,
oc_ = getSize();
oh_ = 1;
ow_ = 1;
+ ih_ = 1;
+ iw_ = 1;
// input size can not change in FC
iLayerSize_ = inputLayers_[0]->getSize();
@@ -77,111 +78,86 @@ void MKLDNNFcLayer::convertWeightsToPaddle() {
wgtVal_->reorderDataTo(wgtVal_, dstFmt, targetDim);
}
-void MKLDNNFcLayer::convertOutputToOtherDevice() {
- copyOutputInfoToOtherDevice();
- // find other cpu device and reorder output to cpu device
- int cnt = 0;
- for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
- if (outputOtherDevice_[i].deviceId == CPU_DEVICE) {
- // fc cpu output value do not need convert
- // just share point
- outputOtherDevice_[i].value = output_.value;
- ++cnt;
- }
- }
-
- if (cnt > 1) {
- LOG(WARNING) << "should not have more than one CPU devie";
- }
-}
+void MKLDNNFcLayer::reshape(
+ int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) {
+ reshapeInput(bs, ih, iw);
-void MKLDNNFcLayer::reshape() {
- const Argument& input = getInput(0, getPrev(0)->getDeviceId());
- int batchSize = input.getBatchSize();
- if (bs_ == batchSize) {
- return;
- }
- bs_ = batchSize;
- ih_ = input.getFrameHeight();
- iw_ = input.getFrameWidth();
- if (ih_ == 0) {
- ih_ = 1;
- }
- if (iw_ == 0) {
- iw_ = 1;
- }
CHECK_EQ(iLayerSize_, inputLayers_[0]->getSize());
- ic_ = iLayerSize_ / (ih_ * iw_);
- CHECK_EQ(size_t(ic_ * ih_ * iw_), iLayerSize_) << "not divisible";
- CHECK_EQ(size_t(oc_), getSize());
- printSizeInfo();
+ ic = iLayerSize_ / (ih * iw);
+ CHECK_EQ(size_t(ic * ih * iw), iLayerSize_) << "not divisible";
+ CHECK_EQ(size_t(oc), getSize());
- // reset output
- output_.setFrameHeight(oh_);
- output_.setFrameWidth(ow_);
- resetOutput(bs_, oc_);
+ reshapeOutput(oh, ow);
+ resizeOutput(bs, oc);
- // reset mkldnn forward
- resetFwd();
- needResetBwd_ = true;
-
- convertWeightsFromPaddle();
+ printSizeInfo();
}
-void MKLDNNFcLayer::resetFwd() {
+void MKLDNNFcLayer::resetFwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) {
+ pipeline.clear();
bool hasBias = biases_ && biases_->getW();
- const MatrixPtr& wgt = weight_->getW();
- const MatrixPtr& bias = hasBias ? biases_->getW() : nullptr;
- const MatrixPtr& out = output_.value;
+ const MatrixPtr& wgtVal = weight_->getW();
+ const MatrixPtr& biasVal = hasBias ? biases_->getW() : nullptr;
+ const MatrixPtr& outVal = output_.value;
if (inputIsOnlyMKLDNN()) {
- const MatrixPtr& in = getInputValue(0);
- inVal_ = std::dynamic_pointer_cast(in);
- CHECK(inVal_) << "Input should be MKLDNNMatrix";
+ const MatrixPtr& inVal = getInputValue(0);
+ in = std::dynamic_pointer_cast(inVal);
+ CHECK(in) << "Input should be MKLDNNMatrix";
} else {
CHECK_EQ(getPrev(0)->getDeviceId(), CPU_DEVICE) << "Only support CPU yet";
- const MatrixPtr& in = getInputValue(0, CPU_DEVICE);
- inVal_ = MKLDNNMatrix::create(
- in, memory::dims{bs_, ic_, ih_, iw_}, format::nchw, engine_);
- }
- inVal_->downSpatial();
- wgtVal_ = MKLDNNMatrix::create(
- wgt, memory::dims{oc_, ic_, ih_, iw_}, format::oihw, engine_);
- wgtVal_->downSpatial();
- biasVal_ =
- hasBias ? MKLDNNMatrix::create(bias, {oc_}, format::x, engine_) : nullptr;
- outVal_ = MKLDNNMatrix::create(out, {bs_, oc_}, format::nc, engine_);
+ const MatrixPtr& inVal = getInputValue(0, CPU_DEVICE);
+ in = MKLDNNMatrix::create(
+ inVal, memory::dims{bs_, ic_, ih_, iw_}, format::nchw, engine_);
+ }
+ in->downSpatial();
+ wgt = MKLDNNMatrix::create(
+ wgtVal, memory::dims{oc_, ic_, ih_, iw_}, format::oihw, engine_);
+ wgt->downSpatial();
+ bias = hasBias ? MKLDNNMatrix::create(biasVal, {oc_}, format::x, engine_)
+ : nullptr;
+ out = MKLDNNMatrix::create(outVal, {bs_, oc_}, format::nc, engine_);
// change original output value to mkldnn output value
- output_.value = std::dynamic_pointer_cast(outVal_);
+ output_.value = std::dynamic_pointer_cast(out);
if (!outputIsOnlyMKLDNN()) {
- convertOutputToOtherDevice();
+ // fc cpu output value do not need create convert
+ // just share point
+ getOutput(CPU_DEVICE).value->setData(output_.value->getData());
}
// create forward handle
prop_kind pk = prop_kind::forward;
fc_fwd::desc fwdDesc = hasBias ? fc_fwd::desc(pk,
- inVal_->getMemoryDesc(),
- wgtVal_->getMemoryDesc(),
- biasVal_->getMemoryDesc(),
- outVal_->getMemoryDesc())
+ in->getMemoryDesc(),
+ wgt->getMemoryDesc(),
+ bias->getMemoryDesc(),
+ out->getMemoryDesc())
: fc_fwd::desc(pk,
- inVal_->getMemoryDesc(),
- wgtVal_->getMemoryDesc(),
- outVal_->getMemoryDesc());
+ in->getMemoryDesc(),
+ wgt->getMemoryDesc(),
+ out->getMemoryDesc());
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
if (hasBias) {
- fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *biasVal_, *outVal_));
+ fwd_.reset(new fc_fwd(fwdPD, *in, *wgt, *bias, *out));
} else {
- fwd_.reset(new fc_fwd(fwdPD, *inVal_, *wgtVal_, *outVal_));
+ fwd_.reset(new fc_fwd(fwdPD, *in, *wgt, *out));
}
printValueFormatFlow();
- pipelineFwd_.clear();
- pipelineFwd_.push_back(*fwd_);
+ pipeline.push_back(*fwd_);
}
-void MKLDNNFcLayer::resetBwd() {
+void MKLDNNFcLayer::resetBwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) {
+ pipeline.clear();
if (!needResetBwd_) {
return;
}
@@ -190,8 +166,8 @@ void MKLDNNFcLayer::resetBwd() {
/// backward weight
CHECK(inVal_) << "Should have input value";
- const MatrixPtr& wgt = weight_->getWGrad();
- const MatrixPtr& bias = hasBias ? biases_->getWGrad() : nullptr;
+ const MatrixPtr& wgtGrad = weight_->getWGrad();
+ const MatrixPtr& biasGrad = hasBias ? biases_->getWGrad() : nullptr;
// TODO(TJ): merge outgrad
int device = outputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE;
@@ -202,101 +178,66 @@ void MKLDNNFcLayer::resetBwd() {
// for CPU device:
// fc do not need to convert from cpu device since output is always nc format
// only need create from cpu device
- const MatrixPtr& out = getOutput(device).grad;
- outGrad_ = MKLDNNMatrix::create(out, outVal_->getPrimitiveDesc());
- wgtGrad_ = MKLDNNMatrix::create(wgt, wgtVal_->getPrimitiveDesc());
- biasGrad_ = hasBias ? MKLDNNMatrix::create(bias, biasVal_->getPrimitiveDesc())
- : nullptr;
+ const MatrixPtr& outGrad = getOutput(device).grad;
+ out = MKLDNNMatrix::create(outGrad, outVal_->getPrimitiveDesc());
+ wgt = MKLDNNMatrix::create(wgtGrad, wgtVal_->getPrimitiveDesc());
+ bias = hasBias ? MKLDNNMatrix::create(biasGrad, biasVal_->getPrimitiveDesc())
+ : nullptr;
// create memory primitive desc
fc_fwd::desc fwdDesc = fc_fwd::desc(prop_kind::forward,
inVal_->getMemoryDesc(),
- wgtGrad_->getMemoryDesc(),
- outGrad_->getMemoryDesc());
+ wgt->getMemoryDesc(),
+ out->getMemoryDesc());
fc_fwd::primitive_desc fwdPD = fc_fwd::primitive_desc(fwdDesc, engine_);
fc_bwdWgt::desc bwdWgtDesc = hasBias
? fc_bwdWgt::desc(inVal_->getMemoryDesc(),
- wgtGrad_->getMemoryDesc(),
- biasGrad_->getMemoryDesc(),
- outGrad_->getMemoryDesc())
+ wgt->getMemoryDesc(),
+ bias->getMemoryDesc(),
+ out->getMemoryDesc())
: fc_bwdWgt::desc(inVal_->getMemoryDesc(),
- wgtGrad_->getMemoryDesc(),
- outGrad_->getMemoryDesc());
+ wgt->getMemoryDesc(),
+ out->getMemoryDesc());
fc_bwdWgt::primitive_desc bwdWgtPD =
fc_bwdWgt::primitive_desc(bwdWgtDesc, engine_, fwdPD);
if (hasBias) {
- bwdWgt_.reset(
- new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_, *biasGrad_));
+ bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *out, *wgt, *bias));
} else {
- bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *outGrad_, *wgtGrad_));
+ bwdWgt_.reset(new fc_bwdWgt(bwdWgtPD, *inVal_, *out, *wgt));
}
- pipelineBwd_.clear();
- pipelineBwd_.push_back(*bwdWgt_);
+ pipeline.push_back(*bwdWgt_);
/// backward data
- device = inputIsOnlyMKLDNN() ? MKLDNN_DEVICE : CPU_DEVICE;
- const MatrixPtr& in = getInputGrad(0, device);
- if (in == nullptr) {
+ const MatrixPtr& inGrad = inputLayers_[0]->getOutput().grad;
+ if (inGrad == nullptr) {
return;
}
- if (getInput(0, device).getAllCount() > 1) {
- // TODO(TJ): use outputMaps_ ways when merge outgrad done
+ if (getInput(0, MKLDNN_DEVICE).getAllCount() > 1) {
+ // TODO(TJ): use outputMaps_ ways to get the inGrad_ when merge outgrad done
} else {
- inGrad_ = MKLDNNMatrix::create(in, inVal_->getPrimitiveDesc());
+ in = MKLDNNMatrix::create(inGrad, inVal_->getPrimitiveDesc());
}
- fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(inVal_->getMemoryDesc(),
- wgtGrad_->getMemoryDesc(),
- outGrad_->getMemoryDesc());
+ fc_bwdData::desc bwdDataDesc = fc_bwdData::desc(
+ inVal_->getMemoryDesc(), wgt->getMemoryDesc(), out->getMemoryDesc());
fc_bwdData::primitive_desc bwdDataPD =
fc_bwdData::primitive_desc(bwdDataDesc, engine_, fwdPD);
CHECK(wgtVal_) << "Should have weight memory";
- bwdData_.reset(new fc_bwdData(bwdDataPD, *outGrad_, *wgtVal_, *inGrad_));
+ bwdData_.reset(new fc_bwdData(bwdDataPD, *out, *wgtVal_, *in));
printGradFormatFlow();
- pipelineBwd_.push_back(*bwdData_);
+ pipeline.push_back(*bwdData_);
}
-void MKLDNNFcLayer::forward(PassType passType) {
- Layer::forward(passType);
- reshape();
-
- {
- REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str());
- syncInputValue();
-
- // just submit forward pipeline
- stream_->submit(pipelineFwd_);
- }
-
- /* activation */ {
- REGISTER_TIMER_INFO("FwActTimer", getName().c_str());
- forwardActivation();
- }
+void MKLDNNFcLayer::updateInputData() {
+ inVal_->setData(getInputValue(0, CPU_DEVICE)->getData());
}
-void MKLDNNFcLayer::backward(const UpdateCallback& callback) {
- /* Do derivation */ {
- REGISTER_TIMER_INFO("BpActTimer", getName().c_str());
- backwardActivation();
- }
-
- {
- REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str());
- resetBwd();
-
- syncOutputGrad();
- // just sumbmit backward pipeline
- stream_->submit(pipelineBwd_);
- }
-
- {
- REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
- weight_->getParameterPtr()->incUpdate(callback);
- if (biases_ && biases_->getWGrad()) {
- biases_->getParameterPtr()->incUpdate(callback);
- }
+void MKLDNNFcLayer::updateWeights(const UpdateCallback& callback) {
+ weight_->getParameterPtr()->incUpdate(callback);
+ if (biases_ && biases_->getWGrad()) {
+ biases_->getParameterPtr()->incUpdate(callback);
}
}
} // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNFcLayer.h b/paddle/gserver/layers/MKLDNNFcLayer.h
index e138a6faf181c412949218458e7ecf800a0d6a07..3119f863496df092da13c08bf733f13c42e53780 100644
--- a/paddle/gserver/layers/MKLDNNFcLayer.h
+++ b/paddle/gserver/layers/MKLDNNFcLayer.h
@@ -45,35 +45,28 @@ public:
bool init(const LayerMap& layerMap,
const ParameterMap& parameterMap) override;
- void convertWeightsFromPaddle() override;
+ void reshape(
+ int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) override;
- void convertWeightsToPaddle() override;
+ void resetFwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) override;
- void forward(PassType passType) override;
+ void resetBwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) override;
- void backward(const UpdateCallback& callback) override;
+ void updateInputData() override;
-protected:
- /**
- * reshape the input image sizes
- * and reset output buffer size
- * and reset mkldnn forward
- */
- void reshape();
-
- /**
- * reset the forward primitve and memory
- * only would be called when input size changes
- */
- void resetFwd();
-
- /**
- * reset the backward primitve and memory for mkldnn fc
- * only would be called when needed
- */
- void resetBwd();
-
- void convertOutputToOtherDevice() override;
+ void updateWeights(const UpdateCallback& callback) override;
+
+ void convertWeightsFromPaddle() override;
+
+ void convertWeightsToPaddle() override;
};
} // namespace paddle
diff --git a/paddle/gserver/layers/MKLDNNLayer.h b/paddle/gserver/layers/MKLDNNLayer.h
index b983b833d510b823c5d4cff0b9390173e4cefc89..169679c8297542cac4a43f5a8e1af311ad9282df 100644
--- a/paddle/gserver/layers/MKLDNNLayer.h
+++ b/paddle/gserver/layers/MKLDNNLayer.h
@@ -19,6 +19,7 @@ limitations under the License. */
#include "MKLDNNBase.h"
#include "mkldnn.hpp"
#include "paddle/math/MKLDNNMatrix.h"
+#include "paddle/utils/Stat.h"
DECLARE_bool(use_mkldnn);
@@ -33,6 +34,8 @@ typedef std::shared_ptr MKLDNNLayerPtr;
*/
class MKLDNNLayer : public Layer {
protected:
+ // input value element count
+ size_t inputElemenCnt_;
// batch size
int bs_;
// input image channel, height and width
@@ -52,7 +55,7 @@ protected:
std::vector pipelineFwd_;
std::vector pipelineBwd_;
- // MKLDNNMatrixPtr
+ // MKLDNNMatrixPtr with internal format
MKLDNNMatrixPtr inVal_;
MKLDNNMatrixPtr inGrad_;
MKLDNNMatrixPtr outVal_;
@@ -65,6 +68,7 @@ protected:
public:
explicit MKLDNNLayer(const LayerConfig& config)
: Layer(config),
+ inputElemenCnt_(0),
bs_(0),
ic_(0),
ih_(0),
@@ -95,12 +99,104 @@ public:
if (!Layer::init(layerMap, parameterMap)) {
return false;
}
+ checkCPUOutputsNumber();
stream_.reset(new MKLDNNStream());
engine_ = CPUEngine::Instance().getEngine();
return true;
}
+ void forward(PassType passType) override {
+ passType_ = passType;
+
+ {
+ REGISTER_TIMER_INFO("mkldnn_FwdTimer", getName().c_str());
+ CHECK(!inputLayers_.empty());
+ copySeqInfoToOutputs();
+ size_t elemenCnt = inputLayers_[0]->getOutput().value->getElementCnt();
+ if (inputElemenCnt_ != elemenCnt) {
+ // reset when input total sizes changed, not only the batchsize
+ inputElemenCnt_ = elemenCnt;
+ reshape(bs_, ic_, ih_, iw_, oc_, oh_, ow_);
+ resetFwd(pipelineFwd_, inVal_, wgtVal_, biasVal_, outVal_);
+ convertWeightsFromPaddle();
+ needResetBwd_ = true;
+ }
+
+ if (inputLayers_[0]->getType() == "data") {
+ updateInputData();
+ }
+
+ stream_->submit(pipelineFwd_);
+ }
+
+ /* activation */ {
+ REGISTER_TIMER_INFO("FwActTimer", getName().c_str());
+ forwardActivation();
+ }
+ }
+
+ void backward(const UpdateCallback& callback) override {
+ /* Do derivation */ {
+ REGISTER_TIMER_INFO("BpActTimer", getName().c_str());
+ backwardActivation();
+ }
+
+ {
+ REGISTER_TIMER_INFO("mkldnn_bwdTimer", getName().c_str());
+ if (needResetBwd_) {
+ resetBwd(pipelineBwd_, inGrad_, wgtGrad_, biasGrad_, outGrad_);
+ needResetBwd_ = false;
+ }
+
+ stream_->submit(pipelineBwd_);
+ }
+
+ {
+ REGISTER_TIMER_INFO("WeightUpdate", getName().c_str());
+ updateWeights(callback);
+ }
+ }
+
+ /**
+ * reshape the input image sizes
+ * and reset output image and buffer size
+ * output channel can not be changed
+ */
+ virtual void reshape(
+ int& bs, int& ic, int& ih, int& iw, int oc, int& oh, int& ow) = 0;
+
+ /**
+ * reset the mkldnn forward primitve and memory
+ * only would be called when input size changes
+ */
+ virtual void resetFwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) = 0;
+
+ /**
+ * reset the mkldnn backward primitve and memory for mkldnn fc
+ * only would be called when needed
+ */
+ virtual void resetBwd(std::vector& pipeline,
+ MKLDNNMatrixPtr& in,
+ MKLDNNMatrixPtr& wgt,
+ MKLDNNMatrixPtr& bias,
+ MKLDNNMatrixPtr& out) = 0;
+
+ /**
+ * Update input value data when input layer is "data" type.
+ * Since the input value data address might be changed.
+ */
+ virtual void updateInputData() {}
+
+ /**
+ * Update weights and biases if necessary.
+ */
+ virtual void updateWeights(const UpdateCallback& callback) {}
+
/**
* convert weight from paddle format to mkldnn format
* weight_ will be override
@@ -114,10 +210,38 @@ public:
virtual void convertWeightsToPaddle() {}
/**
- * convert MKLDNN output to other device.
- * only support CPU device yet
+ * add this interface as public for unit test
+ */
+ void addOutputArgument(int deviceId) { Layer::addOutputArgument(deviceId); }
+
+protected:
+ /**
+ * reshape the input image sizes and input batchsize
*/
- virtual void convertOutputToOtherDevice() {}
+ virtual void reshapeInput(int& batchsize, int& height, int& width) {
+ const Argument& input = inputLayers_[0]->getOutput();
+ batchsize = input.getBatchSize();
+ int h = input.getFrameHeight();
+ int w = input.getFrameWidth();
+ if (h != 0) {
+ height = h;
+ }
+ if (w != 0) {
+ width = w;
+ }
+ }
+
+ /**
+ * reshape output image sizes
+ */
+ virtual void reshapeOutput(size_t height, size_t width) {
+ output_.setFrameHeight(height);
+ output_.setFrameWidth(width);
+ for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
+ outputOtherDevice_[i].setFrameHeight(height);
+ outputOtherDevice_[i].setFrameWidth(width);
+ }
+ }
/**
* print info about sizes
@@ -133,8 +257,8 @@ public:
*/
virtual void printValueFormatFlow() {
if (inVal_ && outVal_) {
- VLOG(MKLDNN_FMTS) << "value format flow --- " << inVal_->getFormat()
- << " >>> " << outVal_->getFormat();
+ VLOG(MKLDNN_FMTS) << inVal_->getFormat() << " >>> "
+ << outVal_->getFormat();
}
}
@@ -143,29 +267,12 @@ public:
*/
virtual void printGradFormatFlow() {
if (inGrad_ && outGrad_) {
- VLOG(MKLDNN_FMTS) << "grad format flow --- " << inGrad_->getFormat()
- << " <<< " << outGrad_->getFormat();
+ VLOG(MKLDNN_FMTS) << inGrad_->getFormat() << " <<< "
+ << outGrad_->getFormat();
}
}
protected:
- /**
- * copy image size and sequence info to other device
- * @note: can not directly use Layer::copyOutputToOtherDevice since here only
- * copy base info and do not copy data value
- */
- void copyOutputInfoToOtherDevice() {
- for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
- outputOtherDevice_[i].setFrameHeight(output_.getFrameHeight());
- outputOtherDevice_[i].setFrameWidth(output_.getFrameWidth());
- outputOtherDevice_[i].sequenceStartPositions =
- output_.sequenceStartPositions;
- outputOtherDevice_[i].subSequenceStartPositions =
- output_.subSequenceStartPositions;
- outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims;
- }
- }
-
/**
* If input only has MKLDNN device.
* Otherwise, only support the previous layer using CPU device.
@@ -193,37 +300,12 @@ protected:
return outputOtherDevice_.size() == 0;
}
- /**
- * Sync input value data
- */
- void syncInputValue() {
- if (inputIsOnlyMKLDNN()) {
- return;
- }
- real* iData = getInputValue(0, CPU_DEVICE)->getData();
- // update input data
- // since it might be changed if this is after data layer
- inVal_->updateData(iData);
- }
-
- /**
- * Sync output grad data
- */
- void syncOutputGrad() {
- if (outputIsOnlyMKLDNN()) {
- return;
- }
-
- // update diff
- real* oDiff = getOutput(CPU_DEVICE).grad->getData();
- outGrad_->updateData(oDiff);
- }
-
/**
* Set deviceId of this layer.
*/
void setDevice(int id) { deviceId_ = id; }
+private:
/**
* Set deviceId of the params used in this layer.
*/
@@ -247,6 +329,42 @@ protected:
parameter->setDevice(id);
}
}
+
+ /**
+ * Check the cpu device number of outputOtherDevice_.
+ * should have only one at most.
+ */
+ void checkCPUOutputsNumber(int max = 1) {
+ int cnt = 0;
+ for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
+ if (outputOtherDevice_[i].deviceId == CPU_DEVICE) {
+ ++cnt;
+ }
+ }
+ CHECK_LE(cnt, max) << "too much CPU devies";
+ }
+
+ /**
+ * copy SeqInfo from input layer to this output and other output devices.
+ * @note: do not use getInput(0) since it used this deviceId_,
+ * use "inputLayers_[0]->getOutput()" instead.
+ */
+ void copySeqInfoToOutputs() {
+ if (inputLayers_.empty() || !needSequenceInfo_) {
+ return;
+ }
+ const Argument& input = inputLayers_[0]->getOutput();
+ output_.sequenceStartPositions = input.sequenceStartPositions;
+ output_.subSequenceStartPositions = input.subSequenceStartPositions;
+ output_.cpuSequenceDims = input.cpuSequenceDims;
+ for (size_t i = 0; i < outputOtherDevice_.size(); i++) {
+ outputOtherDevice_[i].sequenceStartPositions =
+ output_.sequenceStartPositions;
+ outputOtherDevice_[i].subSequenceStartPositions =
+ output_.subSequenceStartPositions;
+ outputOtherDevice_[i].cpuSequenceDims = output_.cpuSequenceDims;
+ }
+ }
};
} // namespace paddle
diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp
index 6a91042f628920a9986763531fb4c633307b43b8..e97809141a93106f9e6ebaf40c7e8aa9c6010557 100644
--- a/paddle/gserver/layers/SwitchOrderLayer.cpp
+++ b/paddle/gserver/layers/SwitchOrderLayer.cpp
@@ -24,19 +24,21 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap,
/* Initialize the basic parent class */
Layer::init(layerMap, parameterMap);
auto& img_conf = config_.inputs(0).image_conf();
+ size_t inD = img_conf.img_size_z();
size_t inH =
img_conf.has_img_size_y() ? img_conf.img_size_y() : img_conf.img_size();
size_t inW = img_conf.img_size();
size_t inC = img_conf.channels();
+ inH = inH * inD;
inDims_ = TensorShape({0, inC, inH, inW});
outDims_ = TensorShape(4);
auto& reshape_conf = config_.reshape_conf();
- for (size_t i = 0; i < reshape_conf.heightaxis_size(); i++) {
- heightAxis_.push_back(reshape_conf.heightaxis(i));
+ for (int i = 0; i < reshape_conf.height_axis_size(); i++) {
+ heightAxis_.push_back(reshape_conf.height_axis(i));
}
- for (size_t i = 0; i < reshape_conf.widthaxis_size(); i++) {
- widthAxis_.push_back(reshape_conf.widthaxis(i));
+ for (int i = 0; i < reshape_conf.width_axis_size(); i++) {
+ widthAxis_.push_back(reshape_conf.width_axis(i));
}
createFunction(nchw2nhwc_, "NCHW2NHWC", FuncConfig());
createFunction(nhwc2nchw_, "NHWC2NCHW", FuncConfig());
@@ -64,9 +66,10 @@ void SwitchOrderLayer::setInDims() {
MatrixPtr input = inputLayers_[0]->getOutputValue();
size_t batchSize = input->getHeight();
inDims_.setDim(0, batchSize);
-
+ int d = inputLayers_[0]->getOutput().getFrameDepth();
+ d = (d == 0 ? 1 : d);
int h = inputLayers_[0]->getOutput().getFrameHeight();
- if (h != 0) inDims_.setDim(2, h);
+ if (h != 0) inDims_.setDim(2, h * d);
int w = inputLayers_[0]->getOutput().getFrameWidth();
if (w != 0) inDims_.setDim(3, w);
int totalCount = input->getElementCnt();
@@ -80,8 +83,7 @@ void SwitchOrderLayer::forward(PassType passType) {
setOutDims();
resetOutput(outDims_[0], outDims_[1] * outDims_[2] * outDims_[3]);
if (heightAxis_.size() > 0) {
- getOutputValue()->reshape(reshapeHeight_, reshapeWidth_);
- getOutputGrad()->reshape(reshapeHeight_, reshapeWidth_);
+ resetOutput(reshapeHeight_, reshapeWidth_);
}
// switch NCHW to NHWC
diff --git a/paddle/gserver/tests/MKLDNNTester.cpp b/paddle/gserver/tests/MKLDNNTester.cpp
index de1635be2af37cd0ba49010199a417090865b0e4..2f48e5b2d3ffc9337ed1314f6db6549e56263fdd 100644
--- a/paddle/gserver/tests/MKLDNNTester.cpp
+++ b/paddle/gserver/tests/MKLDNNTester.cpp
@@ -63,8 +63,12 @@ void MKLDNNTester::reset(const TestConfig& dnn,
initTestLayer(
configs_[i], &(layerMaps_[i]), &(parameters_[i]), &(testLayers_[i]));
}
- dnnLayer_ = testLayers_[DNN];
refLayer_ = testLayers_[REF];
+ dnnLayer_ = std::dynamic_pointer_cast(testLayers_[DNN]);
+ CHECK(dnnLayer_);
+ // for comparison with Paddle reference results,
+ // need manually add cpu device output for test
+ dnnLayer_->addOutputArgument(CPU_DEVICE);
EXPECT_EQ(dataLayers_[DNN].size(), dataLayers_[REF].size());
EXPECT_EQ(parameters_[DNN].size(), parameters_[REF].size());
@@ -109,20 +113,22 @@ void MKLDNNTester::randomBotDatas() {
void MKLDNNTester::randomTopDiffs() {
refLayer_->getOutputGrad()->randomizeUniform();
- dnnLayer_->getOutputGrad()->copyFrom(*(refLayer_->getOutputGrad()));
- VLOG(lvl_) << "Random dom Backward Input, TopDiff: ";
+ dnnLayer_->getOutput(CPU_DEVICE)
+ .grad->copyFrom(*(refLayer_->getOutputGrad()));
+ VLOG(lvl_) << "Random Backward Input, TopDiff: ";
printMatrix(refLayer_->getOutputGrad());
}
void MKLDNNTester::checkForward() {
- printTopDatas();
- double delta = compareMatrix(testLayers_[DNN]->getOutputValue(),
- testLayers_[REF]->getOutputValue());
VLOG(MKLDNN_ALL) << "Check Forward";
+ printTopDatas();
+ double delta = compareMatrix(dnnLayer_->getOutput(-1).value,
+ refLayer_->getOutputValue());
EXPECT_LE(fabs(delta), eps_);
}
void MKLDNNTester::checkBackwardData() {
+ VLOG(MKLDNN_ALL) << "Check Backward Data";
// TODO(TJ): uncomment me when batch norm ready
// const bool isBN = dnnLayer_->getType() == "mkldnn_batch_norm";
for (size_t i = 0; i < dataLayers_[DNN].size(); ++i) {
@@ -144,14 +150,12 @@ void MKLDNNTester::checkBackwardData() {
}
void MKLDNNTester::checkBackwardWgts() {
+ VLOG(MKLDNN_ALL) << "Check Backward Weight";
CHECK_EQ(parameters_[DNN].size(), parameters_[REF].size());
vector dnnWgts; // used to temply save mkldnn weights
saveWgt(parameters_[DNN], dnnWgts);
- const MKLDNNLayerPtr dnnlayer =
- std::dynamic_pointer_cast(dnnLayer_);
- CHECK(dnnlayer);
- dnnlayer->convertWeightsToPaddle();
+ dnnLayer_->convertWeightsToPaddle();
for (size_t i = 0; i < parameters_[DNN].size(); ++i) {
const VectorPtr& dnn = parameters_[DNN][i]->getBuf(PARAMETER_VALUE);
const VectorPtr& ref = parameters_[REF][i]->getBuf(PARAMETER_VALUE);
@@ -189,38 +193,38 @@ void MKLDNNTester::restoreWgt(const vector& from,
}
// clear parameters grad
-void MKLDNNTester::clearWgtDiffs() {
+void MKLDNNTester::clearWgtDiffs(size_t id) {
+ CHECK_LE(id, parameters_.size());
for (size_t n = 0; n < parameters_.size(); ++n) {
- for (size_t i = 0; i < parameters_[n].size(); ++i) {
- const VectorPtr& grad = parameters_[n][i]->getBuf(PARAMETER_GRADIENT);
- if (grad) {
- grad->zeroMem();
+ if (id == n || id == parameters_.size()) {
+ for (size_t i = 0; i < parameters_[n].size(); ++i) {
+ const VectorPtr& grad = parameters_[n][i]->getBuf(PARAMETER_GRADIENT);
+ if (grad) {
+ grad->zeroMem();
+ }
}
}
}
}
-void MKLDNNTester::clearBotDiffs() {
- // dnn and ref
+void MKLDNNTester::clearBotDiffs(size_t id) {
+ CHECK_LE(id, dataLayers_.size());
for (size_t n = 0; n < dataLayers_.size(); ++n) {
- // all inputs layers
- for (size_t i = 0; i < dataLayers_[n].size(); ++i) {
- dataLayers_[n][i]->getOutputGrad()->zeroMem();
+ if (id == n || id == dataLayers_.size()) {
+ // clear inputs layers of this specific layer
+ for (size_t i = 0; i < dataLayers_[n].size(); ++i) {
+ dataLayers_[n][i]->getOutputGrad()->zeroMem();
+ }
}
}
}
-void MKLDNNTester::clearBotDiffs(int n) {
- CHECK_LT(n, NUM);
- // all inputs layers
- for (size_t i = 0; i < dataLayers_[n].size(); ++i) {
- dataLayers_[n][i]->getOutputGrad()->zeroMem();
- }
-}
-
-void MKLDNNTester::clearTopDatas() {
+void MKLDNNTester::clearTopDatas(size_t id) {
+ CHECK_LE(id, testLayers_.size());
for (size_t i = 0; i < testLayers_.size(); ++i) {
- testLayers_[i]->getOutputValue()->zeroMem();
+ if (id == i || id == testLayers_.size()) {
+ testLayers_[i]->getOutputValue()->zeroMem();
+ }
}
}
@@ -300,16 +304,24 @@ void MKLDNNTester::runOnce() {
checkForward();
// test backward
+ // simple updater
+ UpdateCallback updateCallback = [](Parameter* para) {
+ auto& grad = para->getBuf(PARAMETER_GRADIENT);
+ auto& value = para->getBuf(PARAMETER_VALUE);
+ real lr = 1e-3;
+ value->add(*grad, lr);
+ };
randomTopDiffs();
- dnnLayer_->backward(nullptr);
- refLayer_->backward(nullptr);
+ dnnLayer_->backward(updateCallback);
+ refLayer_->backward(updateCallback);
checkBackwardData();
checkBackwardWgts();
// clear buffers
// ref code will addto the diff, dnn code will writeto it
- // and clearTopDatas() and clearWgtDiffs() should be coverd by test layers
+ // and clearTopDatas(REF) should be coverd by ref layers
clearBotDiffs(REF);
+ clearWgtDiffs(REF);
}
void MKLDNNTester::run(const TestConfig& dnn,
diff --git a/paddle/gserver/tests/MKLDNNTester.h b/paddle/gserver/tests/MKLDNNTester.h
index e55e4493ffdfe45b8cfdee423febd1878b8b3d8a..5ac885638cde7693a0c847733e7a6149c1b7e6c2 100644
--- a/paddle/gserver/tests/MKLDNNTester.h
+++ b/paddle/gserver/tests/MKLDNNTester.h
@@ -18,6 +18,7 @@ limitations under the License. */
#include
#include "LayerGradUtil.h"
#include "paddle/gserver/layers/MKLDNNBase.h"
+#include "paddle/gserver/layers/MKLDNNLayer.h"
namespace paddle {
@@ -40,7 +41,8 @@ protected:
vector layerMaps_;
vector> parameters_;
vector testLayers_;
- LayerPtr dnnLayer_, refLayer_;
+ LayerPtr refLayer_;
+ MKLDNNLayerPtr dnnLayer_;
/// run some iterations, all the result should pass
size_t iter_;
@@ -88,10 +90,10 @@ private:
void checkBackwardData();
void checkBackwardWgts();
- void clearWgtDiffs();
- void clearBotDiffs();
- void clearBotDiffs(int n); // clear specific layer
- void clearTopDatas();
+ // clear specific layer, clear all when id equals NUM
+ void clearWgtDiffs(size_t id = NUM);
+ void clearBotDiffs(size_t id = NUM);
+ void clearTopDatas(size_t id = NUM);
void printTopDatas();
void printMatrix(const MatrixPtr& m);
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index e0c14ad5b512c7329062a5426ef34844ec268020..090bde7b203652e3ffb1662b8f5b8937885d2608 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -1703,6 +1703,55 @@ TEST(Layer, BatchNormalizationLayer) {
#endif
}
+void testBatchNorm3DLayer(const string& type, bool trans, bool useGpu) {
+ TestConfig config;
+ const int CHANNELS = 10;
+ const int IMG_SIZE = 16;
+ const int IMG_SIZE_Y = 8;
+ const int IMG_SIZE_Z = 8;
+ size_t size = CHANNELS * IMG_SIZE * IMG_SIZE_Y * IMG_SIZE_Z;
+ config.layerConfig.set_type(type);
+ config.layerConfig.set_size(size);
+ config.layerConfig.set_active_type("sigmoid");
+ config.biasSize = CHANNELS;
+ config.inputDefs.push_back({INPUT_DATA,
+ "layer_0",
+ /* dim= */ size,
+ /* paraSize= */ CHANNELS});
+
+ config.inputDefs.push_back({INPUT_DATA, "layer_1_running_mean", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+ config.inputDefs.push_back({INPUT_DATA, "layer_2_running_var", 1, CHANNELS});
+ config.inputDefs.back().isStatic = true;
+
+ LayerInputConfig* input = config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+ config.layerConfig.add_inputs();
+
+ ImageConfig* img_conf = input->mutable_image_conf();
+ img_conf->set_channels(CHANNELS);
+ img_conf->set_img_size(IMG_SIZE);
+ img_conf->set_img_size_y(IMG_SIZE_Y);
+ img_conf->set_img_size_z(IMG_SIZE_Z);
+
+ testLayerGrad(config,
+ "batch_norm",
+ 64,
+ /* trans= */ trans,
+ useGpu,
+ /* useWeight */ true);
+}
+
+TEST(Layer, testBatchNorm3DLayer) {
+ testBatchNorm3DLayer("batch_norm", false, false);
+#ifndef PADDLE_ONLY_CPU
+ testBatchNorm3DLayer("batch_norm", false, true);
+ if (hl_get_cudnn_lib_version() >= int(4000)) {
+ testBatchNorm3DLayer("cudnn_batch_norm", false, true);
+ }
+#endif
+}
+
void testConvOperator(bool isDeconv) {
TestConfig config;
const int NUM_FILTERS = 16;
@@ -2019,10 +2068,10 @@ TEST(Layer, SwitchOrderLayer) {
img->set_img_size_y(16);
ReshapeConfig* reshape = config.layerConfig.mutable_reshape_conf();
- reshape->add_heightaxis(0);
- reshape->add_heightaxis(1);
- reshape->add_heightaxis(2);
- reshape->add_widthaxis(3);
+ reshape->add_height_axis(0);
+ reshape->add_height_axis(1);
+ reshape->add_height_axis(2);
+ reshape->add_width_axis(3);
// config softmax layer
config.layerConfig.set_type("switch_order");
@@ -2253,26 +2302,27 @@ void test3DDeConvLayer(const string& type, bool trans, bool useGpu) {
conv->set_stride(2);
conv->set_stride_y(2);
conv->set_stride_z(2);
- conv->set_img_size(IMAGE_SIZE);
- conv->set_img_size_y(IMAGE_SIZE_Y);
- conv->set_img_size_z(IMAGE_SIZE_Z);
- conv->set_output_x(imageSize(conv->img_size(),
+ conv->set_output_x(IMAGE_SIZE);
+ conv->set_output_y(IMAGE_SIZE_Y);
+ conv->set_output_z(IMAGE_SIZE_Z);
+
+ conv->set_img_size(imageSize(conv->output_x(),
conv->filter_size(),
conv->padding(),
conv->stride(),
true));
- conv->set_output_y(imageSize(conv->img_size_y(),
- conv->filter_size_y(),
- conv->padding_y(),
- conv->stride_y(),
- true));
- conv->set_output_z(imageSize(conv->img_size_z(),
- conv->filter_size_z(),
- conv->padding_z(),
- conv->stride_z(),
- true));
- config.layerConfig.set_size(conv->output_x() * conv->output_y() *
- conv->output_z() * NUM_FILTERS);
+ conv->set_img_size_y(imageSize(conv->output_y(),
+ conv->filter_size_y(),
+ conv->padding_y(),
+ conv->stride_y(),
+ true));
+ conv->set_img_size_z(imageSize(conv->output_z(),
+ conv->filter_size_z(),
+ conv->padding_z(),
+ conv->stride_z(),
+ true));
+ config.layerConfig.set_size(conv->img_size() * conv->img_size_y() *
+ conv->img_size_z() * NUM_FILTERS);
conv->set_groups(1);
conv->set_filter_channels(conv->channels() / conv->groups());
config.inputDefs.push_back(
diff --git a/paddle/math/MKLDNNMatrix.cpp b/paddle/math/MKLDNNMatrix.cpp
index 0a355e2644cce572ce90ecf5c9d2a5b7b395bc61..c4063e5069854242d9f93886b66580385557ca73 100644
--- a/paddle/math/MKLDNNMatrix.cpp
+++ b/paddle/math/MKLDNNMatrix.cpp
@@ -33,14 +33,12 @@ MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m, memory::primitive_desc pd) {
size_t width = cnts / dims[0];
m = Matrix::create(height, width, false, false);
}
-
CHECK(m) << " Matrix should not be empty";
+
CpuMatrixPtr cpuMatrix = std::dynamic_pointer_cast(m);
CHECK(cpuMatrix) << "Only support create from CPU matrix yet";
-
- CHECK_EQ(cnts, m->getElementCnt()) << "Count size does not match";
- return std::make_shared(
- m->getData(), m->getHeight(), m->getWidth(), pd);
+ CHECK_EQ(cpuMatrix->getElementCnt(), cnts) << "Count size does not match";
+ return std::make_shared(cpuMatrix, pd);
}
MKLDNNMatrixPtr MKLDNNMatrix::create(MatrixPtr m,
@@ -138,7 +136,7 @@ void MKLDNNMatrix::downSpatial() {
mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr),
"could not create a memory primitive");
reset(result);
- set_data_handle(getData());
+ set_data_handle(data_);
}
} // namespace paddle
diff --git a/paddle/math/MKLDNNMatrix.h b/paddle/math/MKLDNNMatrix.h
index e50f698b495713e6f15ab7a12a7ee7487662040f..eef3b429e6fa0087aeac3f5aed9dff983b06e826 100644
--- a/paddle/math/MKLDNNMatrix.h
+++ b/paddle/math/MKLDNNMatrix.h
@@ -30,11 +30,10 @@ typedef std::shared_ptr MKLDNNMatrixPtr;
*/
class MKLDNNMatrix : public CpuMatrix, public mkldnn::memory {
public:
- MKLDNNMatrix(real* data,
- size_t height,
- size_t width,
- mkldnn::memory::primitive_desc pd)
- : CpuMatrix(data, height, width, false), mkldnn::memory(pd, data) {}
+ MKLDNNMatrix(CpuMatrixPtr m, mkldnn::memory::primitive_desc pd)
+ : CpuMatrix(m->getData(), m->getHeight(), m->getWidth(), false),
+ mkldnn::memory(pd, m->getData()),
+ m_(m) {}
~MKLDNNMatrix() {}
@@ -81,11 +80,29 @@ public:
void downSpatial();
/**
- * Update the memory data handle.
+ * set the memory data handle.
* Caution: This will not check the buffer size of the data,
* it should be coverd by user.
*/
- void updateData(void* data) { set_data_handle(data); }
+ void setData(real* data) {
+ set_data_handle(data);
+ CpuMatrix::setData(data);
+ m_.reset();
+ }
+
+ /**
+ * override Matrix::getData
+ * check data before return
+ */
+ real* getData() override {
+ CHECK_EQ((void*)data_, get_data_handle());
+ return data_;
+ }
+
+ const real* getData() const override {
+ CHECK_EQ((void*)data_, get_data_handle());
+ return data_;
+ }
/**
* Get primitive descriptor.
@@ -143,6 +160,10 @@ protected:
memory::format srcFmt,
memory::format dstFmt,
memory::dims dm);
+
+private:
+ // save the CpuMatrixPtr in case the buffer released outside
+ CpuMatrixPtr m_;
};
} // namespace paddle
diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..0ebefbab26ec8fdf316f852fbb7f6d9f3bbc48eb
--- /dev/null
+++ b/paddle/operators/concat_op.cc
@@ -0,0 +1,79 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/operators/concat_op.h"
+#include
+
+namespace paddle {
+namespace operators {
+using framework::Tensor;
+
+class ConcatOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ auto ins = ctx.MultiInput("X");
+ auto *out = ctx.Output("Out");
+ size_t axis = static_cast(ctx.Attr("axis"));
+ size_t n = ins.size();
+
+ PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
+
+ auto out_dims = ins[0]->dims();
+ size_t in_zero_dims_size = out_dims.size();
+ for (size_t i = 1; i < n; i++) {
+ for (size_t j = 0; j < in_zero_dims_size; j++) {
+ if (j == axis) {
+ out_dims[axis] += ins[i]->dims()[j];
+ continue;
+ }
+ PADDLE_ENFORCE_EQ(out_dims[j], ins[i]->dims()[j],
+ "Input tensors should have the same "
+ "elements except the specify axis.")
+ }
+ }
+ out->Resize(out_dims);
+ }
+};
+
+class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ ConcatOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "the input tensors of concat operator.").AsDuplicable();
+ AddOutput("Out", "the output tensor of concat operator.");
+ AddComment(R"DOC(
+ Join the input tensors along with the axis.
+ Examples:
+ Input[0] = [[1,2],[3,4]]
+ Input[1] = [[5,6]]
+ axis = 0
+ Output = [[1,2],
+ [3,4],
+ [5,6]]
+ )DOC");
+ AddAttr("axis", "The axis which the inputs will be joined with.")
+ .SetDefault(0);
+ }
+};
+
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP_WITHOUT_GRADIENT(concat, ops::ConcatOp, ops::ConcatOpMaker)
+REGISTER_OP_CPU_KERNEL(concat,
+ ops::ConcatKernel)
diff --git a/paddle/operators/concat_op.cu b/paddle/operators/concat_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..38fee7473dbb2ba97fe95b6632db7a1749cf3bbe
--- /dev/null
+++ b/paddle/operators/concat_op.cu
@@ -0,0 +1,19 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/concat_op.h"
+
+namespace ops = paddle::operators;
+// TODO(Yancey1989) Add GPU kernel
diff --git a/paddle/operators/concat_op.h b/paddle/operators/concat_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..f977054fdf8aa0164db726b94a21c57f770dd674
--- /dev/null
+++ b/paddle/operators/concat_op.h
@@ -0,0 +1,64 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#pragma once
+
+#include
+#include "paddle/framework/op_registry.h"
+
+namespace paddle {
+namespace operators {
+
+template
+class ConcatKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ auto ins = ctx.MultiInput("X");
+ auto* out = ctx.Output("Out");
+ int64_t axis = static_cast(ctx.Attr("axis"));
+ size_t n = ins.size();
+ size_t output_axis_dim = 0;
+ size_t before = 1, after = 1;
+ for (size_t i = 0; i < n; i++) {
+ output_axis_dim += ins[i]->dims()[axis];
+ }
+ auto& input_zero = ins[0];
+ for (int64_t i = 0; i < input_zero->dims().size(); i++) {
+ if (i == axis) {
+ continue;
+ }
+ if (i < axis) {
+ before *= input_zero->dims()[i];
+ } else {
+ after *= input_zero->dims()[i];
+ }
+ }
+ size_t output_offset = 0;
+ for (size_t i = 0; i < n; i++) {
+ auto& in = ins[i];
+ auto axis_dim = in->dims()[axis];
+ for (size_t j = 0; j < before; j++) {
+ size_t len = axis_dim * after * sizeof(T);
+ const T* src = in->data() + axis_dim * after * j;
+ T* out_data = out->mutable_data(platform::CPUPlace());
+ T* dest = out_data + output_offset + output_axis_dim * after * j;
+ memcpy(dest, src, len);
+ }
+ output_offset += axis_dim * after;
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/operators/cos_sim_op.h b/paddle/operators/cos_sim_op.h
index 9e2bcebe3b5432c157fac895a9bbab5164193dbb..0dc509952578497671a128374f77ce616a520909 100644
--- a/paddle/operators/cos_sim_op.h
+++ b/paddle/operators/cos_sim_op.h
@@ -42,7 +42,7 @@ class CosSimKernel : public framework::OpKernel {
output_y_norm->mutable_data(context.GetPlace());
auto dims = input_x->dims();
- int size = static_cast(framework::product(dims));
+ int64_t size = input_x->numel();
auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
auto x = EigenMatrix::From(*input_x, new_dims);
auto y = EigenMatrix::From(*input_y, new_dims);
@@ -72,7 +72,7 @@ class CosSimGradKernel : public framework::OpKernel {
auto* input_grad_z = context.Input(framework::GradVarName("Out"));
auto dims = input_x->dims();
- int size = static_cast(framework::product(dims));
+ int64_t size = input_x->numel();
auto new_dims = framework::make_ddim({dims[0], size / dims[0]});
auto x = EigenMatrix::From(*input_x, new_dims);
auto y = EigenMatrix::From(*input_y, new_dims);
diff --git a/paddle/operators/elementwise_mul_op.cc b/paddle/operators/elementwise_mul_op.cc
new file mode 100644
index 0000000000000000000000000000000000000000..1742925545d29df5d7df719faaea3b754680ab61
--- /dev/null
+++ b/paddle/operators/elementwise_mul_op.cc
@@ -0,0 +1,109 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#include "paddle/operators/elementwise_mul_op.h"
+
+namespace paddle {
+namespace operators {
+
+using Tensor = framework::Tensor;
+
+class ElementWiseMulOp : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
+ auto x_dim = ctx.Input("X")->dims();
+ auto y_dim = ctx.Input("Y")->dims();
+ PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(),
+ "Rank of first input must >= rank of second input.")
+ ctx.Output("Out")->Resize(x_dim);
+ }
+};
+
+class ElementWiseMulOpMaker : public framework::OpProtoAndCheckerMaker {
+ public:
+ ElementWiseMulOpMaker(framework::OpProto *proto,
+ framework::OpAttrChecker *op_checker)
+ : OpProtoAndCheckerMaker(proto, op_checker) {
+ AddInput("X", "The first input of elementwise mul op");
+ AddInput("Y", "The second input of elementwise mul op");
+ AddAttr("axis",
+ R"DOC(
+When shape(Y) does not equal shape(X),Y will be broadcasted
+to match the shape of X and axis should be dimension index Y in X
+ )DOC")
+ .SetDefault(-1)
+ .EqualGreaterThan(-1);
+
+ AddOutput("Out", "The output of elementwise mul op");
+ AddComment(R"DOC(
+Limited elementwise multiple operator.The equation is: Out = X ⊙ Y.
+1. The shape of Y should be same with X or
+2. Y's shape is a subset of X.
+ Y will be broadcasted to match the shape of X and axis should be dimension index Y in X.
+ example:
+ shape(X) = (2, 3, 4, 5), shape(Y) = (,)
+ shape(X) = (2, 3, 4, 5), shape(Y) = (5,)
+ shape(X) = (2, 3, 4, 5), shape(Y) = (4, 5)
+ shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
+ shape(X) = (2, 3, 4, 5), shape(Y) = (2), with axis=0
+)DOC");
+ }
+};
+
+class ElementWiseMulOpGrad : public framework::OperatorWithKernel {
+ public:
+ using framework::OperatorWithKernel::OperatorWithKernel;
+
+ protected:
+ void InferShape(const framework::InferShapeContext &ctx) const override {
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Y"), "Input(Y) should not be null");
+ PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
+ "Input(Out@GRAD) should not be null");
+
+ auto x_dims = ctx.Input("X")->dims();
+ auto y_dims = ctx.Input("Y")->dims();
+ auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims();
+ auto *x_grad = ctx.Output(framework::GradVarName("X"));
+ auto *y_grad = ctx.Output(framework::GradVarName("Y"));
+
+ PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
+ "Rank of first input must >= rank of second input.")
+
+ if (x_grad) {
+ x_grad->Resize(x_dims);
+ }
+
+ if (y_grad) {
+ y_grad->Resize(y_dims);
+ }
+ }
+};
+} // namespace operators
+} // namespace paddle
+
+namespace ops = paddle::operators;
+REGISTER_OP(elementwise_mul, ops::ElementWiseMulOp, ops::ElementWiseMulOpMaker,
+ elementwise_mul_grad, ops::ElementWiseMulOpGrad);
+REGISTER_OP_CPU_KERNEL(
+ elementwise_mul,
+ ops::ElementWiseMulKernel);
+REGISTER_OP_CPU_KERNEL(
+ elementwise_mul_grad,
+ ops::ElementWiseMulGradKernel);
diff --git a/paddle/operators/elementwise_mul_op.cu b/paddle/operators/elementwise_mul_op.cu
new file mode 100644
index 0000000000000000000000000000000000000000..56f2087c22c6c599a3c5aef36eb0fe3eac295bef
--- /dev/null
+++ b/paddle/operators/elementwise_mul_op.cu
@@ -0,0 +1,25 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#define EIGEN_USE_GPU
+#include "paddle/operators/elementwise_mul_op.h"
+
+namespace ops = paddle::operators;
+
+REGISTER_OP_GPU_KERNEL(
+ elementwise_mul,
+ ops::ElementWiseMulKernel);
+REGISTER_OP_GPU_KERNEL(
+ elementwise_mul_grad,
+ ops::ElementWiseMulGradKernel);
diff --git a/paddle/operators/elementwise_mul_op.h b/paddle/operators/elementwise_mul_op.h
new file mode 100644
index 0000000000000000000000000000000000000000..e9ed6791799240039f9af42c1a4339be7126ee65
--- /dev/null
+++ b/paddle/operators/elementwise_mul_op.h
@@ -0,0 +1,185 @@
+/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License. */
+
+#pragma once
+#include
+#include "paddle/framework/eigen.h"
+#include "paddle/framework/op_registry.h"
+#include "paddle/operators/math/math_function.h"
+
+namespace paddle {
+namespace operators {
+/*
+ * Out = X ⊙ Y
+ * 1. shape(X) = (2, 3, 4, 5), shape(Y) = (3, 4), with axis=1
+ * pre=2, n=3*4, post=5
+ * 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
+ * pre=2*3, n=4*5, post=1
+ */
+
+inline void get_mid_dims(const framework::DDim& x_dims,
+ const framework::DDim& y_dims, const int axis,
+ int& pre, int& n, int& post) {
+ pre = 1;
+ n = 1;
+ post = 1;
+ for (int i = 0; i < axis; ++i) {
+ pre *= x_dims[i];
+ }
+
+ for (int i = 0; i < y_dims.size(); ++i) {
+ PADDLE_ENFORCE_EQ(x_dims[i + axis], y_dims[i],
+ "Broadcast dimension mismatch.");
+ n *= y_dims[i];
+ }
+
+ for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
+ post *= x_dims[i];
+ }
+}
+
+template
+class ElementWiseMulKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ using Tensor = framework::Tensor;
+
+ auto* x = ctx.Input("X");
+ auto* y = ctx.Input("Y");
+ auto* z = ctx.Output("Out");
+ z->mutable_data(ctx.GetPlace());
+
+ auto x_e = framework::EigenVector::Flatten(*x);
+ auto y_e = framework::EigenVector::Flatten(*y);
+ auto z_e = framework::EigenVector::Flatten(*z);
+
+ auto x_dims = x->dims();
+ auto y_dims = y->dims();
+ PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
+ "Rank of first input must >= rank of second input.")
+
+ if (x_dims == y_dims || product(y_dims) == 1) {
+ z_e.device(ctx.GetEigenDevice()) = x_e * y_e;
+ return;
+ }
+
+ int axis = ctx.Attr("axis");
+ axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
+ PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),
+ "Axis should be in range [0, x_dims)");
+
+ int pre, n, post;
+ get_mid_dims(x_dims, y_dims, axis, pre, n, post);
+ if (post == 1) {
+ auto y_bcast = y_e.reshape(Eigen::DSizes(1, n))
+ .broadcast(Eigen::DSizes(pre, 1))
+ .reshape(Eigen::DSizes(x_e.size()));
+ z_e.device(ctx.GetEigenDevice()) = x_e * y_bcast;
+ return;
+ } else {
+ auto y_bcast = y_e.reshape(Eigen::DSizes(1, n, 1))
+ .broadcast(Eigen::DSizes(pre, 1, post))
+ .reshape(Eigen::DSizes(x_e.size()));
+ z_e.device(ctx.GetEigenDevice()) = x_e * y_bcast;
+ return;
+ }
+ }
+};
+
+template
+class ElementWiseMulGradKernel : public framework::OpKernel {
+ public:
+ void Compute(const framework::ExecutionContext& ctx) const override {
+ using Tensor = framework::Tensor;
+
+ auto* x = ctx.Input("X");
+ auto* y = ctx.Input("Y");
+ auto* dout = ctx.Input(framework::GradVarName("Out"));
+
+ auto x_e = framework::EigenVector::Flatten(*x);
+ auto y_e = framework::EigenVector::Flatten(*y);
+ auto dout_e = framework::EigenVector::Flatten(*dout);
+
+ auto x_dims = x->dims();
+ auto y_dims = y->dims();
+
+ auto* dx = ctx.Output(framework::GradVarName("X"));
+ auto* dy = ctx.Output(framework::GradVarName("Y"));
+ if (dx) {
+ dx->mutable_data(ctx.GetPlace());
+ }
+ if (dy) {
+ dy->mutable_data(ctx.GetPlace());
+ }
+
+ if (x_dims == y_dims || product(y_dims) == 1) {
+ if (dx) {
+ auto dx_e = framework::EigenVector::Flatten(*dx);
+ dx_e.device(ctx.GetEigenDevice()) = dout_e * y_e;
+ }
+
+ if (dy) {
+ auto dy_e = framework::EigenVector::Flatten(*dy);
+ dy_e.device(ctx.GetEigenDevice()) = x_e * dout_e;
+ }
+ return;
+ }
+
+ int axis = ctx.Attr("axis");
+ axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
+
+ int pre, n, post;
+ get_mid_dims(x_dims, y_dims, axis, pre, n, post);
+
+ // TODO(gongweibao): wrap reshape to a function.
+ if (post == 1) {
+ auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n))
+ .broadcast(Eigen::DSizes(pre, 1))
+ .reshape(Eigen::DSizes(x_e.size()));
+ if (dx) {
+ auto dx_e = framework::EigenVector::Flatten(*dx);
+ dx_e.device(ctx.GetEigenDevice()) = dout_e * y_e_bcast;
+ }
+
+ if (dy) {
+ auto dy_e = framework::EigenVector::Flatten(*dy);
+ dy_e.device(ctx.GetEigenDevice()) =
+ (x_e * dout_e)
+ .reshape(Eigen::DSizes(pre, n))
+ .sum(Eigen::array{{0}});
+ }
+ return;
+ } else {
+ auto y_e_bcast = y_e.reshape(Eigen::DSizes(1, n, 1))
+ .broadcast(Eigen::DSizes(pre, 1, post))
+ .reshape(Eigen::DSizes(x_e.size()));
+ if (dx) {
+ auto dx_e = framework::EigenVector::Flatten(*dx);
+ dx_e.device(ctx.GetEigenDevice()) = dout_e * y_e_bcast;
+ }
+
+ if (dy) {
+ auto dy_e = framework::EigenVector::Flatten(*dy);
+ dy_e.device(ctx.GetEigenDevice()) =
+ (x_e * dout_e)
+ .reshape(Eigen::DSizes(pre, n, post))
+ .sum(Eigen::array{{0, 2}});
+ }
+ return;
+ }
+ }
+};
+
+} // namespace operators
+} // namespace paddle
diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc
index 6574880c0eb6324b2dd175e39a364d2ef46e735e..3d76516405960c502a46997108049b2db5cab6bf 100644
--- a/paddle/operators/gaussian_random_op.cc
+++ b/paddle/operators/gaussian_random_op.cc
@@ -31,7 +31,7 @@ class CPUGaussianRandomKernel : public framework::OpKernel {
}
engine.seed(seed);
std::normal_distribution dist(mean, std);
- int64_t size = framework::product(tensor->dims());
+ int64_t size = tensor->numel();
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine);
}
diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu
index d9dbc1dcfe6a6676938d64be93c879ea69148018..2d63b3049988cfc3135a87a57dad56b970df3eab 100644
--- a/paddle/operators/gaussian_random_op.cu
+++ b/paddle/operators/gaussian_random_op.cu
@@ -50,8 +50,8 @@ class GPUGaussianRandomKernel : public framework::OpKernel {
T mean = static_cast(context.Attr("mean"));
T std = static_cast(context.Attr("std"));
thrust::counting_iterator index_sequence_begin(0);
- ssize_t N = framework::product(tensor->dims());
- thrust::transform(index_sequence_begin, index_sequence_begin + N,
+ int64_t size = tensor->numel();
+ thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr(data),
GaussianGenerator(mean, std, seed));
}
diff --git a/paddle/operators/lookup_table_op.cu b/paddle/operators/lookup_table_op.cu
index 27eee3436af8107cef2aa3577ea238be49edf1af..708344046760691aa2da562eb1ee3d8b130c5f18 100644
--- a/paddle/operators/lookup_table_op.cu
+++ b/paddle/operators/lookup_table_op.cu
@@ -70,7 +70,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
size_t N = table_t->dims()[0];
size_t D = table_t->dims()[1];
- size_t K = product(ids_t->dims());
+ size_t K = ids_t->numel();
auto ids = ids_t->data();
auto table = table_t->data();
auto output = output_t->mutable_data(context.GetPlace());
@@ -91,7 +91,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel {
int N = d_table_t->dims()[0];
int D = d_table_t->dims()[1];
- int K = product(ids_t->dims());
+ int K = ids_t->numel();
const int32_t* ids = ids_t->data();
const T* d_output = d_output_t->data();
T* d_table = d_table_t->mutable_data(context.GetPlace());
diff --git a/paddle/operators/lookup_table_op.h b/paddle/operators/lookup_table_op.h
index 877b36cef4ea9cdaaaf37c97d5e5bfce55b91436..a1298906dd4b4209644fe06584f70169519de01c 100644
--- a/paddle/operators/lookup_table_op.h
+++ b/paddle/operators/lookup_table_op.h
@@ -35,7 +35,7 @@ class LookupTableKernel : public framework::OpKernel {
auto ids = ids_t->data();
auto table = table_t->data();
auto output = output_t->mutable_data(context.GetPlace());
- for (ssize_t i = 0; i < product(ids_t->dims()); ++i) {
+ for (int64_t i = 0; i < ids_t->numel(); ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
memcpy(output + i * D, table + ids[i] * D, D * sizeof(T));
@@ -61,7 +61,7 @@ class LookupTableGradKernel : public framework::OpKernel {
t.device(context.GetEigenDevice()) =
t.constant(static_cast(0));
- for (ssize_t i = 0; i < product(ids_t->dims()); ++i) {
+ for (int64_t i = 0; i < ids_t->numel(); ++i) {
PADDLE_ENFORCE_LT(ids[i], N);
PADDLE_ENFORCE_GE(ids[i], 0);
for (int j = 0; j < D; ++j) {
diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc
index 186a33edcec88bd5e51091a524a778eeb27ad526..4f380388b108dc173d847f027ba5c9db387a87f8 100644
--- a/paddle/operators/math/im2col_test.cc
+++ b/paddle/operators/math/im2col_test.cc
@@ -119,4 +119,4 @@ TEST(math, im2col) {
#ifndef PADDLE_ONLY_CPU
testIm2col();
#endif
-}
\ No newline at end of file
+}
diff --git a/paddle/operators/mean_op.h b/paddle/operators/mean_op.h
index 9848af280b62729bef9243052ceae0b7d8f4c6f5..ce31e178d8e375dc59be80a6c05133201308da70 100644
--- a/paddle/operators/mean_op.h
+++ b/paddle/operators/mean_op.h
@@ -49,12 +49,11 @@ class MeanGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input(framework::GradVarName("Out"));
- PADDLE_ENFORCE(framework::product(OG->dims()) == 1,
- "Mean Gradient should be scalar");
+ PADDLE_ENFORCE(OG->numel() == 1, "Mean Gradient should be scalar");
auto IG = context.Output(framework::GradVarName("X"));
IG->mutable_data(context.GetPlace());
- T ig_size = (T)framework::product(IG->dims());
+ T ig_size = static_cast(IG->numel());
Eigen::DSizes bcast(ig_size);
EigenVector::Flatten(*IG).device(context.GetEigenDevice()) =
diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc
index 069fb5e1abc657aa02a50fde352ce88d078c36e1..a4876feb2edf77bd422fa2a7687b0fa7d55dae47 100644
--- a/paddle/operators/minus_op.cc
+++ b/paddle/operators/minus_op.cc
@@ -31,8 +31,7 @@ class MinusOp : public framework::OperatorWithKernel {
auto *right_tensor = ctx.Input("Y");
PADDLE_ENFORCE_EQ(
- framework::product(left_tensor->dims()),
- framework::product(right_tensor->dims()),
+ left_tensor->numel(), right_tensor->numel(),
"Minus operator must take two tensor with same num of elements");
ctx.Output("Out")->Resize(left_tensor->dims());
}
diff --git a/paddle/operators/modified_huber_loss_op.cc b/paddle/operators/modified_huber_loss_op.cc
index 631d406fd45e35e8a1e5b9bd237c7fe733af6b5b..235bfe47b372d5b9d41196f72a05635b0e2c94b2 100644
--- a/paddle/operators/modified_huber_loss_op.cc
+++ b/paddle/operators/modified_huber_loss_op.cc
@@ -31,11 +31,10 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(x->dims(), y->dims(),
"Dimensions of X and Y must be the same.");
- PADDLE_ENFORCE_EQ(framework::arity(x->dims()), 2,
- "Tensor rank of X must be 2.");
- PADDLE_ENFORCE_EQ(x->dims()[1], 1, "Second dimension of X must be 1.");
+ PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Tensor rank of X must be 2.");
+ PADDLE_ENFORCE_EQ(x->dims()[1], 1, "2nd dimension of X must be 1.");
- context.Output("intermediate_val")->Resize(x->dims());
+ context.Output("IntermediateVal")->Resize(x->dims());
context.Output("Out")->Resize({x->dims()[0], 1});
}
};
@@ -47,7 +46,7 @@ class ModifiedHuberLossOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input value of ModifiedHuberLossOp.");
AddInput("Y", "Target labels of ModifiedHuberLossOp.");
- AddOutput("intermediate_val",
+ AddOutput("IntermediateVal",
"Variable to save intermediate result which will be reused in "
"backward processing.")
.AsIntermediate();
@@ -75,7 +74,7 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext& context) const override {
auto* x = context.Input("X");
auto* y = context.Input("Y");
- auto* intermediate_val = context.Input("intermediate_val");
+ auto* intermediate_val = context.Input("IntermediateVal");
auto* out_grad = context.Input(framework::GradVarName("Out"));
auto* x_grad = context.Output(framework::GradVarName("X"));
diff --git a/paddle/operators/modified_huber_loss_op.cu b/paddle/operators/modified_huber_loss_op.cu
index f8aa5043dd8e8a95c4b36ff14022b8a0a57c74cf..bce760f95e72cfec05b07591e0fa1250168b112f 100644
--- a/paddle/operators/modified_huber_loss_op.cu
+++ b/paddle/operators/modified_huber_loss_op.cu
@@ -43,7 +43,7 @@ class ModifiedHuberLossGradGPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input("Y");
- auto* in1 = context.Input("intermediate_val");
+ auto* in1 = context.Input("IntermediateVal");
auto* in2 = context.Input(framework::GradVarName("Out"));
auto* out0 = context.Output(framework::GradVarName("X"));
diff --git a/paddle/operators/modified_huber_loss_op.h b/paddle/operators/modified_huber_loss_op.h
index ffb89e806f0e5b546c1c4c386e0e3a2e9cfc7a04..e78be06ebd2a89e9086908c9c07c63c990f578f5 100644
--- a/paddle/operators/modified_huber_loss_op.h
+++ b/paddle/operators/modified_huber_loss_op.h
@@ -52,7 +52,7 @@ class ModifiedHuberLossKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input("X");
auto* in1 = context.Input("Y");
- auto* out0 = context.Output("intermediate_val");
+ auto* out0 = context.Output("IntermediateVal");
auto* out1 = context.Output("Out");
out0->mutable_data(context.GetPlace());
@@ -77,7 +77,7 @@ class ModifiedHuberLossGradCPUKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input("Y");
- auto* in1 = context.Input("intermediate_val");
+ auto* in1 = context.Input("IntermediateVal");
auto* in2 = context.Input(framework::GradVarName("Out"));
auto* out0 = context.Output(framework::GradVarName("X"));
diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc
index 28a47cdff2e9b7a965ff9f99e787bb8315010823..710a56a0e8e2d17162d7d000df226f1537104eb9 100644
--- a/paddle/operators/mul_op.cc
+++ b/paddle/operators/mul_op.cc
@@ -25,18 +25,27 @@ class MulOp : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
- auto dim0 = ctx.Input("X")->dims();
- auto dim1 = ctx.Input("Y")->dims();
- PADDLE_ENFORCE_EQ(dim0.size(), 2,
- "input X(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("X"));
- PADDLE_ENFORCE_EQ(dim1.size(), 2,
- "input Y(%s) should be a tensor with 2 dims, a matrix",
- ctx.op().Input("Y"));
+ auto x_dims = ctx.Input("X")->dims();
+ auto y_dims = ctx.Input("Y")->dims();
+ int x_num_col_dims = Attr("x_num_col_dims");
+ int y_num_col_dims = Attr("y_num_col_dims");
+
+ PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
+ "The rank of input tensor X(%s) should be larger than "
+ "`mul_op`'s `x_num_col_dims`.",
+ ctx.op().Input("X"));
+ PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
+ "The rank of input tensor Y(%s) should be larger than "
+ "`mul_op`'s `y_num_col_dims`.",
+ ctx.op().Input("Y"));
+
+ auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
+ auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
+
PADDLE_ENFORCE_EQ(
- dim0[1], dim1[0],
+ x_mat_dims[1], y_mat_dims[0],
"First matrix's width must be equal with second matrix's height.");
- ctx.Output("Out")->Resize({dim0[0], dim1[1]});
+ ctx.Output("Out")->Resize({x_mat_dims[0], y_mat_dims[1]});
}
};
@@ -47,6 +56,23 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X", "The first input of mul op");
AddInput("Y", "The second input of mul op");
AddOutput("Out", "The output of mul op");
+ AddAttr(
+ "x_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `X`,
+ in that case, tensors will be reshaped to a matrix. The matrix's first
+ dimension(column length) will be the product of tensor's last
+ `num_col_dims` dimensions, and the matrix's second dimension(row length)
+ will be the product of tensor's first `rank - num_col_dims` dimensions.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
+ AddAttr(
+ "y_num_col_dims",
+ R"DOC(mul_op can take tensors with more than two dimensions as input `Y`,
+ in that case, tensors will be reshaped to a matrix. Just like input `X`.
+ )DOC")
+ .SetDefault(1)
+ .EqualGreaterThan(1);
AddComment(R"DOC(
Two Element Mul Operator.
@@ -70,10 +96,20 @@ class MulOpGrad : public framework::OperatorWithKernel {
auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims();
auto *x_grad = ctx.Output(framework::GradVarName("X"));
auto *y_grad = ctx.Output(framework::GradVarName("Y"));
- PADDLE_ENFORCE(x_dims[0] == out_dims[0],
- "Out@GRAD M X N must equal to X dims 0, M ");
- PADDLE_ENFORCE(y_dims[1] == out_dims[1],
- "Out@GRAD M X N must equal to Y dims 1, N ");
+
+ auto x_mat_dims =
+ framework::flatten_to_2d(x_dims, Attr("x_num_col_dims"));
+ auto y_mat_dims =
+ framework::flatten_to_2d(y_dims, Attr("y_num_col_dims"));
+
+ PADDLE_ENFORCE_EQ(
+ x_mat_dims[0], out_dims[0],
+ "The first dimension of Out@GRAD must equal to the first dimension of "
+ "the first operand.");
+ PADDLE_ENFORCE_EQ(
+ y_mat_dims[1], out_dims[1],
+ "The second dimension of Out@GRAD must equal to the second "
+ "dimension of the second operand.");
if (x_grad) x_grad->Resize(x_dims);
if (y_grad) y_grad->Resize(y_dims);
diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h
index 05a79e13b3470e39a5ebd0394ba05629553a5075..3c01f868bda8cba488b3403df456d63d6b082fa6 100644
--- a/paddle/operators/mul_op.h
+++ b/paddle/operators/mul_op.h
@@ -1,7 +1,7 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
+ You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
@@ -31,13 +31,25 @@ template
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
- auto* x = context.Input("X");
- auto* y = context.Input("Y");
- auto* z = context.Output("Out");
+ const Tensor* x = context.Input("X");
+ const Tensor* y = context.Input("Y");
+ Tensor* z = context.Output("Out");
+ const Tensor x_matrix =
+ x->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *x, context.template Attr("x_num_col_dims"))
+ : *x;
+ const Tensor y_matrix =
+ y->dims().size() > 2
+ ? framework::ReshapeToMatrix(
+ *y, context.template Attr("y_num_col_dims"))
+ : *y;
+
z->mutable_data(context.GetPlace());
auto* device_context =
const_cast(context.device_context_);
- math::matmul(*x, false, *y, false, 1, z, 0, device_context);
+ math::matmul(x_matrix, false, y_matrix, false, 1, z, 0,
+ device_context);
}
};
@@ -45,23 +57,39 @@ template
class MulGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
- auto* x = ctx.Input("X");
- auto* y = ctx.Input("Y");
- auto* dout = ctx.Input(framework::GradVarName("Out"));
+ int x_num_col_dims = ctx.template Attr