“3bea13c48579c9b28470983b07dc90316528e6bb”上不存在“examples/C++/PaddleRec/criteo_ctr/local_train.py”
提交 62b56e76 编写于 作者: Y Yang Yang(Tony) 提交者: GitHub

Merge branch 'develop' into update-doc-pybind

# Design Doc: Refactorization Overview # Design Doc: Refactorization Overview
The goal of refactorizaiton include: The goals of refactoring include:
1. Make it easy for external contributors to write new elementory computaiton operations. 1. Making it easy for external contributors to write new elementary computation operations.
1. Make the codebase clean and readable. 1. Making the codebase clean and readable.
1. Introduce a new design of computation representation -- a computation graph of operators and variables. 1. Designing a new computation representation -- a computation graph of operators and variables.
1. The graph representation helps implementing auto-scalable and auto fault recoverable distributed computing. 1. Implementing auto-scalability and auto fault recoverable distributed computing with the help of computation graphs.
## Computation Graphs ## Computation Graphs
1. PaddlePaddle represent the computation, training and inference of DL models, by computation graphs. 1. PaddlePaddle represents the computation, training and inference of Deep Learning models, by computation graphs.
1. Please dig into [computation graphs](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/graph.md) for a solid example. 1. Please refer to [computation graphs](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/graph.md) for a concrete example.
1. Users write Python programs to describe the graphs and run it (locally or remotely). 1. Users write Python programs to describe the graphs and run them (locally or remotely).
1. A graph is composed of *variables* and *operators*. 1. A graph is composed of *variables* and *operators*.
1. The description of graphs must be able to be serialized/deserialized, so it 1. The description of graphs must be capable of being serialized/deserialized, so that:
1. could to be sent to the cloud for distributed execution, and 1. It can to be sent to the cloud for distributed execution, and
1. be sent to clients for mobile or enterprise deployment. 1. It can be sent to clients for mobile or enterprise deployment.
1. The Python program do 1. The Python program does the following steps
1. *compilation*: runs a Python program to generate a protobuf message representation of the graph and send it to 1. *compilation*: run a Python program to generate a protobuf message representation of the graph and send it to
1. the C++ library `libpaddle.so` for local execution, 1. the C++ library `libpaddle.so` for local execution,
1. the master process of a distributed training job for training, or 1. the master process of a distributed training job for training, or
1. the server process of a Kubernetes serving job for distributed serving. 1. the server process of a Kubernetes serving job for distributed serving.
1. *execution*: according to the protobuf message, constructs instances of class `Variable` and `OperatorBase`, and run them. 1. *execution*: execute the graph by constructing instances of class [`Variable`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h#L24) and [`OperatorBase`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L70), according to the protobuf message.
## Description and Realization ## Description and Realization of Computation Graph
At compile time, the Python program generates protobuf message representation of the graph, or the description of the graph. At compile time, the Python program generates a protobuf message representation of the graph, or the description of the graph.
At runtime, the C++ program realizes the graph and run it. At runtime, the C++ program realizes the graph and runs it.
| | Representation (protobuf messages) | Realization (C++ class objects) | | | Representation (protobuf messages) | Realization (C++ class objects) |
|---|---|---| |---|---|---|
...@@ -42,30 +42,31 @@ At runtime, the C++ program realizes the graph and run it. ...@@ -42,30 +42,31 @@ At runtime, the C++ program realizes the graph and run it.
|Operation|[OpDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L35)|[Operator](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L64)| |Operation|[OpDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L35)|[Operator](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L64)|
|Block|BlockDesc|Block| |Block|BlockDesc|Block|
The word *graph* is exchangable with *block* in this document. A graph represent computation steps and local variables as a C++/Java program block, or a pair of { and }. The word *graph* is interchangeable with *block* in this document. A graph represents computation steps and local variables similar to a C++/Java program block, or a pair of parentheses(`{` and `}`).
## Compilation and Execution ## Compilation and Execution
1. Run an applicaton Python program to describe the graph. In particular, 1. Run an application Python program to describe the graph. In particular, the Python application program does the following:
1. create VarDesc to represent local/intermediate variables, 1. Create `VarDesc` to represent local/intermediate variables,
1. create operators and set attributes, 1. Create operators and set attributes,
1. validate attribute values, 1. Validate attribute values,
1. inference the type and the shape of variables, 1. Infer the type and the shape of variables,
1. plan for memory-reuse for variables, 1. Plan memory-reuse for variables,
1. generate backward and optimization part of the Graph. 1. Generate the backward graph
1. possiblly split the graph for distributed training. 1. Optimize the computation graph.
1. Potentially, split the graph for distributed training.
1. The invocation of `train` or `infer` in the application Python program: 1. The invocation of `train` or [`infer`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/inference.py#L108) methods in the application Python program does the following:
1. create a new Scope instance in the [scope hierarchy](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/scope.md) for each run of a block, 1. Create a new Scope instance in the [scope hierarchy](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/scope.md) for each run of a block,
1. realize local variables defined in the BlockDesc message in the new scope, 1. realize local variables defined in the BlockDesc message in the new scope,
1. a scope is similar to the stack frame in programming languages, 1. a scope is similar to the stack frame in programming languages,
1. create an instance of class `Block`, in which, 1. Create an instance of class `Block`, in which,
1. realize operators in the BlockDesc message, 1. realize operators in the BlockDesc message,
1. run the Block by calling 1. Run the Block by calling
1. `Block::Eval(vector<Variable>* targets)` for forward and backward computations, or 1. `Block::Eval(vector<Variable>* targets)` for forward and backward computations, or
1. `Block::Eval(vector<Operator>* targets)` for optimization. 1. `Block::Eval(vector<Operator>* targets)` for optimization.
...@@ -76,14 +77,14 @@ The word *graph* is exchangable with *block* in this document. A graph represen ...@@ -76,14 +77,14 @@ The word *graph* is exchangable with *block* in this document. A graph represen
Compile Time -> IR -> Runtime Compile Time -> IR -> Runtime
``` ```
### Benefit ### Benefits of IR
- Optimization - Optimization
```text ```text
Compile Time -> IR -> Optimized IR -> Runtime Compile Time -> IR -> Optimized IR -> Runtime
``` ```
- Send automatically partitioned IR to different nodes. - Automatically send partitioned IR to different nodes.
- Automatic data parallel - Automatic Data Parallelism
```text ```text
Compile Time Compile Time
|-> Single GPU IR |-> Single GPU IR
...@@ -92,7 +93,7 @@ Compile Time -> IR -> Runtime ...@@ -92,7 +93,7 @@ Compile Time -> IR -> Runtime
|-> Node-1 (runs trainer-IR-1) |-> Node-1 (runs trainer-IR-1)
|-> Node-2 (runs pserver-IR) |-> Node-2 (runs pserver-IR)
``` ```
- Automatic model parallel (planned for future) - Automatic Model Parallelism (planned for future)
--- ---
...@@ -105,10 +106,10 @@ Compile Time -> IR -> Runtime ...@@ -105,10 +106,10 @@ Compile Time -> IR -> Runtime
# Operator # Operator
![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/dd598e8f1976f5759f58af5e5ef94738a6b2e661/op.dot) ![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/dd598e8f1976f5759f58af5e5ef94738a6b2e661/op.dot)
* `Operator` is the fundamental building block as the user interface. * `Operator` is the fundamental building block of the user interface.
* Operator stores input/output variable name, and attributes. * Operator stores input/output variable names, and attributes.
* The `InferShape` interface is used to infer output variable shapes by its input shapes. * The `InferShape` interface is used to infer the shape of the output variable shapes based on the shapes of the input variables.
* Use `Run` to compute `input variables` to `output variables`. * Use `Run` to compute the `output` variables from the `input` variables.
--- ---
...@@ -126,30 +127,29 @@ Compile Time -> IR -> Runtime ...@@ -126,30 +127,29 @@ Compile Time -> IR -> Runtime
# Why separate Kernel and Operator # Why separate Kernel and Operator
* Separate GPU and CPU code. * Separate GPU and CPU code.
* Make Paddle can run without GPU. * Make Paddle capable of running without GPU.
* Make one operator (which is user interface) can contain many implementations. * Make one operator (which is a user interface) and create many implementations.
* Same mul op, different FP16, FP32 Kernel. different MKL, eigen kernel. * For example, same multiplication op can have different implementations kernels such as FP16 kernel, FP32 kernel, MKL, eigen kernel.
--- ---
# Libraries for Kernel development # Libraries for Kernel development
* `Eigen::Tensor` contains basic math and element-wise functions. * `Eigen::Tensor` contains basic math and element-wise functions.
* Note that `Eigen::Tensor` has broadcast implementation. * Note that `Eigen::Tensor` has broadcast implementation.
* Limit number of `tensor.device(dev) = ` in your code. * Limit the number of `tensor.device(dev) = ` in your code.
* `thrust::tranform` and `std::transform`. * `thrust::transform` and `std::transform`.
* `thrust` has the same API as C++ standard library. Using `transform` can quickly implement a customized elementwise kernel. * `thrust` has the same API as C++ standard library. Using `transform`, one can quickly implement customized element-wise kernels.
* `thrust` has more complex API, like `scan`, `reduce`, `reduce_by_key`. * `thrust` also has more complex APIs, like `scan`, `reduce`, `reduce_by_key`.
* Hand-writing `GPUKernel` and `CPU` code * Hand-writing `GPUKernel` and `CPU` code
* Do not write `.h`. CPU Kernel should be in `.cc`. GPU kernel should be in `.cu`. (`GCC` cannot compile GPU code.) * Do not write in header (`.h`) files. CPU Kernel should be in cpp source (`.cc`) and GPU kernels should be in cuda (`.cu`) files. (GCC cannot compile GPU code.)
--- ---
# Operator Register # Operator Registration
## Why register is necessary? ## Why is registration necessary?
We need a method to build mappings between Op type names and Op classes. We need a method to build mappings between Op type names and Op classes.
## How to do the register? ## How is registration implemented?
Maintaining a map, whose key is the type name and the value is the corresponding Op constructor.
Maintain a map, whose key is the type name and value is corresponding Op constructor.
--- ---
# The Registry Map # The Registry Map
...@@ -169,7 +169,7 @@ Maintain a map, whose key is the type name and value is corresponding Op constru ...@@ -169,7 +169,7 @@ Maintain a map, whose key is the type name and value is corresponding Op constru
# Related Concepts # Related Concepts
### Op_Maker ### Op_Maker
It's constructor takes `proto` and `checker`. They are compeleted during Op_Maker's construction. ([ScaleOpMaker](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/scale_op.cc#L37)) It's constructor takes `proto` and `checker`. They are completed during Op_Maker's construction. ([ScaleOpMaker](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/scale_op.cc#L37))
### Register Macros ### Register Macros
```cpp ```cpp
...@@ -177,32 +177,30 @@ REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, grad_op_class) ...@@ -177,32 +177,30 @@ REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, grad_op_class)
REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class)
``` ```
--- ---
# Register Process # Registration Process
1. Write Op class, as well as its gradient Op class if there is. 1. Write an Op class and its gradient Op class, if required.
2. Write Op maker class. In the constructor, describe its inputs, outputs, and attributes. 2. Write an Op maker class. In the constructor of this class, describe the inputs, outputs and attributes of the operator.
3. Invoke macro `REGISTER_OP`. The macro will 3. Invoke the macro `REGISTER_OP`. This macro will
1. call maker class to complete `proto` and `checker` 1. Call maker class to complete the `proto` and the `checker`
2. with the completed `proto` and `checker`, build a new key-value pair in the `OpInfoMap` 2. Using the completed `proto` and `checker`, it will add a new key-value pair to the `OpInfoMap`
--- ---
# Backward Module (1/2) # Backward Module (1/2)
### Create Backward Operator ### Create Backward Operator
- Mapping from forwarding Op to backward Op - Mapping from forward Op to backward Op
![backward](https://gist.githubusercontent.com/dzhwinter/a6fbd4623ee76c459f7f94591fd1abf0/raw/61026ab6e518e66bde66a889bc42557a1fccff33/backward.png) ![backward](https://gist.githubusercontent.com/dzhwinter/a6fbd4623ee76c459f7f94591fd1abf0/raw/61026ab6e518e66bde66a889bc42557a1fccff33/backward.png)
--- ---
# Backward Module (2/2) # Backward Module (2/2)
### Build Backward Network ### Build Backward Network
- **Input** graph of forwarding operators - **Input**: graph of forward operators
- **Output** graph of backward operators - **Output**: graph of backward operators
- **corner case in construction** - **Corner cases in construction**
- shared variable => insert `Add` operator - Shared Variables => insert an `Add` operator to combine gradients
- no gradient => insert `fill_zero_grad` operator - No Gradient => insert a `fill_zero_grad` operator
- recursive netOp => call `Backward` recursively - Recursive NetOp => call `Backward` recursively
- RNN Op => recursively call `Backward` on stepnet - RNN Op => recursively call `Backward` on stepnet
...@@ -211,41 +209,41 @@ REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) ...@@ -211,41 +209,41 @@ REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class)
* `Tensor` is an n-dimension array with type. * `Tensor` is an n-dimension array with type.
* Only dims and data pointers are stored in `Tensor`. * Only dims and data pointers are stored in `Tensor`.
* All operators on `Tensor` is written in `Operator` or global functions. * All operations on `Tensor` are written in `Operator` or global functions.
* variable length Tensor design [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md) * Variable length Tensor design [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md)
* `Variable` is the inputs and outputs of an operator. Not just `Tensor`. * `Variable` instances are the inputs and the outputs of an operator. Not just `Tensor`.
* step_scopes in RNN is a variable and not a tensor. * `step_scopes` in RNN is a variable and not a tensor.
* `Scope` is where variables store at. * `Scope` is where variables are stores.
* map<string/*var name */, Variable> * map<string `variable_name`, Variable>
* `Scope` has a hierarchical structure. The local scope can get variable from its parent scope. * `Scope` has a hierarchical structure. The local scope can get variables from its parent scope.
--- ---
# Block (in design) # Block (in design)
## the difference with original RNNOp ## the difference between original RNNOp and Block
- as an operator is more intuitive than `RNNOp`, - As an operator is more intuitive than `RNNOp`,
- offers new interface `Eval(targets)` to deduce the minimal block to `Run`, - Offers a new interface `Eval(targets)` to deduce the minimal block to `Run`,
- fits the compile-time/ runtime separation design. - Fits the compile-time/ runtime separation design paradigm.
- during the compilation, `SymbolTable` stores `VarDesc`s and `OpDesc`s and serialize to a `BlockDesc` - During the compilation, `SymbolTable` stores `VarDesc`s and `OpDesc`s and serialize to a `BlockDesc`
- when graph executes, a Block with `BlockDesc` passed in creates `Op` and `Var` then `Run` - When graph executes, a Block with `BlockDesc` is passed. It then creates `Op` and `Var` instances and then invokes `Run`.
--- ---
# Milestone # Milestone
- take Paddle/books as the main line, the requirement of the models motivates framework refactoring, - Take Paddle/books as the main line, the requirement of the models motivates framework refactoring,
- model migration - Model migration
- framework development gives **priority support** to model migration, for example, - Framework development gives **priority support** to model migration, for example,
- the MNIST demo needs a Python interface, - the MNIST demo needs a Python interface,
- the RNN models require the framework to support `LoDTensor`. - the RNN models require the framework to support `LoDTensor`.
- determine some timelines, - Determine some timelines,
- heavily-relied Ops need to be migrated first, - Frequently used Ops need to be migrated first,
- different models can be migrated parallelly. - Different models can be migrated in parallel.
- improve the framework at the same time - Improve the framework at the same time
- accept imperfection, concentrated on solving the specific problem at the right price. - Accept imperfection, concentrate on solving the specific problem at the right price.
--- ---
# Control the migration quality # Control the migration quality
- compare the performance of migrated models with old ones. - Compare the performance of migrated models with old ones.
- follow google C style - Follow the google C++ style
- build the automatic workflow of generating Python/C++ documentations - Build the automatic workflow of generating Python/C++ documentations.
- the documentation of layers and ops should be written inside the code - The documentation of layers and ops should be written inside the code.
- take the documentation quality into account when doing PR - Take the documentation quality into account when submitting pull requests.
- preview the documentations, read and improve them from users' perspective - Preview the documentations, read and improve them from a user's perspective.
...@@ -285,41 +285,27 @@ class TestMulGradOp(GradientChecker): ...@@ -285,41 +285,27 @@ class TestMulGradOp(GradientChecker):
'Y': np.random.random((84, 100)).astype("float32") 'Y': np.random.random((84, 100)).astype("float32")
} }
def test_cpu_gpu_compare(self): def test_check_grad_normal(self):
self.compare_grad(self.op, self.inputs)
def test_normal(self):
# mul op will enlarge the relative error # mul op will enlarge the relative error
self.check_grad( self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.5)
self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5)
def test_ignore_x(self): def test_check_grad_ingore_x(self):
self.check_grad( self.check_grad(
self.op, ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
self.inputs, ["Y"],
"Out",
max_relative_error=0.5,
no_grad_set={"X"})
def test_ignore_y(self): def test_check_grad_ingore_y(self):
self.check_grad( self.check_grad(
self.op, ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
self.inputs, ["X"],
"Out",
max_relative_error=0.5,
no_grad_set={"Y"})
``` ```
下面解释代码中一些关键的地方: 下面解释代码中一些关键的地方:
- 调用`create_op("mul")`创建反向Op对应的前向Op。 - 调用`create_op("mul")`创建反向Op对应的前向Op。
- 调用`compare_grad`函数对比CPU、GPU计算结果。 - `test_check_grad_normal`中调用`check_grad`使用数值法检测梯度正确性和稳定性。
- `test_normal`中调用`check_grad`使用数值法检测梯度正确性和稳定性。 - 第一个参数`["X", "Y"]` : 指定对输入变量`X``Y`做梯度检测。
- 第一个参数`self.op` : 前向Op。 - 第二个参数`"Out"` : 指定前向网络最终的输出目标变量`Out`
- 第二个参数`self.inputs` : 输入词典,词典的Key和`ProtoMaker`定义保持一致。 - 第三个参数`max_relative_error`:指定检测梯度时能容忍的最大错误值。
- 第三个参数`["X", "Y"]` : 指定对输入变量`X``Y`做梯度检测。 - `test_check_grad_ingore_x``test_check_grad_ingore_y`分支用来测试只需要计算一个输入梯度的情况。
- 第四个参数`"Out"` : 指定前向网络最终的输出目标变量`Out`
- `test_ignore_x``test_ignore_y`分支用来测试只需要计算一个输入梯度的情况。
### 编译和执行单元测试 ### 编译和执行单元测试
......
...@@ -293,41 +293,27 @@ class TestMulGradOp(GradientChecker): ...@@ -293,41 +293,27 @@ class TestMulGradOp(GradientChecker):
'Y': np.random.random((84, 100)).astype("float32") 'Y': np.random.random((84, 100)).astype("float32")
} }
def test_cpu_gpu_compare(self): def test_check_grad_normal(self):
self.compare_grad(self.op, self.inputs)
def test_normal(self):
# mul op will enlarge the relative error # mul op will enlarge the relative error
self.check_grad( self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.5)
self.op, self.inputs, ["X", "Y"], "Out", max_relative_error=0.5)
def test_ignore_x(self): def test_check_grad_ingore_x(self):
self.check_grad( self.check_grad(
self.op, ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X"))
self.inputs, ["Y"],
"Out",
max_relative_error=0.5,
no_grad_set={"X"})
def test_ignore_y(self): def test_check_grad_ingore_y(self):
self.check_grad( self.check_grad(
self.op, ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y'))
self.inputs, ["X"],
"Out",
max_relative_error=0.5,
no_grad_set={"Y"})
``` ```
Some key points in the code above include: Some key points in the code above include:
- `create_op("mul")` creates the backward operator's corresponding forward operator. - `create_op("mul")` creates the backward operator's corresponding forward operator.
- `compare_grad` compares results between utilizing the CPU and the GPU.
- `test_normal` calls `check_grad` to validate scaling tests' correctness and stability through numeric methods. - `test_normal` calls `check_grad` to validate scaling tests' correctness and stability through numeric methods.
- The first variable `self.op` denotes the forward operator. - The first variable `["X", "Y"]` appoints `X` and `Y` to be scale tested.
- The second variable `self.inputs` denotes the input dictionary, which has its key value identical to its `ProtoMaker` definitions. - The second variable `"Out"` points to the network's final output target `Out`.
- The third variable `["X", "Y"]` appoints `X` and `Y` to be scale tested. - The third variable `max_relative_error` points to the maximum relative tolerance error during scaling tests.
- The fourth variable `"Out"` points to the network's final output target `Out`. - `test_check_grad_ingore_x` and `test_check_grad_ingore_y`branches test the cases where there is only one scaling input.
- `test_ignore_x` and `test_ignore_y`branches test the cases where there is only one scaling input.
### Compiling and Running ### Compiling and Running
......
...@@ -26,7 +26,7 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) ...@@ -26,7 +26,7 @@ cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info) cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry) cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op) cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)
......
...@@ -12,12 +12,25 @@ ...@@ -12,12 +12,25 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #pragma once
#include "paddle/operators/rowwise_add_op.h" #include <typeindex>
#include "paddle/framework/framework.pb.h"
namespace ops = paddle::operators; namespace paddle {
REGISTER_OP_GPU_KERNEL( namespace framework {
rowwise_add, ops::RowwiseAddKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL( inline DataType ToDataType(std::type_index type) {
rowwise_add_grad, if (typeid(float).hash_code() == type.hash_code()) {
ops::RowwiseAddGradKernel<paddle::platform::GPUPlace, float>); return DataType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32;
} else {
PADDLE_THROW("Not supported");
return static_cast<DataType>(-1);
}
}
} // namespace framework
} // namespace paddle
...@@ -54,5 +54,44 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { ...@@ -54,5 +54,44 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs()); return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs());
} }
static void TransOpDescArg(const OpDescBind* src_op, const OpArgType& src_type,
bool is_grad, OpDescBind* dst_op,
const OpArgType& dst_type) {
PADDLE_ENFORCE(dst_op != nullptr,
"Protobuf desc of gradient op must be initialized first.");
const auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto();
const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) {
if (arg.not_in_gradient() && !is_grad) continue;
const std::string src_name = arg.name();
std::vector<std::string> vars = src_type == OpArgType::IN
? src_op->Input(src_name)
: src_op->Output(src_name);
if (is_grad) {
for (std::string& var : vars) {
var = GradVarName(var);
}
}
std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
dst_type == OpArgType::IN ? dst_op->SetInput(dst_name, vars)
: dst_op->SetOutput(dst_name, vars);
}
}
void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op) {
auto& info = OpInfoMap::Instance().Get(forw_op->Type());
PADDLE_ENFORCE(info.HasGradientOp());
grad_op->SetType(info.grad_op_type_);
TransOpDescArg(forw_op, OpArgType::IN, false, grad_op, OpArgType::IN);
TransOpDescArg(forw_op, OpArgType::OUT, false, grad_op, OpArgType::IN);
TransOpDescArg(forw_op, OpArgType::OUT, true, grad_op, OpArgType::IN);
TransOpDescArg(forw_op, OpArgType::IN, true, grad_op, OpArgType::OUT);
grad_op->SetAttrMap(forw_op->GetAttrMap());
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
namespace paddle { namespace paddle {
...@@ -21,5 +22,7 @@ namespace framework { ...@@ -21,5 +22,7 @@ namespace framework {
OperatorBase* BuildGradOp(const OperatorBase* op); OperatorBase* BuildGradOp(const OperatorBase* op);
void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -120,3 +120,82 @@ TEST(GradOpBuilder, IOIgnoredInGradient) { ...@@ -120,3 +120,82 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
std::vector<std::string>( std::vector<std::string>(
{f::GradVarName("in3_1"), f::GradVarName("in3_2")})); {f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
} }
TEST(GradOpDescBuilder, MutiInOut) {
f::OpDescBind *forw_op = new f::OpDescBind();
forw_op->SetType("mult_io");
forw_op->SetInput("In1", {"in1"});
forw_op->SetInput("In2_mult", {"in2_1", "in2_2", "in2_3"});
forw_op->SetInput("In3", {"in3"});
forw_op->SetOutput("Out1", {"out1"});
forw_op->SetOutput("Out2_mult", {"out2_1", "out2_2"});
f::OpDescBind *grad_op = new f::OpDescBind();
f::CompleteGradOpDesc(forw_op, grad_op);
EXPECT_EQ(grad_op->Type(), "mult_io_grad");
ASSERT_EQ(grad_op->InputNames().size(), 3UL + 2UL + 2UL);
EXPECT_EQ(grad_op->Input("In1"), std::vector<std::string>({"in1"}));
EXPECT_EQ(grad_op->Input("In2_mult"),
std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
EXPECT_EQ(grad_op->Input("In3"), std::vector<std::string>({"in3"}));
EXPECT_EQ(grad_op->Input("Out1"), std::vector<std::string>({"out1"}));
EXPECT_EQ(grad_op->Input("Out2_mult"),
std::vector<std::string>({"out2_1", "out2_2"}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out1")),
std::vector<std::string>({f::GradVarName("out1")}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out2_mult")),
std::vector<std::string>(
{f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
ASSERT_EQ(grad_op->OutputNames().size(), 3UL);
EXPECT_EQ(grad_op->Output(f::GradVarName("In1")),
std::vector<std::string>({f::GradVarName("in1")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")),
std::vector<std::string>({f::GradVarName("in2_1"),
f::GradVarName("in2_2"),
f::GradVarName("in2_3")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In3")),
std::vector<std::string>({f::GradVarName("in3")}));
delete forw_op;
delete grad_op;
}
TEST(GradOpDescBuilder, IOIgnoredInGradient) {
f::OpDescBind *forw_op = new f::OpDescBind();
forw_op->SetType("io_ignored");
forw_op->SetInput("In1", {"in1"});
forw_op->SetInput("In2_mult", {"in2_1", "in2_2"});
forw_op->SetInput("In3_mult", {"in3_1", "in3_2"});
forw_op->SetOutput("Out1_mult", {"out1_1", "out1_2"});
forw_op->SetOutput("Out2", {"out2"});
f::OpDescBind *grad_op = new f::OpDescBind();
f::CompleteGradOpDesc(forw_op, grad_op);
EXPECT_EQ(grad_op->Type(), "io_ignored_grad");
// 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ(grad_op->InputNames().size(), 2UL + 1UL + 2UL);
EXPECT_EQ(grad_op->Input("In1"), std::vector<std::string>({"in1"}));
EXPECT_EQ(grad_op->Input("In3_mult"),
std::vector<std::string>({"in3_1", "in3_2"}));
EXPECT_EQ(grad_op->Input("Out1_mult"),
std::vector<std::string>({"out1_1", "out1_2"}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out1_mult")),
std::vector<std::string>(
{f::GradVarName("out1_1"), f::GradVarName("out1_2")}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out2")),
std::vector<std::string>({f::GradVarName("out2")}));
ASSERT_EQ(grad_op->OutputNames().size(), 3UL);
EXPECT_EQ(grad_op->Output(f::GradVarName("In1")),
std::vector<std::string>({f::GradVarName("in1")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")),
std::vector<std::string>(
{f::GradVarName("in2_1"), f::GradVarName("in2_2")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In3_mult")),
std::vector<std::string>(
{f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
delete forw_op;
delete grad_op;
}
\ No newline at end of file
...@@ -89,6 +89,12 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { ...@@ -89,6 +89,12 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
need_update_ = true; need_update_ = true;
} }
void OpDescBind::SetAttrMap(
const std::unordered_map<std::string, Attribute> &attr_map) {
attrs_ = attr_map;
need_update_ = true;
}
Attribute OpDescBind::GetAttr(const std::string &name) const { Attribute OpDescBind::GetAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
...@@ -101,6 +107,11 @@ int OpDescBind::GetBlockAttr(const std::string &name) const { ...@@ -101,6 +107,11 @@ int OpDescBind::GetBlockAttr(const std::string &name) const {
return boost::get<BlockDesc *>(it->second)->idx(); return boost::get<BlockDesc *>(it->second)->idx();
} }
const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()
const {
return attrs_;
}
void OpDescBind::Sync() { void OpDescBind::Sync() {
if (need_update_) { if (need_update_) {
this->op_desc_.mutable_inputs()->Clear(); this->op_desc_.mutable_inputs()->Clear();
......
...@@ -60,10 +60,16 @@ class OpDescBind { ...@@ -60,10 +60,16 @@ class OpDescBind {
void SetBlockAttr(const std::string &name, BlockDescBind &block); void SetBlockAttr(const std::string &name, BlockDescBind &block);
// Only be used in C++
void SetAttrMap(const std::unordered_map<std::string, Attribute> &attr_map);
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const; int GetBlockAttr(const std::string &name) const;
// Only be used in C++
const std::unordered_map<std::string, Attribute> &GetAttrMap() const;
private: private:
struct SetAttrDescVisitor : public boost::static_visitor<void> { struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
......
...@@ -100,13 +100,39 @@ class OpRegistrar : public Registrar { ...@@ -100,13 +100,39 @@ class OpRegistrar : public Registrar {
} }
}; };
template <typename PlaceType, typename KernelType> template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor;
template <typename PlaceType, size_t I, typename... KernelTypes>
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using KERNEL_TYPE =
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
void operator()(const char* op_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))),
PlaceType());
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
func;
func(op_type);
}
};
template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
void operator()(const char* op_type) const {}
};
// User can register many kernel in one place. The data type could be different.
template <typename PlaceType, typename... KernelType>
class OpKernelRegistrar : public Registrar { class OpKernelRegistrar : public Registrar {
public: public:
explicit OpKernelRegistrar(const char* op_type) { explicit OpKernelRegistrar(const char* op_type) {
OperatorWithKernel::OpKernelKey key; OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
key.place_ = PlaceType(); func(op_type);
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType);
} }
}; };
......
...@@ -22,14 +22,14 @@ namespace framework { ...@@ -22,14 +22,14 @@ namespace framework {
template <> template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice< Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_.get_eigen_device<Eigen::DefaultDevice>(); return *device_context_.GetEigenDevice<platform::CPUPlace>();
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <> template <>
Eigen::GpuDevice& Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_.get_eigen_device<Eigen::GpuDevice>(); return *device_context_.GetEigenDevice<platform::GPUPlace>();
} }
#endif #endif
......
...@@ -22,6 +22,7 @@ limitations under the License. */ ...@@ -22,6 +22,7 @@ limitations under the License. */
#include "op_info.h" #include "op_info.h"
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
...@@ -295,21 +296,6 @@ template <> ...@@ -295,21 +296,6 @@ template <>
std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>( std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
const std::string& name) const; const std::string& name) const;
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class ExecutionContext : public InferShapeContext { class ExecutionContext : public InferShapeContext {
public: public:
ExecutionContext(const OperatorBase& op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
...@@ -317,8 +303,8 @@ class ExecutionContext : public InferShapeContext { ...@@ -317,8 +303,8 @@ class ExecutionContext : public InferShapeContext {
: InferShapeContext(op, scope), device_context_(device_context) {} : InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType, template <typename PlaceType,
typename DeviceType = typename DeviceType = typename platform::EigenDeviceConverter<
typename EigenDeviceConverter<PlaceType>::EigenDeviceType> PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const; DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); } platform::Place GetPlace() const { return device_context_.GetPlace(); }
...@@ -403,7 +389,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase { ...@@ -403,7 +389,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
const Scope& scope_; const Scope& scope_;
}; };
class OpKernel { class OpKernelBase {
public: public:
/** /**
* ExecutionContext is the only parameter of Kernel Run function. * ExecutionContext is the only parameter of Kernel Run function.
...@@ -414,33 +400,47 @@ class OpKernel { ...@@ -414,33 +400,47 @@ class OpKernel {
virtual void Compute(const ExecutionContext& context) const = 0; virtual void Compute(const ExecutionContext& context) const = 0;
virtual ~OpKernel() {} virtual ~OpKernelBase() = default;
};
template <typename T>
class OpKernel : public OpKernelBase {
public:
using ELEMENT_TYPE = T;
}; };
class OperatorWithKernel : public OperatorBase { class OperatorWithKernel : public OperatorBase {
public: public:
struct OpKernelKey { struct OpKernelKey {
platform::Place place_; platform::Place place_;
DataType data_type_;
OpKernelKey() = default; OpKernelKey(DataType data_type, platform::Place place)
explicit OpKernelKey(const platform::DeviceContext& dev_ctx) { : place_(place), data_type_(data_type) {}
place_ = dev_ctx.GetPlace();
} OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
bool operator==(const OpKernelKey& o) const { bool operator==(const OpKernelKey& o) const {
return platform::places_are_same_class(place_, o.place_); return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
} }
}; };
struct OpKernelHash { struct OpKernelHash {
std::hash<bool> hash_; std::hash<int> hash_;
size_t operator()(const OpKernelKey& key) const { size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_)); int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
(place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
return hash_(pre_hash);
} }
}; };
using OpKernelMap = using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>; std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
OpKernelHash>;
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs, OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs) const VariableNameMap& outputs, const AttributeMap& attrs)
...@@ -451,8 +451,10 @@ class OperatorWithKernel : public OperatorBase { ...@@ -451,8 +451,10 @@ class OperatorWithKernel : public OperatorBase {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx)); ExecutionContext ctx(*this, scope, dev_ctx);
opKernel->Compute(ExecutionContext(*this, scope, dev_ctx)); auto& opKernel = AllOpKernels().at(type_).at(
OpKernelKey(IndicateDataType(ctx), dev_ctx));
opKernel->Compute(ctx);
} }
static std::unordered_map<std::string /* op_type */, OpKernelMap>& static std::unordered_map<std::string /* op_type */, OpKernelMap>&
...@@ -462,13 +464,43 @@ class OperatorWithKernel : public OperatorBase { ...@@ -462,13 +464,43 @@ class OperatorWithKernel : public OperatorBase {
} }
bool SupportGPU() const override { bool SupportGPU() const override {
OperatorWithKernel::OpKernelKey key; auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
key.place_ = platform::GPUPlace(); return std::any_of(op_kernels.begin(), op_kernels.end(),
return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0; [](OpKernelMap::const_reference kern_pair) {
return platform::is_gpu_place(kern_pair.first.place_);
});
} }
protected: protected:
virtual void InferShape(InferShapeContextBase* ctx) const = 0; virtual void InferShape(InferShapeContextBase* ctx) const = 0;
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
for (auto& input : this->inputs_) {
for (auto& ipt_name : input.second) {
auto* var = scope.FindVar(ipt_name);
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op must be same.");
data_type = tmp;
}
}
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type);
}
}; };
} // namespace framework } // namespace framework
......
...@@ -114,10 +114,13 @@ class OpWithKernelTest : public OperatorWithKernel { ...@@ -114,10 +114,13 @@ class OpWithKernelTest : public OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {} void InferShape(framework::InferShapeContextBase* ctx) const override {}
DataType IndicateDataType(const ExecutionContext& ctx) const override {
return DataType::FP32;
}
}; };
template <typename T1, typename T2> template <typename T1, typename T2>
class CPUKernelTest : public OpKernel { class CPUKernelTest : public OpKernel<float> {
public: public:
void Compute(const ExecutionContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl; std::cout << "this is cpu kernel" << std::endl;
...@@ -144,7 +147,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker ...@@ -144,7 +147,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
} }
}; };
class CPUKernalMultiInputsTest : public OpKernel { class CPUKernalMultiInputsTest : public OpKernel<float> {
public: public:
void Compute(const ExecutionContext& ctx) const { void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op().Inputs("xs"); auto xs = ctx.op().Inputs("xs");
......
...@@ -29,20 +29,10 @@ limitations under the License. */ ...@@ -29,20 +29,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace pybind {
namespace details {
template <bool less, size_t i, typename... args>
struct CastToPyBufferImpl;
}
} // namespace pybind
namespace framework { namespace framework {
class Tensor { class Tensor {
public: public:
template <bool less, size_t i, typename... args>
friend struct pybind::details::CastToPyBufferImpl;
template <typename T, size_t D, int MajorType, typename IndexType> template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor; friend struct EigenTensor;
...@@ -119,6 +109,8 @@ class Tensor { ...@@ -119,6 +109,8 @@ class Tensor {
return holder_->place(); return holder_->place();
} }
std::type_index type() const { return holder_->type(); }
private: private:
template <typename T> template <typename T>
inline void check_memory_size() const; inline void check_memory_size() const;
......
...@@ -215,13 +215,13 @@ struct testActDesc { ...@@ -215,13 +215,13 @@ struct testActDesc {
static void getAddtoConfig(TestConfig& cfg, const testActDesc& pm) { static void getAddtoConfig(TestConfig& cfg, const testActDesc& pm) {
cfg.biasSize = 0; cfg.biasSize = 0;
cfg.layerConfig.set_type("addto"); cfg.layerConfig.set_type("addto");
size_t layerSize = pm.ih * pm.ih * pm.iw; size_t layerSize = pm.ic * pm.ih * pm.iw;
cfg.layerConfig.set_size(layerSize); cfg.layerConfig.set_size(layerSize);
cfg.inputDefs.push_back({INPUT_DATA, "layer_0", layerSize, 0}); cfg.inputDefs.push_back({INPUT_DATA, "layer_0", layerSize, 0});
cfg.layerConfig.add_inputs(); cfg.layerConfig.add_inputs();
} }
void testActivation(std::string& actType, const testActDesc& pm) { void testActivation(std::string actType, const testActDesc& pm) {
// TODO(TJ): remove me when paddle support elu activation // TODO(TJ): remove me when paddle support elu activation
if (actType == "mkldnn_elu") { if (actType == "mkldnn_elu") {
return; return;
...@@ -240,6 +240,7 @@ TEST(MKLDNNActivation, Activations) { ...@@ -240,6 +240,7 @@ TEST(MKLDNNActivation, Activations) {
for (auto type : types) { for (auto type : types) {
/* bs, c, h, w*/ /* bs, c, h, w*/
testActivation(type, {16, 64, 32, 32}); testActivation(type, {16, 64, 32, 32});
testActivation(type, {2, 8, 1, 1});
} }
} }
......
...@@ -99,7 +99,11 @@ public: ...@@ -99,7 +99,11 @@ public:
/** /**
* @brief clear local buffer. It only affect auto-growth buffer. * @brief clear local buffer. It only affect auto-growth buffer.
*/ */
inline void clear() { rowStore_.clear(); } inline void clear() {
// swap an empty vector to it to free the memory.
std::vector<real, AlignedAllocator<real, 32>> empty;
rowStore_.swap(empty);
}
/** /**
* @brief get current number of rows. * @brief get current number of rows.
......
...@@ -101,8 +101,8 @@ set(DEPS_OPS ...@@ -101,8 +101,8 @@ set(DEPS_OPS
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op) DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op) op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy_function) op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy_function softmax_function) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS}) foreach(src ${GENERAL_OPS})
......
...@@ -47,7 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata, ...@@ -47,7 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
} }
template <typename T> template <typename T>
class AccuracyOpCUDAKernel : public framework::OpKernel { class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
...@@ -35,7 +35,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -35,7 +35,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>; using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class AccuracyKernel : public framework::OpKernel { class AccuracyKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference"); auto* inference = ctx.Input<Tensor>("Inference");
......
...@@ -132,6 +132,17 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -132,6 +132,17 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker {
} }
}; };
class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftsignOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softsign operator");
AddOutput("Y", "Output of Softsign operator");
AddComment("Softsign activation operator, softsign(x) = x / (1 + |x|)");
}
};
template <typename AttrType> template <typename AttrType>
class BReluOpMaker : public framework::OpProtoAndCheckerMaker { class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
...@@ -195,111 +206,57 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -195,111 +206,57 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad, REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::SigmoidFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SigmoidGradFunctor<float>>);
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
exp,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::ExpFunctor>);
REGISTER_OP_CPU_KERNEL(exp_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::ExpGradFunctor>);
REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::ReluFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::ReluGradFunctor<float>>);
REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
tanh,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::TanhFunctor>);
REGISTER_OP_CPU_KERNEL(
tanh_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::TanhGradFunctor<float>>);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
sqrt,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::SqrtFunctor>);
REGISTER_OP_CPU_KERNEL(
sqrt_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SqrtGradFunctor<float>>);
REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad, REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
abs,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::AbsFunctor>);
REGISTER_OP_CPU_KERNEL(abs_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::AbsGradFunctor>);
REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker, REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker,
reciprocal_grad, ops::ActivationOpGrad); reciprocal_grad, ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(reciprocal,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::ReciprocalFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
reciprocal_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::ReciprocalGradFunctor<float>>);
REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad, REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
log,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::LogFunctor>);
REGISTER_OP_CPU_KERNEL(
log_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::LogGradFunctor<float>>);
REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad, REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::CPUPlace, float, REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad,
ops::SquareFunctor>); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad, REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(brelu,
ops::BReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(brelu_grad,
ops::BReluGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>, REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad); soft_relu_grad, ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(soft_relu,
ops::SoftReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
soft_relu_grad, ops::SoftReluGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad, REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(pow, ops::PowKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pow_grad,
ops::PowGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad, REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(stanh,
ops::STanhKernel<paddle::platform::CPUPlace, float>); #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL(stanh_grad, REGISTER_OP_CPU_KERNEL( \
ops::STanhGradKernel<paddle::platform::CPUPlace, float>); act_type, \
paddle::operators::ActivationKernel<paddle::platform::CPUPlace, \
paddle::operators::functor<float>>); \
REGISTER_OP_CPU_KERNEL(act_type##_grad, \
paddle::operators::ActivationGradKernel< \
paddle::platform::CPUPlace, \
paddle::operators::grad_functor<float>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
...@@ -15,86 +15,14 @@ ...@@ -15,86 +15,14 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/activation_op.h" #include "paddle/operators/activation_op.h"
namespace ops = paddle::operators; #define REGISTER_ACTIVATION_GPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_GPU_KERNEL( \
REGISTER_OP_GPU_KERNEL(sigmoid, act_type, \
ops::ActivationKernel<paddle::platform::GPUPlace, float, paddle::operators::ActivationKernel<paddle::platform::GPUPlace, \
ops::SigmoidFunctor<float>>); paddle::operators::functor<float>>); \
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(act_type##_grad, \
sigmoid_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float, paddle::operators::ActivationGradKernel< \
ops::SigmoidGradFunctor<float>>); paddle::platform::GPUPlace, \
paddle::operators::grad_functor<float>>);
REGISTER_OP_GPU_KERNEL(
exp, FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_GPU_KERNEL);
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::ExpFunctor>);
REGISTER_OP_GPU_KERNEL(exp_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace,
float, ops::ExpGradFunctor>);
REGISTER_OP_GPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::ReluFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::ReluGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
tanh,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::TanhFunctor>);
REGISTER_OP_GPU_KERNEL(
tanh_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::TanhGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
sqrt,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::SqrtFunctor>);
REGISTER_OP_GPU_KERNEL(
sqrt_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SqrtGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
abs,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::AbsFunctor>);
REGISTER_OP_GPU_KERNEL(abs_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace,
float, ops::AbsGradFunctor>);
REGISTER_OP_GPU_KERNEL(reciprocal,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::ReciprocalFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
reciprocal_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::ReciprocalGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
log,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::LogFunctor>);
REGISTER_OP_GPU_KERNEL(
log_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::LogGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::SquareFunctor>);
REGISTER_OP_GPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(brelu,
ops::BReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(brelu_grad,
ops::BReluGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(soft_relu,
ops::SoftReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
soft_relu_grad, ops::SoftReluGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pow, ops::PowKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pow_grad,
ops::PowGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(stanh,
ops::STanhKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(stanh_grad,
ops::STanhGradKernel<paddle::platform::GPUPlace, float>);
...@@ -19,9 +19,12 @@ ...@@ -19,9 +19,12 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T, typename Functor> template <typename Place, typename Functor>
class ActivationKernel : public framework::OpKernel { class ActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public: public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y"); auto* Y = context.Output<framework::Tensor>("Y");
...@@ -31,13 +34,20 @@ class ActivationKernel : public framework::OpKernel { ...@@ -31,13 +34,20 @@ class ActivationKernel : public framework::OpKernel {
auto y = framework::EigenVector<T>::Flatten(*Y); auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
Functor functor; Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(place, x, y); functor(place, x, y);
} }
}; };
template <typename Place, typename T, typename Functor> template <typename Place, typename Functor>
class ActivationGradKernel : public framework::OpKernel { class ActivationGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public: public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X"); auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y"); auto* Y = context.Input<framework::Tensor>("Y");
...@@ -51,303 +61,322 @@ class ActivationGradKernel : public framework::OpKernel { ...@@ -51,303 +61,322 @@ class ActivationGradKernel : public framework::OpKernel {
auto dx = framework::EigenVector<T>::Flatten(*dX); auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
Functor functor; Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(place, x, y, dy, dx); functor(place, x, y, dy, dx);
} }
}; };
template <typename T>
struct BaseActivationFunctor {
using ELEMENT_TYPE = T;
using AttrPair = std::vector<std::pair<const char*, float*>>;
AttrPair GetAttrs() { return AttrPair(); }
};
// sigmoid(x) = 1 / (1 + exp(-x)) // sigmoid(x) = 1 / (1 + exp(-x))
template <typename T> template <typename T>
struct SigmoidFunctor { struct SigmoidFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp()); y.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
} }
}; };
template <typename T> template <typename T>
struct SigmoidGradFunctor { struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * y * (static_cast<T>(1) - y); dx.device(d) = dy * y * (static_cast<T>(1) - y);
} }
}; };
// exp(x) = e^x // exp(x) = e^x
struct ExpFunctor { template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.exp(); y.device(d) = x.exp();
} }
}; };
struct ExpGradFunctor { template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * y; dx.device(d) = dy * y;
} }
}; };
// relu(x) = max(x, 0) // relu(x) = max(x, 0)
template <typename T> template <typename T>
struct ReluFunctor { struct ReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(0)); y.device(d) = x.cwiseMax(static_cast<T>(0));
} }
}; };
template <typename T> template <typename T>
struct ReluGradFunctor { struct ReluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>(); dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>();
} }
}; };
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
struct TanhFunctor { template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.tanh(); y.device(d) = x.tanh();
} }
}; };
template <typename T> template <typename T>
struct TanhGradFunctor { struct TanhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (static_cast<T>(1) - y * y); dx.device(d) = dy * (static_cast<T>(1) - y * y);
} }
}; };
// sqrt(x) = x^(1/2) // sqrt(x) = x^(1/2)
struct SqrtFunctor { template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.sqrt(); y.device(d) = x.sqrt();
} }
}; };
template <typename T> template <typename T>
struct SqrtGradFunctor { struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
const Y y_conj = Eigen::numext::conj(y); const Y y_conj = Eigen::numext::conj(y);
dx.device(d) = static_cast<T>(0.5) * dy / y_conj; dx.device(d) = static_cast<T>(0.5) * dy / y_conj;
} }
}; };
// abs(x) = |x| // abs(x) = |x|
struct AbsFunctor { template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.abs(); y.device(d) = x.abs();
} }
}; };
struct AbsGradFunctor { template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * x.sign(); dx.device(d) = dy * x.sign();
} }
}; };
// reciprocal(x) = 1 / x // reciprocal(x) = 1 / x
template <typename T> template <typename T>
struct ReciprocalFunctor { struct ReciprocalFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = static_cast<T>(1) / x; y.device(d) = static_cast<T>(1) / x;
} }
}; };
template <typename T> template <typename T>
struct ReciprocalGradFunctor { struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(-1) * y * y; dx.device(d) = dy * static_cast<T>(-1) * y * y;
} }
}; };
// log(x) = natural logarithm of x // log(x) = natural logarithm of x
struct LogFunctor { template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.log(); y.device(d) = x.log();
} }
}; };
template <typename T> template <typename T>
struct LogGradFunctor { struct LogGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (static_cast<T>(1) / x); dx.device(d) = dy * (static_cast<T>(1) / x);
} }
}; };
// square(x) = x^2 // square(x) = x^2
struct SquareFunctor { template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y> template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) { void operator()(Device d, X x, Y y) const {
y.device(d) = x.square(); y.device(d) = x.square();
} }
}; };
template <typename T> template <typename T>
struct SquareGradFunctor { struct SquareGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX> template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) { void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(2) * x; dx.device(d) = dy * static_cast<T>(2) * x;
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class BReluKernel : public framework::OpKernel { struct BReluFunctor : public BaseActivationFunctor<T> {
public: float t_min;
void Compute(const framework::ExecutionContext& context) const override { float t_max;
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y"); // NOTE: Explicit hides the `BaseActivationFunctor<T>::GetAttrs`
auto t_min = static_cast<T>(context.Attr<AttrType>("t_min")); // not polymorphism for speed.
auto t_max = static_cast<T>(context.Attr<AttrType>("t_max")); typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
Y->mutable_data<T>(context.GetPlace()); return {{"t_min", &t_min}, {"t_max", &t_max}};
}
auto x = framework::EigenVector<T>::Flatten(*X); template <typename Device, typename X, typename Y>
auto y = framework::EigenVector<T>::Flatten(*Y); void operator()(Device d, X x, Y y) const {
auto place = context.GetEigenDevice<Place>(); y.device(d) = x.cwiseMax(t_min).cwiseMin(t_max);
y.device(place) = x.cwiseMax(t_min).cwiseMin(t_max);
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class BReluGradKernel : public framework::OpKernel { struct BReluGradFunctor : public BaseActivationFunctor<T> {
public: float t_min;
void Compute(const framework::ExecutionContext& context) const override { float t_max;
auto* X = context.Input<framework::Tensor>("X"); typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y")); return {{"t_min", &t_min}, {"t_max", &t_max}};
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X")); }
auto t_min = static_cast<T>(context.Attr<AttrType>("t_min")); template <typename Device, typename X, typename Y, typename dY, typename dX>
auto t_max = static_cast<T>(context.Attr<AttrType>("t_max")); void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dX->mutable_data<T>(context.GetPlace()); dx.device(d) = dy * ((x > t_min) * (x < t_max)).template cast<T>();
}
};
auto dy = framework::EigenVector<T>::Flatten(*dY); // softsign(x) = x / (1 + |x|)
auto x = framework::EigenVector<T>::Flatten(*X); template <typename T>
auto dx = framework::EigenVector<T>::Flatten(*dX); struct SoftsignFunctor : public BaseActivationFunctor<T> {
auto place = context.GetEigenDevice<Place>(); template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x / (static_cast<T>(1) + x.abs());
}
};
dx.device(place) = dy * ((x > t_min) * (x < t_max)).template cast<T>(); // d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) =
dy * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class SoftReluKernel : public framework::OpKernel { struct SoftReluFunctor : public BaseActivationFunctor<T> {
public: float threshold;
void Compute(const framework::ExecutionContext& context) const override { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* X = context.Input<framework::Tensor>("X"); return {{"threshold", &threshold}};
auto* Y = context.Output<framework::Tensor>("Y"); }
auto threshold = static_cast<T>(context.Attr<AttrType>("threshold"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X); template <typename Device, typename X, typename Y>
auto y = framework::EigenVector<T>::Flatten(*Y); void operator()(Device d, X x, Y y) const {
auto place = context.GetEigenDevice<Place>(); auto temp = x.cwiseMax(-threshold).cwiseMin(threshold);
auto temp = x.cwiseMax(-threshold).cwiseMin(threshold).eval(); y.device(d) = (static_cast<T>(1) + temp.exp()).log();
y.device(place) = (static_cast<T>(1) + temp.exp()).log();
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class SoftReluGradKernel : public framework::OpKernel { struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
public: float threshold;
void Compute(const framework::ExecutionContext& context) const override { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* X = context.Input<framework::Tensor>("X"); return {{"threshold", &threshold}};
auto* Y = context.Input<framework::Tensor>("Y"); }
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y")); template <typename Device, typename X, typename Y, typename dY, typename dX>
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X")); void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto threshold = static_cast<T>(context.Attr<AttrType>("threshold"));
dX->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
auto temp = ((x > -threshold) * (x < threshold)).template cast<T>().eval(); auto temp = ((x > -threshold) * (x < threshold)).template cast<T>().eval();
dx.device(place) = dy * (static_cast<T>(1) - (-y).exp()) * temp; dx.device(d) = dy * (static_cast<T>(1) - (-y).exp()) * temp;
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class PowKernel : public framework::OpKernel { struct PowFunctor : public BaseActivationFunctor<T> {
public: float factor;
void Compute(const framework::ExecutionContext& context) const override { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* X = context.Input<framework::Tensor>("X"); return {{"factor", &factor}};
auto* Y = context.Output<framework::Tensor>("Y"); }
auto factor = static_cast<T>(context.Attr<AttrType>("factor")); template <typename Device, typename X, typename Y>
Y->mutable_data<T>(context.GetPlace()); void operator()(Device d, X x, Y y) const {
y.device(d) = x.pow(factor);
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
y.device(place) = x.pow(factor);
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class PowGradKernel : public framework::OpKernel { struct PowGradFunctor : public BaseActivationFunctor<T> {
public: float factor;
void Compute(const framework::ExecutionContext& context) const override { typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* X = context.Input<framework::Tensor>("X"); return {{"factor", &factor}};
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y")); }
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X")); template <typename Device, typename X, typename Y, typename dY, typename dX>
auto factor = static_cast<T>(context.Attr<AttrType>("factor")); void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dX->mutable_data<T>(context.GetPlace()); dx.device(d) = dy * factor * x.pow(factor - static_cast<T>(1));
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
dx.device(place) = dy * factor * x.pow(factor - static_cast<T>(1));
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class STanhKernel : public framework::OpKernel { struct STanhFunctor : public BaseActivationFunctor<T> {
public: float scale_a;
void Compute(const framework::ExecutionContext& context) const override { float scale_b;
auto* X = context.Input<framework::Tensor>("X"); typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* Y = context.Output<framework::Tensor>("Y"); return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
auto scale_a = static_cast<T>(context.Attr<AttrType>("scale_a")); }
auto scale_b = static_cast<T>(context.Attr<AttrType>("scale_b"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X); template <typename Device, typename X, typename Y>
auto y = framework::EigenVector<T>::Flatten(*Y); void operator()(Device d, X x, Y y) const {
auto place = context.GetEigenDevice<Place>(); y.device(d) = scale_b * (scale_a * x).tanh();
y.device(place) = scale_b * (scale_a * x).tanh();
} }
}; };
template <typename Place, typename T, typename AttrType = T> template <typename T>
class STanhGradKernel : public framework::OpKernel { struct STanhGradFunctor : public BaseActivationFunctor<T> {
public: float scale_a;
void Compute(const framework::ExecutionContext& context) const override { float scale_b;
auto* X = context.Input<framework::Tensor>("X"); typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y")); return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X")); }
auto scale_a = static_cast<T>(context.Attr<AttrType>("scale_a"));
auto scale_b = static_cast<T>(context.Attr<AttrType>("scale_b"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
auto temp = (scale_a * x).tanh() * (scale_a * x).tanh(); auto temp = (scale_a * x).tanh() * (scale_a * x).tanh();
dx.device(place) = dy * scale_a * scale_b * (static_cast<T>(1) - temp); dx.device(d) = dy * scale_a * scale_b * (static_cast<T>(1) - temp);
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#define FOR_EACH_KERNEL_FUNCTOR(__macro) \
__macro(sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(exp, ExpFunctor, ExpGradFunctor); \
__macro(relu, ReluFunctor, ReluGradFunctor); \
__macro(tanh, TanhFunctor, TanhGradFunctor); \
__macro(sqrt, SqrtFunctor, SqrtGradFunctor); \
__macro(abs, AbsFunctor, AbsGradFunctor); \
__macro(reciprocal, ReciprocalFunctor, ReciprocalGradFunctor); \
__macro(log, LogFunctor, LogGradFunctor); \
__macro(square, SquareFunctor, SquareGradFunctor); \
__macro(brelu, BReluFunctor, BReluGradFunctor); \
__macro(soft_relu, SoftReluFunctor, SoftReluGradFunctor); \
__macro(pow, PowFunctor, PowGradFunctor); \
__macro(stanh, STanhFunctor, STanhGradFunctor); \
__macro(softsign, SoftsignFunctor, SoftsignGradFunctor)
...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class AddKernel : public framework::OpKernel { class AddKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("X"); auto* input0 = context.Input<Tensor>("X");
......
...@@ -56,7 +56,7 @@ class ClipGradFunctor { ...@@ -56,7 +56,7 @@ class ClipGradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ClipKernel : public framework::OpKernel { class ClipKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max"); auto max = context.Attr<T>("max");
...@@ -73,7 +73,7 @@ class ClipKernel : public framework::OpKernel { ...@@ -73,7 +73,7 @@ class ClipKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ClipGradKernel : public framework::OpKernel { class ClipGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max"); auto max = context.Attr<T>("max");
......
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ConcatKernel : public framework::OpKernel { class ConcatKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
...@@ -44,7 +44,7 @@ class ConcatKernel : public framework::OpKernel { ...@@ -44,7 +44,7 @@ class ConcatKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ConcatGradKernel : public framework::OpKernel { class ConcatGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
......
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class CosSimKernel : public framework::OpKernel { class CosSimKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
// get Tensor // get Tensor
...@@ -67,7 +67,7 @@ class CosSimKernel : public framework::OpKernel { ...@@ -67,7 +67,7 @@ class CosSimKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class CosSimGradKernel : public framework::OpKernel { class CosSimGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
// get Tensor // get Tensor
......
...@@ -27,7 +27,7 @@ using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; ...@@ -27,7 +27,7 @@ using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor; using framework::Tensor;
template <typename T> template <typename T>
class CropKernel : public framework::OpKernel { class CropKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
...@@ -69,7 +69,7 @@ void CropGradFunction(const framework::ExecutionContext& context) { ...@@ -69,7 +69,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
} }
template <typename Place, typename T> template <typename Place, typename T>
class CropGradKernel : public framework::OpKernel { class CropGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
size_t rank = size_t rank =
......
...@@ -47,6 +47,12 @@ class CrossEntropyOp : public framework::OperatorWithKernel { ...@@ -47,6 +47,12 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Y", {x_dims[0], 1}); ctx->SetOutputDim("Y", {x_dims[0], 1});
ctx->ShareLoD("X", /*->*/ "Y"); ctx->ShareLoD("X", /*->*/ "Y");
} }
// CrossEntropy's data type just determined by "X"
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
}; };
class CrossEntropyGradientOp : public framework::OperatorWithKernel { class CrossEntropyGradientOp : public framework::OperatorWithKernel {
...@@ -87,6 +93,12 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -87,6 +93,12 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
} }
// CrossEntropy's data type just determined by "X"
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
}; };
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker { class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -18,14 +18,6 @@ namespace paddle { ...@@ -18,14 +18,6 @@ namespace paddle {
namespace operators { namespace operators {
namespace { namespace {
// TODO(qingqing): make zero setting a common function.
template <typename T>
__global__ void Zero(T* X, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
X[i] = 0.0;
}
}
template <typename T> template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X, __global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
...@@ -53,7 +45,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X, ...@@ -53,7 +45,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
} // namespace } // namespace
template <typename T> template <typename T>
class CrossEntropyOpCUDAKernel : public framework::OpKernel { class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
...@@ -64,12 +56,12 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel { ...@@ -64,12 +56,12 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::GPUPlace, T>()( math::CrossEntropyFunctor<platform::GPUPlace, T>()(
ctx, y, x, label, ctx.Attr<bool>("softLabel")); ctx.device_context(), y, x, label, ctx.Attr<bool>("softLabel"));
} }
}; };
template <typename T> template <typename T>
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
...@@ -99,11 +91,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel { ...@@ -99,11 +91,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
.stream()>>>(dx_data, dy_data, x_data, label_data, .stream()>>>(dx_data, dy_data, x_data, label_data,
batch_size, class_num); batch_size, class_num);
} else { } else {
Zero<T><<<grid, block, 0, math::SetConstant<platform::GPUPlace, T>(ctx.device_context(), dx, 0);
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, batch_size * class_num);
auto* label_data = label->data<int>(); auto* label_data = label->data<int>();
grid = (batch_size + block - 1) / block; grid = (batch_size + block - 1) / block;
CrossEntropyGradientKernel<T><<< CrossEntropyGradientKernel<T><<<
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/math/cross_entropy.h" #include "paddle/operators/math/cross_entropy.h"
#include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -26,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -26,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T>
class CrossEntropyOpKernel : public framework::OpKernel { class CrossEntropyOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
...@@ -37,12 +38,12 @@ class CrossEntropyOpKernel : public framework::OpKernel { ...@@ -37,12 +38,12 @@ class CrossEntropyOpKernel : public framework::OpKernel {
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::CPUPlace, T>()( math::CrossEntropyFunctor<platform::CPUPlace, T>()(
ctx, y, x, labels, ctx.Attr<bool>("softLabel")); ctx.device_context(), y, x, labels, ctx.Attr<bool>("softLabel"));
} }
}; };
template <typename T> template <typename T>
class CrossEntropyGradientOpKernel : public framework::OpKernel { class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
...@@ -69,8 +70,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel { ...@@ -69,8 +70,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
const T* x_data = x->data<T>(); const T* x_data = x->data<T>();
const int* label_data = label->data<int>(); const int* label_data = label->data<int>();
// TODO(qingqing): make zero setting a common function. math::SetConstant<platform::CPUPlace, T>(ctx.device_context(), dx, 0);
memset(dx_data, 0, sizeof(T) * batch_size * class_num);
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num); PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
......
...@@ -47,7 +47,7 @@ struct MaskGenerator { ...@@ -47,7 +47,7 @@ struct MaskGenerator {
// Use std::random and thrust::random(thrust is a std library in CUDA) to // Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random. // implement uniform random.
template <typename Place, typename T, typename AttrType> template <typename Place, typename T, typename AttrType>
class GPUDropoutKernel : public framework::OpKernel { class GPUDropoutKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
......
...@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T, typename AttrType> template <typename Place, typename T, typename AttrType>
class CPUDropoutKernel : public framework::OpKernel { class CPUDropoutKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
...@@ -62,7 +62,7 @@ class CPUDropoutKernel : public framework::OpKernel { ...@@ -62,7 +62,7 @@ class CPUDropoutKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel { class DropoutGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(context.Attr<bool>("is_training"), PADDLE_ENFORCE(context.Attr<bool>("is_training"),
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseAddKernel : public framework::OpKernel { class ElementwiseAddKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenAddFunctor, Place, T>(ctx); ElementwiseCompute<EigenAddFunctor, Place, T>(ctx);
...@@ -101,7 +101,7 @@ struct ElementwiseAddBroadCast2GradFunctor { ...@@ -101,7 +101,7 @@ struct ElementwiseAddBroadCast2GradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel { class ElementwiseAddGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseAddGradFunctor<T>, ElementwiseGradCompute<Place, T, ElementwiseAddGradFunctor<T>,
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseDivKernel : public framework::OpKernel { class ElementwiseDivKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenDivFunctor, Place, T>(ctx); ElementwiseCompute<EigenDivFunctor, Place, T>(ctx);
...@@ -103,7 +103,7 @@ struct ElementwiseDivBroadCast2GradFunctor { ...@@ -103,7 +103,7 @@ struct ElementwiseDivBroadCast2GradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel { class ElementwiseDivGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseDivGradFunctor<T>, ElementwiseGradCompute<Place, T, ElementwiseDivGradFunctor<T>,
......
...@@ -36,7 +36,9 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker, ...@@ -36,7 +36,9 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker,
elementwise_mul_grad, ops::ElementwiseOpGrad); elementwise_mul_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul, elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>); ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad, elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>); ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, double>);
...@@ -19,7 +19,9 @@ namespace ops = paddle::operators; ...@@ -19,7 +19,9 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
elementwise_mul, elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, float>); ops::ElementwiseMulKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
elementwise_mul_grad, elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, float>); ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, double>);
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseMulKernel : public framework::OpKernel { class ElementwiseMulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenMulFunctor, Place, T>(ctx); ElementwiseCompute<EigenMulFunctor, Place, T>(ctx);
...@@ -102,7 +102,7 @@ struct ElementwiseMulBroadCast2GradFunctor { ...@@ -102,7 +102,7 @@ struct ElementwiseMulBroadCast2GradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel { class ElementwiseMulGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseMulGradFunctor<T>, ElementwiseGradCompute<Place, T, ElementwiseMulGradFunctor<T>,
......
...@@ -19,7 +19,7 @@ namespace paddle { ...@@ -19,7 +19,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseSubKernel : public framework::OpKernel { class ElementwiseSubKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenSubFunctor, Place, T>(ctx); ElementwiseCompute<EigenSubFunctor, Place, T>(ctx);
...@@ -102,7 +102,7 @@ struct ElementwiseSubBroadCast2GradFunctor { ...@@ -102,7 +102,7 @@ struct ElementwiseSubBroadCast2GradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel { class ElementwiseSubGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseSubGradFunctor<T>, ElementwiseGradCompute<Place, T, ElementwiseSubGradFunctor<T>,
......
...@@ -100,7 +100,7 @@ class FCOp : public NetOp { ...@@ -100,7 +100,7 @@ class FCOp : public NetOp {
add_out = Output("AddOut"); add_out = Output("AddOut");
AppendOp(framework::OpRegistry::CreateOp( AppendOp(framework::OpRegistry::CreateOp(
"rowwise_add", {{"X", {sum_out}}, {"b", {Input("B")}}}, "elementwise_add", {{"X", {sum_out}}, {"Y", {Input("B")}}},
{{"Out", {add_out}}}, {})); {{"Out", {add_out}}}, {}));
} else { } else {
if (Output("AddOut") != framework::kEmptyVarName) { if (Output("AddOut") != framework::kEmptyVarName) {
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class FillZerosLikeKernel : public framework::OpKernel { class FillZerosLikeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* output = context.Output<framework::Tensor>("Y"); auto* output = context.Output<framework::Tensor>("Y");
......
...@@ -37,6 +37,11 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -37,6 +37,11 @@ class GatherOp : public framework::OperatorWithKernel {
output_dims[0] = batch_size; output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims); ctx->SetOutputDim("Out", output_dims);
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
}; };
class GatherGradOp : public framework::OperatorWithKernel { class GatherGradOp : public framework::OperatorWithKernel {
...@@ -47,6 +52,11 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -47,6 +52,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContextBase* ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
}; };
class GatherOpMaker : public framework::OpProtoAndCheckerMaker { class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -24,7 +24,7 @@ namespace operators { ...@@ -24,7 +24,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class GatherOpKernel : public framework::OpKernel { class GatherOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X"); auto *X = ctx.Input<Tensor>("X");
...@@ -37,7 +37,7 @@ class GatherOpKernel : public framework::OpKernel { ...@@ -37,7 +37,7 @@ class GatherOpKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class GatherGradientOpKernel : public framework::OpKernel { class GatherGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *Index = ctx.Input<Tensor>("Index"); auto *Index = ctx.Input<Tensor>("Index");
......
...@@ -16,7 +16,7 @@ namespace paddle { ...@@ -16,7 +16,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
class CPUGaussianRandomKernel : public framework::OpKernel { class CPUGaussianRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
float mean = context.Attr<float>("mean"); float mean = context.Attr<float>("mean");
...@@ -56,6 +56,11 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -56,6 +56,11 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
"dims can be one int or array. dims must be set."); "dims can be one int or array. dims must be set.");
ctx->SetOutputDim("Out", framework::make_ddim(temp)); ctx->SetOutputDim("Out", framework::make_ddim(temp));
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("data_type"));
}
}; };
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker { class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -76,6 +81,8 @@ Use to initialize tensor with gaussian random generator. ...@@ -76,6 +81,8 @@ Use to initialize tensor with gaussian random generator.
"Random seed of generator." "Random seed of generator."
"0 means use system wide seed") "0 means use system wide seed")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("data_type", "output data type")
.SetDefault(framework::DataType::FP32);
} }
}; };
......
...@@ -37,7 +37,7 @@ struct GaussianGenerator { ...@@ -37,7 +37,7 @@ struct GaussianGenerator {
}; };
template <typename T> template <typename T>
class GPUGaussianRandomKernel : public framework::OpKernel { class GPUGaussianRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
......
...@@ -25,7 +25,7 @@ namespace operators { ...@@ -25,7 +25,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class GemmConv2DKernel : public framework::OpKernel { class GemmConv2DKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
...@@ -98,7 +98,7 @@ class GemmConv2DKernel : public framework::OpKernel { ...@@ -98,7 +98,7 @@ class GemmConv2DKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class GemmConvGrad2DKernel : public framework::OpKernel { class GemmConvGrad2DKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input"); const Tensor* input = context.Input<Tensor>("Input");
......
...@@ -36,6 +36,11 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -36,6 +36,11 @@ class LookupTableOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]}); ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out"); ctx->ShareLoD("Ids", /*->*/ "Out");
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type());
}
}; };
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -69,6 +74,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -69,6 +74,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W"); auto table_dims = ctx->GetInputDim("W");
ctx->SetOutputDim(framework::GradVarName("W"), table_dims); ctx->SetOutputDim(framework::GradVarName("W"), table_dims);
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -61,7 +61,7 @@ __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids, ...@@ -61,7 +61,7 @@ __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
} }
template <typename T> template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel { class LookupTableCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); auto table_t = context.Input<Tensor>("W");
...@@ -85,7 +85,7 @@ class LookupTableCUDAKernel : public framework::OpKernel { ...@@ -85,7 +85,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
}; };
template <typename T> template <typename T>
class LookupTableGradCUDAKernel : public framework::OpKernel { class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids"); auto ids_t = context.Input<Tensor>("Ids");
......
...@@ -23,7 +23,7 @@ namespace operators { ...@@ -23,7 +23,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
class LookupTableKernel : public framework::OpKernel { class LookupTableKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); // float tensor auto table_t = context.Input<Tensor>("W"); // float tensor
...@@ -44,7 +44,7 @@ class LookupTableKernel : public framework::OpKernel { ...@@ -44,7 +44,7 @@ class LookupTableKernel : public framework::OpKernel {
}; };
template <typename T> template <typename T>
class LookupTableGradKernel : public framework::OpKernel { class LookupTableGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids"); auto ids_t = context.Input<Tensor>("Ids");
......
...@@ -90,7 +90,7 @@ __global__ void LSTMUnitGradientKernel(const int nthreads, const int dim, ...@@ -90,7 +90,7 @@ __global__ void LSTMUnitGradientKernel(const int nthreads, const int dim,
} }
template <typename T, typename AttrType = T> template <typename T, typename AttrType = T>
class LstmUnitOpCUDAKernel : public framework::OpKernel { class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
...@@ -121,7 +121,7 @@ class LstmUnitOpCUDAKernel : public framework::OpKernel { ...@@ -121,7 +121,7 @@ class LstmUnitOpCUDAKernel : public framework::OpKernel {
}; };
template <typename T, typename AttrType = T> template <typename T, typename AttrType = T>
class LstmUnitGradOpCUDAKernel : public framework::OpKernel { class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
...@@ -33,7 +33,7 @@ inline T tanh(T x) { ...@@ -33,7 +33,7 @@ inline T tanh(T x) {
} }
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class LstmUnitKernel : public framework::OpKernel { class LstmUnitKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
...@@ -76,7 +76,7 @@ class LstmUnitKernel : public framework::OpKernel { ...@@ -76,7 +76,7 @@ class LstmUnitKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class LstmUnitGradKernel : public framework::OpKernel { class LstmUnitGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
......
if(WITH_GPU) if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu DEPS cblas device_context operator) im2col.cu DEPS cblas device_context operator)
nv_library(softmax_function SRCS softmax.cc softmax.cu nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
DEPS operator) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
DEPS operator)
else() else()
cc_library(math_function SRCS math_function.cc im2col.cc cc_library(math_function SRCS math_function.cc im2col.cc
DEPS cblas device_context operator) DEPS cblas device_context operator)
cc_library(softmax_function SRCS softmax.cc DEPS operator) cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator) cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
endif() endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
...@@ -26,8 +26,8 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; ...@@ -26,8 +26,8 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T>
class CrossEntropyFunctor<platform::CPUPlace, T> { class CrossEntropyFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const framework::ExecutionContext& ctx, void operator()(const platform::DeviceContext& ctx, framework::Tensor* out,
framework::Tensor* out, const framework::Tensor* prob, const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel) { const framework::Tensor* labels, const bool softLabel) {
const int batch_size = prob->dims()[0]; const int batch_size = prob->dims()[0];
if (softLabel) { if (softLabel) {
...@@ -35,7 +35,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> { ...@@ -35,7 +35,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> {
auto lbl = EigenMatrix<T>::From(*labels); auto lbl = EigenMatrix<T>::From(*labels);
auto loss = EigenMatrix<T>::From(*out); auto loss = EigenMatrix<T>::From(*out);
loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) = loss.device(*ctx.GetEigenDevice<platform::CPUPlace>()) =
-((lbl * in.log().unaryExpr(math::TolerableValue<T>())) -((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
.sum(Eigen::DSizes<int, 1>(1)) .sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(batch_size, 1))); .reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
......
...@@ -74,8 +74,8 @@ using Tensor = framework::Tensor; ...@@ -74,8 +74,8 @@ using Tensor = framework::Tensor;
template <typename T> template <typename T>
class CrossEntropyFunctor<platform::GPUPlace, T> { class CrossEntropyFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const framework::ExecutionContext& ctx, void operator()(const platform::DeviceContext& ctx, framework::Tensor* out,
framework::Tensor* out, const framework::Tensor* prob, const framework::Tensor* prob,
const framework::Tensor* labels, bool softLabel) { const framework::Tensor* labels, bool softLabel) {
const T* prob_data = prob->data<T>(); const T* prob_data = prob->data<T>();
T* loss_data = out->mutable_data<T>(ctx.GetPlace()); T* loss_data = out->mutable_data<T>(ctx.GetPlace());
...@@ -87,20 +87,18 @@ class CrossEntropyFunctor<platform::GPUPlace, T> { ...@@ -87,20 +87,18 @@ class CrossEntropyFunctor<platform::GPUPlace, T> {
const T* label_data = labels->data<T>(); const T* label_data = labels->data<T>();
int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num))); int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
SoftCrossEntropyKernel< SoftCrossEntropyKernel<T><<<
T><<<batch_size, block, block * sizeof(T), batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>( reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
ctx.device_context()) loss_data, prob_data, label_data, class_num);
.stream()>>>(loss_data, prob_data, label_data, class_num);
} else { } else {
const int* label_data = labels->data<int>(); const int* label_data = labels->data<int>();
int block = 512; int block = 512;
int grid = (batch_size + block - 1) / block; int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<< CrossEntropyKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>( grid, block, 0,
ctx.device_context()) reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
.stream()>>>(loss_data, prob_data, label_data, loss_data, prob_data, label_data, batch_size, class_num);
batch_size, class_num);
} }
} }
}; };
......
...@@ -37,9 +37,7 @@ struct TolerableValue { ...@@ -37,9 +37,7 @@ struct TolerableValue {
template <typename Place, typename T> template <typename Place, typename T>
class CrossEntropyFunctor { class CrossEntropyFunctor {
public: public:
// (TODO caoying) it is much better to use DeviceContext as the first void operator()(const platform::DeviceContext& context,
// parameter.
void operator()(const framework::ExecutionContext& context,
framework::Tensor* out, const framework::Tensor* prob, framework::Tensor* out, const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel); const framework::Tensor* labels, const bool softLabel);
}; };
......
...@@ -52,6 +52,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, ...@@ -52,6 +52,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include <cmath> #include <cmath>
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
...@@ -84,6 +85,13 @@ void matmul(const platform::DeviceContext& context, ...@@ -84,6 +85,13 @@ void matmul(const platform::DeviceContext& context,
const framework::Tensor& matrix_b, bool trans_b, T alpha, const framework::Tensor& matrix_b, bool trans_b, T alpha,
framework::Tensor* matrix_out, T beta); framework::Tensor* matrix_out, T beta);
template <typename Place, typename T>
void SetConstant(const platform::DeviceContext& context,
framework::Tensor* tensor, T num) {
auto t = framework::EigenVector<T>::Flatten(*tensor);
t.device(*context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(num));
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -243,3 +243,24 @@ TEST(math_function, gemm_trans_clbas) { ...@@ -243,3 +243,24 @@ TEST(math_function, gemm_trans_clbas) {
EXPECT_EQ(input3_ptr[6], 86); EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99); EXPECT_EQ(input3_ptr[7], 99);
} }
TEST(math_function, zero) {
paddle::framework::Tensor tensor;
auto* cpu_place = new paddle::platform::CPUPlace();
float* t = tensor.mutable_data<float>({2, 2}, *cpu_place);
paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::SetConstant<paddle::platform::CPUPlace, float>(
context, &tensor, 0);
EXPECT_EQ(t[0], 0);
EXPECT_EQ(t[1], 0);
EXPECT_EQ(t[2], 0);
EXPECT_EQ(t[3], 0);
paddle::operators::math::SetConstant<paddle::platform::CPUPlace, float>(
context, &tensor, 1);
EXPECT_EQ(t[0], 1);
EXPECT_EQ(t[1], 1);
EXPECT_EQ(t[2], 1);
EXPECT_EQ(t[3], 1);
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/softmax.h" #include "paddle/operators/math/softmax.h"
...@@ -19,6 +19,7 @@ namespace operators { ...@@ -19,6 +19,7 @@ namespace operators {
namespace math { namespace math {
template class SoftmaxFunctor<platform::CPUPlace, float>; template class SoftmaxFunctor<platform::CPUPlace, float>;
template class SoftmaxGradFunctor<platform::CPUPlace, float>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
...@@ -21,6 +21,7 @@ namespace operators { ...@@ -21,6 +21,7 @@ namespace operators {
namespace math { namespace math {
template class SoftmaxFunctor<platform::GPUPlace, float>; template class SoftmaxFunctor<platform::GPUPlace, float>;
template class SoftmaxGradFunctor<platform::GPUPlace, float>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/eigen.h" #include "paddle/framework/eigen.h"
...@@ -36,7 +36,7 @@ struct ValueClip { ...@@ -36,7 +36,7 @@ struct ValueClip {
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxFunctor { class SoftmaxFunctor {
public: public:
void operator()(const framework::ExecutionContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor* X, framework::Tensor* Y) { const framework::Tensor* X, framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X); auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y); auto softmax = EigenMatrix<T>::From(*Y);
...@@ -58,8 +58,8 @@ class SoftmaxFunctor { ...@@ -58,8 +58,8 @@ class SoftmaxFunctor {
.broadcast(one_by_class)) .broadcast(one_by_class))
.unaryExpr(ValueClip<T>()); .unaryExpr(ValueClip<T>());
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp(); softmax.device(*context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(context.GetEigenDevice<Place>()) = softmax.device(*context.GetEigenDevice<Place>()) =
(softmax * (softmax *
softmax.sum(along_class) softmax.sum(along_class)
.inverse() .inverse()
...@@ -68,6 +68,37 @@ class SoftmaxFunctor { ...@@ -68,6 +68,37 @@ class SoftmaxFunctor {
.broadcast(one_by_class)); .broadcast(one_by_class));
} }
}; };
template <typename Place, typename T>
class SoftmaxGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad) {
auto softmax = EigenMatrix<T>::From(*y);
auto softmax_grad = EigenMatrix<T>::From(*y_grad);
auto logits_grad = EigenMatrix<T>::From(*x_grad);
const int kBatchDim = 0;
const int kClassDim = 1;
const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto dot = (softmax * softmax_grad)
.sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class);
logits_grad.device(*context.GetEigenDevice<Place>()) =
(softmax_grad - dot) * softmax;
}
};
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class MeanKernel : public framework::OpKernel { class MeanKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X"); auto* input = context.Input<Tensor>("X");
...@@ -45,7 +45,7 @@ class MeanKernel : public framework::OpKernel { ...@@ -45,7 +45,7 @@ class MeanKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MeanGradKernel : public framework::OpKernel { class MeanGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>(framework::GradVarName("Out")); auto OG = context.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MinusKernel : public framework::OpKernel { class MinusKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* left_tensor = context.Input<framework::Tensor>("X"); auto* left_tensor = context.Input<framework::Tensor>("X");
......
...@@ -39,7 +39,7 @@ struct ModifiedHuberLossBackward { ...@@ -39,7 +39,7 @@ struct ModifiedHuberLossBackward {
}; };
template <typename T> template <typename T>
class ModifiedHuberLossGradGPUKernel : public framework::OpKernel { class ModifiedHuberLossGradGPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Y"); auto* in0 = context.Input<Tensor>("Y");
......
...@@ -47,7 +47,7 @@ struct ModifiedHuberLossForward { ...@@ -47,7 +47,7 @@ struct ModifiedHuberLossForward {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ModifiedHuberLossKernel : public framework::OpKernel { class ModifiedHuberLossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
...@@ -73,7 +73,7 @@ class ModifiedHuberLossKernel : public framework::OpKernel { ...@@ -73,7 +73,7 @@ class ModifiedHuberLossKernel : public framework::OpKernel {
// CPU backward kernel // CPU backward kernel
template <typename T> template <typename T>
class ModifiedHuberLossGradCPUKernel : public framework::OpKernel { class ModifiedHuberLossGradCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Y"); auto* in0 = context.Input<Tensor>("Y");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
...@@ -35,12 +35,14 @@ class MulOp : public framework::OperatorWithKernel { ...@@ -35,12 +35,14 @@ class MulOp : public framework::OperatorWithKernel {
int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims"); int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims"); int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
PADDLE_ENFORCE(x_dims.size() > x_num_col_dims, PADDLE_ENFORCE_GT(
"The rank of input tensor X should be larger than " x_dims.size(), x_num_col_dims,
"`mul_op`'s `x_num_col_dims`."); "The input tensor X's rank of MulOp should be larger than "
PADDLE_ENFORCE(y_dims.size() > y_num_col_dims, "x_num_col_dims.");
"The rank of input tensor Y should be larger than " PADDLE_ENFORCE_GT(
"`mul_op`'s `y_num_col_dims`."); y_dims.size(), y_num_col_dims,
"The input tensor Y's rank of MulOp should be larger than "
"y_num_col_dims.");
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims); auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims); auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
......
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class MulKernel : public framework::OpKernel { class MulKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X"); const Tensor* x = context.Input<Tensor>("X");
...@@ -52,7 +52,7 @@ class MulKernel : public framework::OpKernel { ...@@ -52,7 +52,7 @@ class MulKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel { class MulGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims"); int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
......
...@@ -50,6 +50,11 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -50,6 +50,11 @@ class MultiplexOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", in_dim); ctx->SetOutputDim("Out", in_dim);
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type());
}
}; };
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker { class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -99,6 +104,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -99,6 +104,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputsDim(framework::GradVarName("X"), d_ins); ctx->SetOutputsDim(framework::GradVarName("X"), d_ins);
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexGPUKernel : public framework::OpKernel { class MultiplexGPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<Tensor>("X"); auto ins = ctx.MultiInput<Tensor>("X");
...@@ -51,7 +51,7 @@ class MultiplexGPUKernel : public framework::OpKernel { ...@@ -51,7 +51,7 @@ class MultiplexGPUKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexGradGPUKernel : public framework::OpKernel { class MultiplexGradGPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -23,7 +23,7 @@ namespace paddle { ...@@ -23,7 +23,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexCPUKernel : public framework::OpKernel { class MultiplexCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X"); auto ins = ctx.MultiInput<framework::Tensor>("X");
...@@ -48,7 +48,7 @@ class MultiplexCPUKernel : public framework::OpKernel { ...@@ -48,7 +48,7 @@ class MultiplexCPUKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class MultiplexGradCPUKernel : public framework::OpKernel { class MultiplexGradCPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
......
...@@ -47,7 +47,7 @@ void PadFunction(const framework::ExecutionContext& context) { ...@@ -47,7 +47,7 @@ void PadFunction(const framework::ExecutionContext& context) {
} }
template <typename Place, typename T> template <typename Place, typename T>
class PadKernel : public framework::OpKernel { class PadKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size(); int rank = context.Input<Tensor>("X")->dims().size();
...@@ -97,7 +97,7 @@ void PadGradFunction(const framework::ExecutionContext& context) { ...@@ -97,7 +97,7 @@ void PadGradFunction(const framework::ExecutionContext& context) {
} }
template <typename Place, typename T> template <typename Place, typename T>
class PadGradKernel : public framework::OpKernel { class PadGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
size_t rank = size_t rank =
......
...@@ -40,7 +40,7 @@ class PReluFunctor { ...@@ -40,7 +40,7 @@ class PReluFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class PReluKernel : public framework::OpKernel { class PReluKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
...@@ -77,7 +77,7 @@ class PReluGradFunctor { ...@@ -77,7 +77,7 @@ class PReluGradFunctor {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class PReluGradKernel : public framework::OpKernel { class PReluGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* dx = context.Output<Tensor>(framework::GradVarName("X")); auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
......
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class RankLossKernel : public framework::OpKernel { class RankLossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out"); auto* out_t = ctx.Output<framework::Tensor>("Out");
...@@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel { ...@@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class RankLossGradKernel : public framework::OpKernel { class RankLossGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_left_t = auto* d_left_t =
......
...@@ -87,7 +87,7 @@ struct MaxOrMinGradFunctor { ...@@ -87,7 +87,7 @@ struct MaxOrMinGradFunctor {
}; };
template <typename Place, typename T, typename Functor> template <typename Place, typename T, typename Functor>
class ReduceKernel : public framework::OpKernel { class ReduceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size(); int rank = context.Input<Tensor>("X")->dims().size();
...@@ -141,7 +141,7 @@ class ReduceKernel : public framework::OpKernel { ...@@ -141,7 +141,7 @@ class ReduceKernel : public framework::OpKernel {
}; };
template <typename Place, typename T, typename Functor> template <typename Place, typename T, typename Functor>
class ReduceGradKernel : public framework::OpKernel { class ReduceGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size(); int rank = context.Input<Tensor>("X")->dims().size();
......
...@@ -21,7 +21,7 @@ namespace paddle { ...@@ -21,7 +21,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class ReshapeKernel : public framework::OpKernel { class ReshapeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::Tensor>("Out"); auto* out = ctx.Output<framework::Tensor>("Out");
...@@ -39,7 +39,7 @@ class ReshapeKernel : public framework::OpKernel { ...@@ -39,7 +39,7 @@ class ReshapeKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ReshapeGradKernel : public framework::OpKernel { class ReshapeGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/rowwise_add_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class RowwiseAddOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of RowwiseAddOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("b"),
"Input(b) of RowwiseAddOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of RowwiseAddOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
auto b_dims = ctx->GetInputDim("b");
PADDLE_ENFORCE_GT(
x_dims.size(), b_dims.size(),
"The rank of input `X` must be larger than the one of input `b`.");
int num_col_dims = x_dims.size() - b_dims.size();
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
"The width of two operands must be same");
PADDLE_ENFORCE_EQ(ctx->Outputs("Out").size(), 1,
"The output size must be 1");
ctx->SetOutputDim("Out", x_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class RowwiseAddOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RowwiseAddOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The left input of row-wise add op, must be matrix");
AddInput("b", "The right input of row-wise add op, must be vector");
AddOutput("Out", "The output of row-wise add op");
AddComment(R"DOC(Row-wise Add operator
for i in xrange(X.shape[0]):
Out = X[i] + b
)DOC");
}
};
class RowwiseAddGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "X should not be null");
PADDLE_ENFORCE(ctx->HasInput("b"), "b should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X");
auto b_dims = ctx->GetInputDim("b");
PADDLE_ENFORCE_GT(
x_dims.size(), b_dims.size(),
"The rank of input `X` must be larger than the one of input `b`.");
int64_t num_col_dims = x_dims.size() - b_dims.size();
PADDLE_ENFORCE_EQ(
framework::slice_ddim(x_dims, num_col_dims, x_dims.size()), b_dims,
"The width of two operands must be same");
auto x_grad_name = framework::GradVarName("X");
auto b_grad_name = framework::GradVarName("b");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
if (ctx->HasOutput(b_grad_name)) {
ctx->SetOutputDim(b_grad_name, b_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(rowwise_add, ops::RowwiseAddOp, ops::RowwiseAddOpMaker,
rowwise_add_grad, ops::RowwiseAddGradOp);
REGISTER_OP_CPU_KERNEL(
rowwise_add, ops::RowwiseAddKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
rowwise_add_grad,
ops::RowwiseAddGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class RowwiseAddKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
int num_col_dims = context.Input<Tensor>("X")->dims().size() -
context.Input<Tensor>("b")->dims().size();
auto input =
EigenMatrix<T>::Reshape(*context.Input<Tensor>("X"), num_col_dims);
auto bias = EigenVector<T>::Flatten(*context.Input<Tensor>("b"));
auto output = EigenMatrix<T>::Reshape(*out, num_col_dims);
const int bias_size = bias.dimension(0);
const int rest_size = input.size() / bias_size;
Eigen::DSizes<int, 1> one_d(input.size());
Eigen::DSizes<int, 1> bcast(rest_size);
output.reshape(one_d).device(context.GetEigenDevice<Place>()) =
input.reshape(one_d) + bias.broadcast(bcast).reshape(one_d);
}
};
template <typename Place, typename T>
class RowwiseAddGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dout = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
auto* db = context.Output<Tensor>(framework::GradVarName("b"));
int num_col_dims = context.Input<Tensor>("X")->dims().size() -
context.Input<Tensor>("b")->dims().size();
auto out_grad = EigenMatrix<T>::Reshape(*dout, num_col_dims);
auto place = context.GetEigenDevice<Place>();
if (dx) {
dx->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::Reshape(*dx, num_col_dims).device(place) = out_grad;
}
if (db) {
db->mutable_data<T>(context.GetPlace());
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// colwise add
Eigen::array<int, 1> dims{{0}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = out_grad.sum(dims);
}
}
};
} // namespace operators
} // namespace paddle
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class ScaleKernel : public framework::OpKernel { class ScaleKernel : public framework::OpKernel<T> {
public: public:
virtual void Compute(const framework::ExecutionContext& context) const { virtual void Compute(const framework::ExecutionContext& context) const {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
......
...@@ -48,6 +48,11 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -48,6 +48,11 @@ class ScatterOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", ref_dims); ctx->SetOutputDim("Out", ref_dims);
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Ref")->type());
}
}; };
class ScatterGradOp : public framework::OperatorWithKernel { class ScatterGradOp : public framework::OperatorWithKernel {
...@@ -60,6 +65,11 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -60,6 +65,11 @@ class ScatterGradOp : public framework::OperatorWithKernel {
ctx->GetInputDim("Updates")); ctx->GetInputDim("Updates"));
ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref")); ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref"));
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Ref")->type());
}
}; };
class ScatterOpMaker : public framework::OpProtoAndCheckerMaker { class ScatterOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -24,7 +24,7 @@ namespace operators { ...@@ -24,7 +24,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename Place, typename T> template <typename Place, typename T>
class ScatterOpKernel : public framework::OpKernel { class ScatterOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *Ref = ctx.Input<Tensor>("Ref"); auto *Ref = ctx.Input<Tensor>("Ref");
...@@ -40,7 +40,7 @@ class ScatterOpKernel : public framework::OpKernel { ...@@ -40,7 +40,7 @@ class ScatterOpKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class ScatterGradientOpKernel : public framework::OpKernel { class ScatterGradientOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref")); auto *dRef = ctx.Output<Tensor>(framework::GradVarName("Ref"));
......
...@@ -24,9 +24,9 @@ class SequencePoolOp : public framework::OperatorWithKernel { ...@@ -24,9 +24,9 @@ class SequencePoolOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContextBase* ctx) const override { void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceAvgPoolOp should not be null."); "Input(X) of SequencePoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceAvgPoolOp should not be null."); "Output(Out) of SequencePoolOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
} }
}; };
......
...@@ -38,7 +38,7 @@ enum SeqPoolType { ...@@ -38,7 +38,7 @@ enum SeqPoolType {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SequencePoolKernel : public framework::OpKernel { class SequencePoolKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
...@@ -85,7 +85,7 @@ class SequencePoolKernel : public framework::OpKernel { ...@@ -85,7 +85,7 @@ class SequencePoolKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SequencePoolGradKernel : public framework::OpKernel { class SequencePoolGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X"); auto* in = context.Input<LoDTensor>("X");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sequence_softmax_op.h"
namespace paddle {
namespace operators {
class SequenceSoftmaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SequenceSoftmaxOp should not be null.");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
class SequenceSoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SequenceSoftmaxOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"(LoDTensor) 1-D or 2-D input LoDTensor with the 2-nd dimension "
"of length 1.");
AddOutput("Out",
"(LoDTensor) 1-D or 2-D output LoDTensor with the 2-nd dimension "
"of length 1.");
AddComment(R"DOC(
SequenceSoftmaxOp computes softmax activation among all time-steps for each
sequence. The dimension of each time-step should be 1. Thus, the shape of
input Tensor can be either [N, 1] or [N], where N is the sum of all sequences'
lengths.
Equation:
for i-th sequence in a mini-batch:
Out(X[lod[i]:lod[i+1]], :) =
exp(X[lod[i]:lod[i+1], :]) / sum(exp(X[lod[i]:lod[i+1], :]))
For example, for a mini-batch of 3 sequences with variable-length,
each containing 2, 3, 2 time-steps, the lod of which is [0, 2, 5, 7],
then softmax will be computed among X[0:2, :], X[2:5, :], X[5:7, :]
and N turns out to be 7.
)DOC");
}
};
class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Out"),
"Input(Out) of SequenceSoftmaxGradOp should not be null.");
PADDLE_ENFORCE(
ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) of SequenceSoftmaxGradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE_EQ(
ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Out")),
"Input(Out) and Input(Out@GRAD) of SequenceSoftmaxGradOp should be of "
"the same shape.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sequence_softmax, ops::SequenceSoftmaxOp,
ops::SequenceSoftmaxOpMaker, sequence_softmax_grad,
ops::SequenceSoftmaxGradOp);
REGISTER_OP_CPU_KERNEL(
sequence_softmax,
ops::SequenceSoftmaxKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sequence_softmax_grad,
ops::SequenceSoftmaxGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sequence_softmax_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
sequence_softmax,
ops::SequenceSoftmaxKernel<paddle::platform::GPUPlace, float>)
REGISTER_OP_GPU_KERNEL(
sequence_softmax_grad,
ops::SequenceSoftmaxGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/softmax.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename Place, typename T>
class SequenceSoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<LoDTensor>("Out");
auto lod = x->lod();
auto dims = x->dims();
const size_t level = lod.size() - 1;
PADDLE_ENFORCE_EQ(dims[0], static_cast<int64_t>(lod[level].back()),
"The first dimension of Input(X) should be equal to the "
"sum of all sequences' lengths.");
PADDLE_ENFORCE_EQ(dims[0], x->numel(),
"The width of each timestep in Input(X) of "
"SequenceSoftmaxOp should be 1.");
out->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor x_i = x->Slice<T>(start_pos, end_pos);
Tensor out_i = out->Slice<T>(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
x_i.Resize(dims_i);
out_i.Resize(dims_i);
math::SoftmaxFunctor<Place, T>()(ctx.device_context(), &x_i, &out_i);
}
}
};
template <typename Place, typename T>
class SequenceSoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* out = ctx.Input<LoDTensor>("Out");
auto* out_grad = ctx.Input<LoDTensor>(framework::GradVarName("Out"));
auto* x = ctx.Input<LoDTensor>("X");
auto* x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto lod = x->lod();
const size_t level = lod.size() - 1;
x_grad->mutable_data<T>(ctx.GetPlace());
for (int i = 0; i < static_cast<int>(lod[level].size()) - 1; ++i) {
int start_pos = static_cast<int>(lod[level][i]);
int end_pos = static_cast<int>(lod[level][i + 1]);
Tensor out_i = out->Slice<T>(start_pos, end_pos);
Tensor out_grad_i = out_grad->Slice<T>(start_pos, end_pos);
Tensor x_grad_i = x_grad->Slice<T>(start_pos, end_pos);
// Reshape from (end_pos - start_pos) x 1UL to 1UL x (end_pos - start_pos)
framework::DDim dims_i = framework::make_ddim({1UL, end_pos - start_pos});
out_i.Resize(dims_i);
out_grad_i.Resize(dims_i);
x_grad_i.Resize(dims_i);
math::SoftmaxGradFunctor<Place, T>()(ctx.device_context(), &out_i,
&out_grad_i, &x_grad_i);
}
}
};
} // namespace operators
} // namespace paddle
...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SGDOpKernel : public framework::OpKernel { class SGDOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("param");
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X))) // Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X"); const framework::Tensor *X = context.Input<framework::Tensor>("X");
...@@ -48,7 +48,7 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel { ...@@ -48,7 +48,7 @@ class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel {
// dX = sigmoid(X) - labels // dX = sigmoid(X) - labels
template <typename Place, typename T> template <typename Place, typename T>
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel { class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X"); const framework::Tensor *X = context.Input<framework::Tensor>("X");
......
...@@ -45,7 +45,7 @@ struct SmoothL1LossForward { ...@@ -45,7 +45,7 @@ struct SmoothL1LossForward {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class SmoothL1LossKernel : public framework::OpKernel { class SmoothL1LossKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
...@@ -115,7 +115,7 @@ struct SmoothL1LossBackward { ...@@ -115,7 +115,7 @@ struct SmoothL1LossBackward {
}; };
template <typename Place, typename T, typename AttrType = T> template <typename Place, typename T, typename AttrType = T>
class SmoothL1LossGradKernel : public framework::OpKernel { class SmoothL1LossGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("InsideWeight"); auto* in0 = context.Input<Tensor>("InsideWeight");
......
...@@ -26,46 +26,31 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -26,46 +26,31 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxKernel : public framework::OpKernel { class SoftmaxKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto X = context.Input<Tensor>("X"); auto* X = context.Input<Tensor>("X");
auto Y = context.Output<Tensor>("Y"); auto* Y = context.Output<Tensor>("Y");
// allocate memory on device. // allocate memory on device.
Y->mutable_data<T>(context.GetPlace()); Y->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<Place, T>()(context, X, Y); math::SoftmaxFunctor<Place, T>()(context.device_context(), X, Y);
} }
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SoftmaxGradKernel : public framework::OpKernel { class SoftmaxGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto Y = context.Input<Tensor>("Y"); auto* Y = context.Input<Tensor>("Y");
auto dY = context.Input<Tensor>(framework::GradVarName("Y")); auto* dY = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX = context.Output<Tensor>(framework::GradVarName("X")); auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
const int batch_size = Y->dims()[0];
const int class_num = Y->dims()[1];
Eigen::DSizes<int, 1> along_class(1);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, class_num);
auto Y_eigen = EigenMatrix<T>::From(*Y); // allocate memory on device.
auto dY_eigen = EigenMatrix<T>::From(*dY); dX->mutable_data<T>(context.GetPlace());
auto dX_eigen = EigenMatrix<T>::From(*dX);
auto place = context.GetEigenDevice<Place>();
auto dot = (Y_eigen * dY_eigen) math::SoftmaxGradFunctor<Place, T>()(context.device_context(), Y, dY, dX);
.sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class);
dX_eigen.device(place) = (dY_eigen - dot) * Y_eigen;
} }
}; };
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
limitations under the License. */ limitations under the License. */
#include "paddle/operators/softmax_with_cross_entropy_op.h" #include "paddle/operators/softmax_with_cross_entropy_op.h"
#include <paddle/function/TensorType.h>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -115,6 +116,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -115,6 +116,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
ctx->ShareLoD("Logits", /*->*/ "Softmax"); ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx->ShareLoD("Logits", /*->*/ "Loss"); ctx->ShareLoD("Logits", /*->*/ "Loss");
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("Logits")->type());
}
}; };
class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
...@@ -149,6 +155,12 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -149,6 +155,12 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("Logits"), ctx->SetOutputDim(framework::GradVarName("Logits"),
ctx->GetInputDim("Softmax")); ctx->GetInputDim("Softmax"));
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<Tensor>(framework::GradVarName("Loss"))->type());
}
}; };
} // namespace operators } // namespace operators
......
...@@ -53,7 +53,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad, ...@@ -53,7 +53,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
} // namespace } // namespace
template <typename T> template <typename T>
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
...@@ -66,14 +66,16 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel { ...@@ -66,14 +66,16 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel {
softmax->mutable_data<T>(context.GetPlace()); softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace()); loss->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<platform::GPUPlace, T>()(context, logits, softmax); math::SoftmaxFunctor<platform::GPUPlace, T>()(context.device_context(),
logits, softmax);
math::CrossEntropyFunctor<platform::GPUPlace, T>()( math::CrossEntropyFunctor<platform::GPUPlace, T>()(
context, loss, softmax, labels, context.Attr<bool>("softLabel")); context.device_context(), loss, softmax, labels,
context.Attr<bool>("softLabel"));
} }
}; };
template <typename T> template <typename T>
class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel { class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(context.GetPlace()),
......
...@@ -27,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -27,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T> template <typename T>
class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()), PADDLE_ENFORCE(platform::is_cpu_place(context.GetPlace()),
...@@ -40,14 +40,16 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel { ...@@ -40,14 +40,16 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel {
softmax->mutable_data<T>(context.GetPlace()); softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace()); loss->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<platform::CPUPlace, T>()(context, logits, softmax); math::SoftmaxFunctor<platform::CPUPlace, T>()(context.device_context(),
logits, softmax);
math::CrossEntropyFunctor<platform::CPUPlace, T>()( math::CrossEntropyFunctor<platform::CPUPlace, T>()(
context, loss, softmax, labels, context.Attr<bool>("softLabel")); context.device_context(), loss, softmax, labels,
context.Attr<bool>("softLabel"));
} }
}; };
template <typename T> template <typename T>
class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel { class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const Tensor* out_grad = const Tensor* out_grad =
......
...@@ -22,7 +22,7 @@ namespace paddle { ...@@ -22,7 +22,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename Place, typename T> template <typename Place, typename T>
class SplitOpKernel : public framework::OpKernel { class SplitOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X"); auto* in = ctx.Input<framework::Tensor>("X");
......
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SquaredL2DistanceKernel : public framework::OpKernel { class SquaredL2DistanceKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X"); auto* in0 = context.Input<Tensor>("X");
...@@ -68,7 +68,7 @@ class SquaredL2DistanceKernel : public framework::OpKernel { ...@@ -68,7 +68,7 @@ class SquaredL2DistanceKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SquaredL2DistanceGradKernel : public framework::OpKernel { class SquaredL2DistanceGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("sub_result"); auto* in0 = context.Input<Tensor>("sub_result");
......
...@@ -22,7 +22,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -22,7 +22,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class SumKernel : public framework::OpKernel { class SumKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto ins = context.MultiInput<Tensor>("X"); auto ins = context.MultiInput<Tensor>("X");
...@@ -43,7 +43,7 @@ class SumKernel : public framework::OpKernel { ...@@ -43,7 +43,7 @@ class SumKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class SumGradKernel : public framework::OpKernel { class SumGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>(framework::GradVarName("Out")); auto* input = context.Input<Tensor>(framework::GradVarName("Out"));
......
...@@ -279,7 +279,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int* indices, ...@@ -279,7 +279,7 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
} }
template <typename T> template <typename T>
class TopkOpCUDAKernel : public framework::OpKernel { class TopkOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T> template <typename Place, typename T>
class TopkKernel : public framework::OpKernel { class TopkKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
// Get the top k elements of each row of input tensor // Get the top k elements of each row of input tensor
......
...@@ -38,7 +38,7 @@ void EigenTranspose(const framework::ExecutionContext& context, ...@@ -38,7 +38,7 @@ void EigenTranspose(const framework::ExecutionContext& context,
} }
template <typename Place, typename T> template <typename Place, typename T>
class TransposeKernel : public framework::OpKernel { class TransposeKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X"); auto* x = context.Input<framework::Tensor>("X");
...@@ -73,7 +73,7 @@ class TransposeKernel : public framework::OpKernel { ...@@ -73,7 +73,7 @@ class TransposeKernel : public framework::OpKernel {
}; };
template <typename Place, typename T> template <typename Place, typename T>
class TransposeGradKernel : public framework::OpKernel { class TransposeGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad = auto* out_grad =
......
...@@ -21,7 +21,7 @@ namespace operators { ...@@ -21,7 +21,7 @@ namespace operators {
// Use std::random and thrust::random(thrust is a std library in CUDA) to // Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random. // implement uniform random.
template <typename T> template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel { class CPUUniformRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* tensor = ctx.Output<framework::Tensor>("Out"); auto* tensor = ctx.Output<framework::Tensor>("Out");
...@@ -62,6 +62,11 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -62,6 +62,11 @@ class UniformRandomOp : public framework::OperatorWithKernel {
} }
ctx->SetOutputDim("Out", framework::make_ddim(temp)); ctx->SetOutputDim("Out", framework::make_ddim(temp));
} }
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("data_type"));
}
}; };
class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker { class UniformRandomOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -80,6 +85,8 @@ Used to initialize tensor with uniform random generator. ...@@ -80,6 +85,8 @@ Used to initialize tensor with uniform random generator.
"Random seed of uniform random. " "Random seed of uniform random. "
"0 means generate a seed by system") "0 means generate a seed by system")
.SetDefault(0); .SetDefault(0);
AddAttr<int>("data_type", "output tensor data type")
.SetDefault(framework::DataType::FP32);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -40,7 +40,7 @@ struct UniformGenerator { ...@@ -40,7 +40,7 @@ struct UniformGenerator {
// Use std::random and thrust::random(thrust is a std library in CUDA) to // Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random. // implement uniform random.
template <typename T> template <typename T>
class GPUUniformRandomKernel : public framework::OpKernel { class GPUUniformRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
......
...@@ -16,8 +16,8 @@ namespace paddle { ...@@ -16,8 +16,8 @@ namespace paddle {
namespace platform { namespace platform {
template <> template <>
Eigen::DefaultDevice* DeviceContext::get_eigen_device<Eigen::DefaultDevice>() Eigen::DefaultDevice* DeviceContext::GetEigenDevice<
const { platform::CPUPlace, Eigen::DefaultDevice>() const {
return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device(); return reinterpret_cast<const CPUDeviceContext*>(this)->eigen_device();
} }
...@@ -37,6 +37,12 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); } ...@@ -37,6 +37,12 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice*
DeviceContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
}
class EigenCudaStreamDevice : public Eigen::StreamInterface { class EigenCudaStreamDevice : public Eigen::StreamInterface {
public: public:
EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) { EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
...@@ -90,11 +96,6 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface { ...@@ -90,11 +96,6 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
mutable unsigned int* semaphore_; mutable unsigned int* semaphore_;
}; };
template <>
Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
return reinterpret_cast<const CUDADeviceContext*>(this)->eigen_device();
}
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) { CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
SetDeviceId(place_.device); SetDeviceId(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_)); PADDLE_ENFORCE(cudaStreamCreate(&stream_));
......
...@@ -27,13 +27,23 @@ limitations under the License. */ ...@@ -27,13 +27,23 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace platform { namespace platform {
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
class DeviceContext { class DeviceContext {
public: public:
virtual ~DeviceContext() {} virtual ~DeviceContext() {}
virtual Place GetPlace() const = 0; virtual Place GetPlace() const = 0;
template <typename DeviceType> template <typename PlaceType,
DeviceType* get_eigen_device() const; typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
DeviceType* GetEigenDevice() const;
virtual void Wait() const {} virtual void Wait() const {}
}; };
...@@ -52,6 +62,11 @@ class CPUDeviceContext : public DeviceContext { ...@@ -52,6 +62,11 @@ class CPUDeviceContext : public DeviceContext {
}; };
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
class EigenCudaStreamDevice; class EigenCudaStreamDevice;
class CUDADeviceContext : public DeviceContext { class CUDADeviceContext : public DeviceContext {
......
...@@ -24,7 +24,7 @@ TEST(Device, Init) { ...@@ -24,7 +24,7 @@ TEST(Device, Init) {
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i)); DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i));
Eigen::GpuDevice* gpu_device = Eigen::GpuDevice* gpu_device =
device_context->template get_eigen_device<Eigen::GpuDevice>(); device_context->template GetEigenDevice<GPUPlace>();
ASSERT_NE(nullptr, gpu_device); ASSERT_NE(nullptr, gpu_device);
delete device_context; delete device_context;
} }
......
...@@ -47,7 +47,7 @@ bool is_cpu_place(const Place &p) { ...@@ -47,7 +47,7 @@ bool is_cpu_place(const Place &p) {
} }
bool places_are_same_class(const Place &p1, const Place &p2) { bool places_are_same_class(const Place &p1, const Place &p2) {
return is_gpu_place(p1) == is_gpu_place(p2); return p1.which() == p2.which();
} }
std::ostream &operator<<(std::ostream &os, const Place &p) { std::ostream &operator<<(std::ostream &os, const Place &p) {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include <iostream> #include <iostream>
#include "paddle/platform/variant.h" #include "paddle/platform/variant.h"
namespace paddle { namespace paddle {
...@@ -46,8 +47,18 @@ struct IsGPUPlace : public boost::static_visitor<bool> { ...@@ -46,8 +47,18 @@ struct IsGPUPlace : public boost::static_visitor<bool> {
bool operator()(const GPUPlace &gpu) const { return true; } bool operator()(const GPUPlace &gpu) const { return true; }
}; };
// Define the max number of Place in bit length. i.e., the max number of places
// should be less equal than 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
#define NUM_PLACE_TYPE_LIMIT_IN_BIT 4
typedef boost::variant<GPUPlace, CPUPlace> Place; typedef boost::variant<GPUPlace, CPUPlace> Place;
// static check number of place types is less equal than
// 2^(NUM_PLACE_TYPE_LIMIT_IN_BIT)
BOOST_MPL_ASSERT((boost::mpl::less_equal<
Place::types::size,
boost::mpl::long_<1 << NUM_PLACE_TYPE_LIMIT_IN_BIT>>));
void set_place(const Place &); void set_place(const Place &);
const Place &get_place(); const Place &get_place();
......
...@@ -29,4 +29,6 @@ ...@@ -29,4 +29,6 @@
#endif #endif
#endif #endif
#include <boost/mpl/comparison.hpp>
#include <boost/mpl/less_equal.hpp>
#include <boost/variant.hpp> #include <boost/variant.hpp>
...@@ -77,20 +77,18 @@ PYBIND11_PLUGIN(core) { ...@@ -77,20 +77,18 @@ PYBIND11_PLUGIN(core) {
}) })
.def("set", PyCPUTensorSetFromArray<float>) .def("set", PyCPUTensorSetFromArray<float>)
.def("set", PyCPUTensorSetFromArray<int>) .def("set", PyCPUTensorSetFromArray<int>)
.def("set", PyCPUTensorSetFromArray<double>)
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
.def("set", PyCUDATensorSetFromArray<float>) .def("set", PyCUDATensorSetFromArray<float>)
.def("set", PyCUDATensorSetFromArray<int>) .def("set", PyCUDATensorSetFromArray<int>)
.def("set", PyCUDATensorSetFromArray<double>)
#endif #endif
.def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("shape", [](Tensor &self) { return vectorize(self.dims()); })
.def("set_float_element", .def("set_float_element", TensorSetElement<float>)
[](Tensor &self, size_t offset, float f) { .def("get_float_element", TensorGetElement<float>)
// TODO(yuyang18): Only support GPU now. .def("set_double_element", TensorSetElement<double>)
self.data<float>()[offset] = f; .def("get_double_element", TensorGetElement<double>)
}) .def("dtype", [](Tensor &self) { return ToDataType(self.type()); });
.def("get_float_element", [](Tensor &self, size_t offset) -> float {
// TODO(yuyang18): Only support GPU now.
return self.data<float>()[offset];
});
py::class_<LoDTensor, Tensor>(m, "LoDTensor") py::class_<LoDTensor, Tensor>(m, "LoDTensor")
.def_buffer( .def_buffer(
......
...@@ -42,7 +42,7 @@ template <size_t I, typename... ARGS> ...@@ -42,7 +42,7 @@ template <size_t I, typename... ARGS>
struct CastToPyBufferImpl<true, I, ARGS...> { struct CastToPyBufferImpl<true, I, ARGS...> {
using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type; using CUR_TYPE = typename std::tuple_element<I, std::tuple<ARGS...>>::type;
py::buffer_info operator()(framework::Tensor &tensor) { py::buffer_info operator()(framework::Tensor &tensor) {
if (std::type_index(typeid(CUR_TYPE)) == tensor.holder_->type()) { if (std::type_index(typeid(CUR_TYPE)) == tensor.type()) {
auto dim_vec = framework::vectorize(tensor.dims()); auto dim_vec = framework::vectorize(tensor.dims());
std::vector<size_t> dims_outside; std::vector<size_t> dims_outside;
std::vector<size_t> strides; std::vector<size_t> strides;
...@@ -56,13 +56,13 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -56,13 +56,13 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
prod *= dims_outside[i - 1]; prod *= dims_outside[i - 1];
} }
framework::Tensor dst_tensor; framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.holder_->place())) { if (paddle::platform::is_gpu_place(tensor.place())) {
dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace()); dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace());
} else if (paddle::platform::is_cpu_place(tensor.holder_->place())) { } else if (paddle::platform::is_cpu_place(tensor.place())) {
dst_tensor = tensor; dst_tensor = tensor;
} }
return py::buffer_info( return py::buffer_info(
dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.holder_->place()), dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.place()),
sizeof(CUR_TYPE), py::format_descriptor<CUR_TYPE>::format(), sizeof(CUR_TYPE), py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(dst_tensor.dims()), dims_outside, strides); (size_t)framework::arity(dst_tensor.dims()), dims_outside, strides);
} else { } else {
...@@ -73,10 +73,23 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -73,10 +73,23 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
}; };
} // namespace details } // namespace details
inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
auto buffer_info = details::CastToPyBufferImpl<true, 0, float, int>()(tensor); auto buffer_info =
details::CastToPyBufferImpl<true, 0, float, int, double>()(tensor);
return buffer_info; return buffer_info;
} }
template <typename T>
T TensorGetElement(framework::Tensor &self, size_t offset) {
PADDLE_ENFORCE(platform::is_cpu_place(self.place()));
return self.data<T>()[offset];
}
template <typename T>
void TensorSetElement(framework::Tensor &self, size_t offset, T elem) {
PADDLE_ENFORCE(platform::is_cpu_place(self.place()));
self.data<T>()[offset] = elem;
}
template <typename T> template <typename T>
void PyCPUTensorSetFromArray( void PyCPUTensorSetFromArray(
framework::Tensor &self, framework::Tensor &self,
......
...@@ -18,7 +18,7 @@ function version(){ ...@@ -18,7 +18,7 @@ function version(){
echo "PaddlePaddle @PADDLE_VERSION@, compiled with" echo "PaddlePaddle @PADDLE_VERSION@, compiled with"
echo " with_avx: @WITH_AVX@" echo " with_avx: @WITH_AVX@"
echo " with_gpu: @WITH_GPU@" echo " with_gpu: @WITH_GPU@"
echo " with_mkldnn: @WITH_MKLDNN" echo " with_mkldnn: @WITH_MKLDNN@"
echo " with_mklml: @WITH_MKLML@" echo " with_mklml: @WITH_MKLML@"
echo " with_double: @WITH_DOUBLE@" echo " with_double: @WITH_DOUBLE@"
echo " with_python: @WITH_PYTHON@" echo " with_python: @WITH_PYTHON@"
......
import unittest import unittest
import numpy as np import numpy as np
import random
import itertools import itertools
import paddle.v2.framework.core as core import paddle.v2.framework.core as core
from paddle.v2.framework.op import Operator from paddle.v2.framework.op import Operator
...@@ -12,17 +13,19 @@ def grad_var_name(var_name): ...@@ -12,17 +13,19 @@ def grad_var_name(var_name):
def create_op(scope, op_type, inputs, outputs, attrs): def create_op(scope, op_type, inputs, outputs, attrs):
kwargs = dict() kwargs = dict()
def __create_var__(name, var_name):
scope.new_var(var_name)
kwargs[name].append(var_name)
for in_name, in_dup in Operator.get_op_inputs(op_type): for in_name, in_dup in Operator.get_op_inputs(op_type):
if in_name in inputs: if in_name in inputs:
kwargs[in_name] = [] kwargs[in_name] = []
if in_dup: if in_dup:
sub_in = inputs[in_name] sub_in = inputs[in_name]
for sub_in_name, _ in sub_in: for sub_in_name, _ in sub_in:
var = scope.new_var(sub_in_name) __create_var__(in_name, sub_in_name)
kwargs[in_name].append(sub_in_name)
else: else:
var = scope.new_var(in_name) __create_var__(in_name, in_name)
kwargs[in_name].append(in_name)
for out_name, out_dup in Operator.get_op_outputs(op_type): for out_name, out_dup in Operator.get_op_outputs(op_type):
if out_name in outputs: if out_name in outputs:
...@@ -30,11 +33,9 @@ def create_op(scope, op_type, inputs, outputs, attrs): ...@@ -30,11 +33,9 @@ def create_op(scope, op_type, inputs, outputs, attrs):
if out_dup: if out_dup:
sub_out = outputs[out_name] sub_out = outputs[out_name]
for sub_out_name, _ in sub_out: for sub_out_name, _ in sub_out:
var = scope.new_var(sub_out_name) __create_var__(out_name, sub_out_name)
kwargs[out_name].append(sub_out_name)
else: else:
var = scope.new_var(out_name) __create_var__(out_name, out_name)
kwargs[out_name].append(out_name)
for attr_name in Operator.get_op_attr_names(op_type): for attr_name in Operator.get_op_attr_names(op_type):
if attr_name in attrs: if attr_name in attrs:
...@@ -44,49 +45,46 @@ def create_op(scope, op_type, inputs, outputs, attrs): ...@@ -44,49 +45,46 @@ def create_op(scope, op_type, inputs, outputs, attrs):
def set_input(scope, op, inputs, place): def set_input(scope, op, inputs, place):
def __set_input__(var_name, var):
tensor = scope.find_var(var_name).get_tensor()
if isinstance(var, tuple):
tensor.set_lod(var[1])
var = var[0]
tensor.set_dims(var.shape)
tensor.set(var, place)
for in_name, in_dup in Operator.get_op_inputs(op.type()): for in_name, in_dup in Operator.get_op_inputs(op.type()):
if in_name in inputs: if in_name in inputs:
if in_dup: if in_dup:
sub_in = inputs[in_name] sub_in = inputs[in_name]
for sub_in_name, sub_in_val in sub_in: for sub_in_name, sub_in_val in sub_in:
var = scope.find_var(sub_in_name) __set_input__(sub_in_name, sub_in_val)
tensor = var.get_tensor()
sub_in_array = sub_in_val[0] \
if isinstance(sub_in_val, tuple) else sub_in_val
tensor.set_dims(sub_in_array.shape)
tensor.set(sub_in_array, place)
if isinstance(sub_in_val, tuple):
tensor.set_lod(sub_in_val[1])
else: else:
var = scope.find_var(in_name) __set_input__(in_name, inputs[in_name])
tensor = var.get_tensor()
in_val = inputs[in_name]
in_array = in_val[0] if isinstance(in_val, tuple) else in_val
tensor.set_dims(in_array.shape)
tensor.set(in_array, place)
if isinstance(in_val, tuple):
tensor.set_lod(in_val[1])
def set_output_grad(scope, op, outputs, place): def set_output_grad(scope, op, outputs, place):
def __set_tensor__(name):
out_tensor = scope.find_var(name).get_tensor()
grad_tensor = scope.new_var(grad_var_name(name)).get_tensor()
out_dtype = out_tensor.dtype()
if out_dtype == core.DataType.FP64:
data = np.ones(out_tensor.shape(), dtype=np.float64)
elif out_dtype == core.DataType.FP32:
data = np.ones(out_tensor.shape(), dtype=np.float32)
else:
raise ValueError("Not supported data type " + str(out_dtype))
grad_tensor.set(data, place)
for out_name, out_dup in Operator.get_op_outputs(op.type()): for out_name, out_dup in Operator.get_op_outputs(op.type()):
if out_name in outputs: if out_name in outputs:
if out_dup: if out_dup:
sub_out = outputs[out_name] sub_out = outputs[out_name]
for sub_out_name, _ in sub_out: for sub_out_name, _ in sub_out:
out_tensor = scope.find_var(sub_out_name).get_tensor() __set_tensor__(sub_out_name)
grad_tensor = scope.new_var(grad_var_name(
sub_out_name)).get_tensor()
grad_tensor.set_dims(out_tensor.shape())
data = np.ones(out_tensor.shape(), dtype=np.float32)
grad_tensor.set(data, place)
else: else:
out_tensor = scope.find_var(out_name).get_tensor() __set_tensor__(out_name)
grad_tensor = scope.new_var(grad_var_name(out_name)).get_tensor(
)
grad_tensor.set_dims(out_tensor.shape())
data = np.ones(out_tensor.shape(), dtype=np.float32)
grad_tensor.set(data, place)
def get_numeric_gradient(scope, def get_numeric_gradient(scope,
...@@ -96,7 +94,6 @@ def get_numeric_gradient(scope, ...@@ -96,7 +94,6 @@ def get_numeric_gradient(scope,
output_names, output_names,
delta=0.005, delta=0.005,
in_place=False): in_place=False):
set_input(scope, op, inputs, core.CPUPlace()) set_input(scope, op, inputs, core.CPUPlace())
tensor_to_check = scope.find_var(input_to_check).get_tensor() tensor_to_check = scope.find_var(input_to_check).get_tensor()
...@@ -115,7 +112,29 @@ def get_numeric_gradient(scope, ...@@ -115,7 +112,29 @@ def get_numeric_gradient(scope,
tensor_to_check = scope.find_var(input_to_check).get_tensor() tensor_to_check = scope.find_var(input_to_check).get_tensor()
tensor_size = product(tensor_to_check.get_dims()) tensor_size = product(tensor_to_check.get_dims())
gradient_flat = np.zeros(shape=(tensor_size, ), dtype='float32') tensor_to_check_dtype = tensor_to_check.dtype()
if tensor_to_check_dtype == core.DataType.FP32:
tensor_to_check_dtype = np.float32
elif tensor_to_check_dtype == core.DataType.FP64:
tensor_to_check_dtype = np.float64
else:
raise ValueError("Not supported data type " + str(
tensor_to_check_dtype))
gradient_flat = np.zeros(shape=(tensor_size, ), dtype=tensor_to_check_dtype)
def __get_elem__(tensor, i):
if tensor_to_check_dtype == np.float32:
return tensor.get_float_element(i)
else:
return tensor.get_double_element(i)
def __set_elem__(tensor, i, e):
if tensor_to_check_dtype == np.float32:
tensor.set_float_element(i, e)
else:
tensor.set_double_element(i, e)
# we only compute gradient of one element each time. # we only compute gradient of one element each time.
# we use a for loop to compute the gradient of every element. # we use a for loop to compute the gradient of every element.
for i in xrange(tensor_size): for i in xrange(tensor_size):
...@@ -123,20 +142,20 @@ def get_numeric_gradient(scope, ...@@ -123,20 +142,20 @@ def get_numeric_gradient(scope,
set_input(scope, op, inputs, core.CPUPlace()) set_input(scope, op, inputs, core.CPUPlace())
# get one input element throw it's index i. # get one input element throw it's index i.
origin = tensor_to_check.get_float_element(i) origin = __get_elem__(tensor_to_check, i)
# add delta to it, run op and then get the sum of the result tensor. # add delta to it, run op and then get the sum of the result tensor.
x_pos = origin + delta x_pos = origin + delta
tensor_to_check.set_float_element(i, x_pos) __set_elem__(tensor_to_check, i, x_pos)
y_pos = get_output() y_pos = get_output()
if in_place: if in_place:
set_input(scope, op, inputs, core.CPUPlace()) set_input(scope, op, inputs, core.CPUPlace())
x_neg = origin - delta x_neg = origin - delta
tensor_to_check.set_float_element(i, x_neg) __set_elem__(tensor_to_check, i, x_neg)
y_neg = get_output() y_neg = get_output()
tensor_to_check.set_float_element(i, origin) __set_elem__(tensor_to_check, i, origin)
gradient_flat[i] = (y_pos - y_neg) / delta / 2 gradient_flat[i] = (y_pos - y_neg) / delta / 2
return gradient_flat.reshape(tensor_to_check.get_dims()) return gradient_flat.reshape(tensor_to_check.get_dims())
...@@ -174,6 +193,21 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place, ...@@ -174,6 +193,21 @@ def get_gradient(scope, op, inputs, outputs, grad_name, place,
class OpTest(unittest.TestCase): class OpTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
'''Fix random seeds to remove randomness from tests'''
cls._np_rand_state = np.random.get_state()
cls._py_rand_state = random.getstate()
np.random.seed(123)
random.seed(124)
@classmethod
def tearDownClass(cls):
'''Restore random seeds'''
np.random.set_state(cls._np_rand_state)
random.setstate(cls._py_rand_state)
def check_output_with_place(self, place, atol): def check_output_with_place(self, place, atol):
self.scope = core.Scope() self.scope = core.Scope()
op_inputs = self.inputs if hasattr(self, "inputs") else dict() op_inputs = self.inputs if hasattr(self, "inputs") else dict()
......
...@@ -219,5 +219,22 @@ class TestSTanh(OpTest): ...@@ -219,5 +219,22 @@ class TestSTanh(OpTest):
self.check_grad(['X'], 'Y', max_relative_error=0.007) self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestSoftsign(OpTest):
def setUp(self):
self.op_type = "softsign"
self.inputs = {
'X': np.random.uniform(-1, 1, [11, 17]).astype("float32")
}
self.outputs = {
'Y': np.divide(self.inputs['X'], 1 + np.abs(self.inputs['X']))
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -80,7 +80,7 @@ class TestCrossEntropyOp3(OpTest): ...@@ -80,7 +80,7 @@ class TestCrossEntropyOp3(OpTest):
cross_entropy2 = (-label * np.log(X)).sum( cross_entropy2 = (-label * np.log(X)).sum(
axis=1, keepdims=True).astype("float32") axis=1, keepdims=True).astype("float32")
self.inputs = {"X": X, "Label": label} self.inputs = {"X": X, "Label": label.astype(np.float32)}
self.outputs = {"Y": cross_entropy} self.outputs = {"Y": cross_entropy}
self.attrs = {"softLabel": True} self.attrs = {"softLabel": True}
......
...@@ -7,8 +7,8 @@ class ElementwiseMulOp(OpTest): ...@@ -7,8 +7,8 @@ class ElementwiseMulOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "elementwise_mul" self.op_type = "elementwise_mul"
self.inputs = { self.inputs = {
'X': np.random.uniform(0.1, 1, [13, 17]).astype("float32"), 'X': np.random.uniform(0.1, 1, [13, 17]).astype("float64"),
'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float32") 'Y': np.random.uniform(0.1, 1, [13, 17]).astype("float64")
} }
self.outputs = {'Out': np.multiply(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.multiply(self.inputs['X'], self.inputs['Y'])}
...@@ -16,23 +16,21 @@ class ElementwiseMulOp(OpTest): ...@@ -16,23 +16,21 @@ class ElementwiseMulOp(OpTest):
self.check_output() self.check_output()
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1) self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self): def test_check_grad_ingore_x(self):
self.check_grad( self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X"))
def test_check_grad_ingore_y(self): def test_check_grad_ingore_y(self):
self.check_grad( self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y'))
class TestElementwiseMulOp_Vector(ElementwiseMulOp): class TestElementwiseMulOp_Vector(ElementwiseMulOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_mul" self.op_type = "elementwise_mul"
self.inputs = { self.inputs = {
'X': np.random.random((32, )).astype("float32"), 'X': np.random.random((32, )).astype("float64"),
'Y': np.random.random((32, )).astype("float32") 'Y': np.random.random((32, )).astype("float64")
} }
self.outputs = {'Out': np.multiply(self.inputs['X'], self.inputs['Y'])} self.outputs = {'Out': np.multiply(self.inputs['X'], self.inputs['Y'])}
...@@ -41,8 +39,8 @@ class TestElementwiseMulOp_broadcast_0(ElementwiseMulOp): ...@@ -41,8 +39,8 @@ class TestElementwiseMulOp_broadcast_0(ElementwiseMulOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_mul" self.op_type = "elementwise_mul"
self.inputs = { self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32), 'X': np.random.rand(2, 3, 4).astype(np.float64),
'Y': np.random.rand(2).astype(np.float32) 'Y': np.random.rand(2).astype(np.float64)
} }
self.attrs = {'axis': 0} self.attrs = {'axis': 0}
...@@ -55,8 +53,8 @@ class TestElementwiseMulOp_broadcast_1(ElementwiseMulOp): ...@@ -55,8 +53,8 @@ class TestElementwiseMulOp_broadcast_1(ElementwiseMulOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_mul" self.op_type = "elementwise_mul"
self.inputs = { self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32), 'X': np.random.rand(2, 3, 4).astype(np.float64),
'Y': np.random.rand(3).astype(np.float32) 'Y': np.random.rand(3).astype(np.float64)
} }
self.attrs = {'axis': 1} self.attrs = {'axis': 1}
...@@ -69,8 +67,8 @@ class TestElementwiseMulOp_broadcast_2(ElementwiseMulOp): ...@@ -69,8 +67,8 @@ class TestElementwiseMulOp_broadcast_2(ElementwiseMulOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_mul" self.op_type = "elementwise_mul"
self.inputs = { self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32), 'X': np.random.rand(2, 3, 4).astype(np.float64),
'Y': np.random.rand(4).astype(np.float32) 'Y': np.random.rand(4).astype(np.float64)
} }
self.outputs = { self.outputs = {
...@@ -82,8 +80,8 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp): ...@@ -82,8 +80,8 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
def setUp(self): def setUp(self):
self.op_type = "elementwise_mul" self.op_type = "elementwise_mul"
self.inputs = { self.inputs = {
'X': np.random.rand(2, 3, 4, 5).astype(np.float32), 'X': np.random.rand(2, 3, 4, 5).astype(np.float64),
'Y': np.random.rand(3, 4).astype(np.float32) 'Y': np.random.rand(3, 4).astype(np.float64)
} }
self.attrs = {'axis': 1} self.attrs = {'axis': 1}
......
...@@ -17,7 +17,7 @@ class PReluTest(OpTest): ...@@ -17,7 +17,7 @@ class PReluTest(OpTest):
x_np_sign = np.sign(x_np) x_np_sign = np.sign(x_np)
x_np = x_np_sign * np.maximum(x_np, .005) x_np = x_np_sign * np.maximum(x_np, .005)
alpha_np = np.array([.1]) alpha_np = np.array([.1], dtype="float32")
self.inputs = {'X': x_np, 'Alpha': alpha_np} self.inputs = {'X': x_np, 'Alpha': alpha_np}
out_np = np.maximum(self.inputs['X'], 0.) out_np = np.maximum(self.inputs['X'], 0.)
out_np = out_np + np.minimum(self.inputs['X'], out_np = out_np + np.minimum(self.inputs['X'],
......
import unittest
import numpy as np
from op_test import OpTest
class TestRowwiseAddOp(OpTest):
def setUp(self):
self.op_type = "rowwise_add"
self.inputs = {
'X': np.random.uniform(0.1, 1, [5, 10]).astype("float32"),
'b': np.random.uniform(0.1, 1, [10]).astype("float32")
}
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'b'], 'Out')
def test_check_grad_ingore_b(self):
self.check_grad(['X'], 'Out', no_grad_set=set('b'))
def test_check_grad_ingore_x(self):
self.check_grad(['b'], 'Out', no_grad_set=set('X'))
class TestRowwiseAddOp2(OpTest):
def setUp(self):
self.op_type = "rowwise_add"
self.inputs = {
'X': np.random.uniform(0.1, 1, [2, 3, 2, 5]).astype("float32"),
'b': np.random.uniform(0.1, 1, [2, 5]).astype("float32")
}
self.outputs = {'Out': np.add(self.inputs['X'], self.inputs['b'])}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'b'], 'Out')
def test_check_grad_ignore_b(self):
self.check_grad(['X'], 'Out', no_grad_set=set('b'))
def test_check_grad_ignore_x(self):
self.check_grad(['b'], 'Out', no_grad_set=set('X'))
if __name__ == "__main__":
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
def stable_softmax(x):
"""Compute the softmax of vector x in a numerically stable way."""
shiftx = x - np.max(x).clip(-64.)
exps = np.exp(shiftx)
return exps / np.sum(exps)
class TestSequenceSoftmaxOp(OpTest):
def setUp(self):
self.op_type = "sequence_softmax"
x = np.random.uniform(0.1, 1, (11, 1)).astype("float32")
lod = [[0, 4, 5, 8, 11]]
out = np.zeros((11, 1)).astype("float32")
for i in range(4):
sub_x = x[lod[0][i]:lod[0][i + 1], :]
sub_x = sub_x.reshape(1, lod[0][i + 1] - lod[0][i])
sub_out = stable_softmax(sub_x)
out[lod[0][i]:lod[0][i + 1], :] = sub_out.reshape(
lod[0][i + 1] - lod[0][i], 1)
self.inputs = {"X": (x, lod)}
self.outputs = {"Out": out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out", max_relative_error=0.01)
if __name__ == "__main__":
unittest.main()
...@@ -43,7 +43,7 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest): ...@@ -43,7 +43,7 @@ class TestSoftmaxWithCrossEntropyOp2(OpTest):
def setUp(self): def setUp(self):
self.op_type = "softmax_with_cross_entropy" self.op_type = "softmax_with_cross_entropy"
batch_size = 2 batch_size = 2
class_num = 17 class_num = 37
logits = np.random.uniform(0.1, 1.0, logits = np.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
......
...@@ -96,6 +96,9 @@ class Inference(object): ...@@ -96,6 +96,9 @@ class Inference(object):
for i, item in enumerate(result): for i, item in enumerate(result):
retv[i].append(item) retv[i].append(item)
if retv == None:
return []
if flatten_result: if flatten_result:
retv = [numpy.concatenate(out) for out in retv] retv = [numpy.concatenate(out) for out in retv]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册