提交 df598899 编写于 作者: C chengduoZH

remove conflict

......@@ -51,19 +51,19 @@ Please refer to our [release announcement](https://github.com/PaddlePaddle/Paddl
- **Connected to Products**
In addition, PaddlePaddle is also designed to be easily deployable. At Baidu,
PaddlePaddle has been deployed into products or service with a vast number
PaddlePaddle has been deployed into products and services with a vast number
of users, including ad click-through rate (CTR) prediction, large-scale image
classification, optical character recognition(OCR), search ranking, computer
virus detection, recommendation, etc. It is widely utilized in products at
Baidu and it has achieved a significant impact. We hope you can also exploit
the capability of PaddlePaddle to make a huge impact for your product.
Baidu and it has achieved a significant impact. We hope you can also explore
the capability of PaddlePaddle to make an impact on your product.
## Installation
It is recommended to check out the
[Docker installation guide](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/docker_install_en.html)
before looking into the
[build from source guide](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html)
[build from source guide](http://doc.paddlepaddle.org/develop/doc/getstarted/build_and_install/build_from_source_en.html).
## Documentation
......@@ -72,7 +72,7 @@ We provide [English](http://doc.paddlepaddle.org/develop/doc/) and
- [Deep Learning 101](http://book.paddlepaddle.org/index.html)
You might want to start from this online interactive book that can run in Jupyter Notebook.
You might want to start from this online interactive book that can run in a Jupyter Notebook.
- [Distributed Training](http://doc.paddlepaddle.org/develop/doc/howto/usage/cluster/cluster_train_en.html)
......
set -e
unset OMP_NUM_THREADS MKL_NUM_THREADS
export OMP_DYNAMIC="FALSE"
export KMP_AFFINITY="granularity=fine,compact,0,0"
function train() {
unset OMP_NUM_THREADS MKL_NUM_THREADS
export OMP_DYNAMIC="FALSE"
export KMP_AFFINITY="granularity=fine,compact,0,0"
topology=$1
bs=$2
use_mkldnn=$3
......
# Design Doc: Refactorization Overview
The goal of refactorizaiton include:
The goals of refactoring include:
1. Make it easy for external contributors to write new elementory computaiton operations.
1. Make the codebase clean and readable.
1. Introduce a new design of computation representation -- a computation graph of operators and variables.
1. The graph representation helps implementing auto-scalable and auto fault recoverable distributed computing.
1. Making it easy for external contributors to write new elementary computation operations.
1. Making the codebase clean and readable.
1. Designing a new computation representation -- a computation graph of operators and variables.
1. Implementing auto-scalability and auto fault recoverable distributed computing with the help of computation graphs.
## Computation Graphs
1. PaddlePaddle represent the computation, training and inference of DL models, by computation graphs.
1. PaddlePaddle represents the computation, training and inference of Deep Learning models, by computation graphs.
1. Please dig into [computation graphs](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/graph.md) for a solid example.
1. Please refer to [computation graphs](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/graph.md) for a concrete example.
1. Users write Python programs to describe the graphs and run it (locally or remotely).
1. Users write Python programs to describe the graphs and run them (locally or remotely).
1. A graph is composed of *variables* and *operators*.
1. The description of graphs must be able to be serialized/deserialized, so it
1. The description of graphs must be capable of being serialized/deserialized, so that
1. could to be sent to the cloud for distributed execution, and
1. be sent to clients for mobile or enterprise deployment.
1. It can to be sent to the cloud for distributed execution, and
1. It can be sent to clients for mobile or enterprise deployment.
1. The Python program do
1. The Python program does the following steps
1. *compilation*: runs a Python program to generate a protobuf message representation of the graph and send it to
1. *compilation*: run a Python program to generate a protobuf message representation of the graph and send it to
1. the C++ library `libpaddle.so` for local execution,
1. the master process of a distributed training job for training, or
1. the server process of a Kubernetes serving job for distributed serving.
1. *execution*: according to the protobuf message, constructs instances of class `Variable` and `OperatorBase`, and run them.
1. *execution*: execute the graph by constructing instances of class [`Variable`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/variable.h#L24) and [`OperatorBase`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L70), according to the protobuf message.
## Description and Realization
## Description and Realization of Computation Graph
At compile time, the Python program generates protobuf message representation of the graph, or the description of the graph.
At compile time, the Python program generates a protobuf message representation of the graph, or the description of the graph.
At runtime, the C++ program realizes the graph and run it.
At runtime, the C++ program realizes the graph and runs it.
| | Representation (protobuf messages) | Realization (C++ class objects) |
|---|---|---|
......@@ -42,30 +42,31 @@ At runtime, the C++ program realizes the graph and run it.
|Operation|[OpDesc](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/framework.proto#L35)|[Operator](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/operator.h#L64)|
|Block|BlockDesc|Block|
The word *graph* is exchangable with *block* in this document. A graph represent computation steps and local variables as a C++/Java program block, or a pair of { and }.
The word *graph* is interchangeable with *block* in this document. A graph represents computation steps and local variables similar to a C++/Java program block, or a pair of parentheses(`{` and `}`).
## Compilation and Execution
1. Run an applicaton Python program to describe the graph. In particular,
1. Run an application Python program to describe the graph. In particular, the Python application program does the following:
1. create VarDesc to represent local/intermediate variables,
1. create operators and set attributes,
1. validate attribute values,
1. inference the type and the shape of variables,
1. plan for memory-reuse for variables,
1. generate backward and optimization part of the Graph.
1. possiblly split the graph for distributed training.
1. Create `VarDesc` to represent local/intermediate variables,
1. Create operators and set attributes,
1. Validate attribute values,
1. Infer the type and the shape of variables,
1. Plan memory-reuse for variables,
1. Generate the backward graph
1. Optimize the computation graph.
1. Potentially, split the graph for distributed training.
1. The invocation of `train` or `infer` in the application Python program:
1. The invocation of `train` or [`infer`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/inference.py#L108) methods in the application Python program does the following:
1. create a new Scope instance in the [scope hierarchy](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/scope.md) for each run of a block,
1. Create a new Scope instance in the [scope hierarchy](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/design/scope.md) for each run of a block,
1. realize local variables defined in the BlockDesc message in the new scope,
1. a scope is similar to the stack frame in programming languages,
1. create an instance of class `Block`, in which,
1. Create an instance of class `Block`, in which,
1. realize operators in the BlockDesc message,
1. run the Block by calling
1. Run the Block by calling
1. `Block::Eval(vector<Variable>* targets)` for forward and backward computations, or
1. `Block::Eval(vector<Operator>* targets)` for optimization.
......@@ -76,14 +77,14 @@ The word *graph* is exchangable with *block* in this document. A graph represen
Compile Time -> IR -> Runtime
```
### Benefit
### Benefits of IR
- Optimization
```text
Compile Time -> IR -> Optimized IR -> Runtime
```
- Send automatically partitioned IR to different nodes.
- Automatic data parallel
- Automatically send partitioned IR to different nodes.
- Automatic Data Parallelism
```text
Compile Time
|-> Single GPU IR
......@@ -92,7 +93,7 @@ Compile Time -> IR -> Runtime
|-> Node-1 (runs trainer-IR-1)
|-> Node-2 (runs pserver-IR)
```
- Automatic model parallel (planned for future)
- Automatic Model Parallelism (planned for future)
---
......@@ -105,10 +106,10 @@ Compile Time -> IR -> Runtime
# Operator
![class_diagram](http://api.paddlepaddle.org/graphviz?dot=https://gist.githubusercontent.com/reyoung/53df507f6749762675dff3e7ce53372f/raw/dd598e8f1976f5759f58af5e5ef94738a6b2e661/op.dot)
* `Operator` is the fundamental building block as the user interface.
* Operator stores input/output variable name, and attributes.
* The `InferShape` interface is used to infer output variable shapes by its input shapes.
* Use `Run` to compute `input variables` to `output variables`.
* `Operator` is the fundamental building block of the user interface.
* Operator stores input/output variable names, and attributes.
* The `InferShape` interface is used to infer the shape of the output variable shapes based on the shapes of the input variables.
* Use `Run` to compute the `output` variables from the `input` variables.
---
......@@ -126,30 +127,30 @@ Compile Time -> IR -> Runtime
# Why separate Kernel and Operator
* Separate GPU and CPU code.
* Make Paddle can run without GPU.
* Make one operator (which is user interface) can contain many implementations.
* Same mul op, different FP16, FP32 Kernel. different MKL, eigen kernel.
* Make Paddle capable of running without GPU.
* Make one operator (which is a user interface) and create many implementations.
* For example, same multiplication op can have different implementations kernels such as FP16 kernel, FP32 kernel, MKL, eigen kernel.
---
# Libraries for Kernel development
* `Eigen::Tensor` contains basic math and element-wise functions.
* Note that `Eigen::Tensor` has broadcast implementation.
* Limit number of `tensor.device(dev) = ` in your code.
* Limit the number of `tensor.device(dev) = ` in your code.
* `thrust::tranform` and `std::transform`.
* `thrust` has the same API as C++ standard library. Using `transform` can quickly implement a customized elementwise kernel.
* `thrust` has more complex API, like `scan`, `reduce`, `reduce_by_key`.
* `thrust` has the same API as C++ standard library. Using `transform`, one can quickly implement customized elementwise kernels.
* `thrust` also has more complex APIs, like `scan`, `reduce`, `reduce_by_key`.
* Hand-writing `GPUKernel` and `CPU` code
* Do not write `.h`. CPU Kernel should be in `.cc`. GPU kernel should be in `.cu`. (`GCC` cannot compile GPU code.)
* Do not write in header (`.h`) files. CPU Kernel should be in cpp source (`.cc`) and GPU kernels should be in cuda (`.cu`) files. (GCC cannot compile GPU code.)
---
# Operator Register
# Operator Registration
## Why register is necessary?
## Why registration is necessary?
We need a method to build mappings between Op type names and Op classes.
## How to do the register?
## How is registration implemented?
Maintain a map, whose key is the type name and value is corresponding Op constructor.
Maintaining a map, whose key is the type name and the value is the corresponding Op constructor.
---
# The Registry Map
......@@ -177,34 +178,34 @@ REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type, grad_op_class)
REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class)
```
### `USE` Macros
make sure the registration process is executed and linked.
### USE Macros
Make sure the registration process is executed and linked.
---
# Register Process
1. Write Op class, as well as its gradient Op class if there is.
2. Write Op maker class. In the constructor, describe its inputs, outputs, and attributes.
3. Invoke macro `REGISTER_OP`. The macro will
1. call maker class to complete `proto` and `checker`
2. with the completed `proto` and `checker`, build a new key-value pair in the `OpInfoMap`
# Registration Process
1. Write an Op class and its gradient Op class, if required.
2. Write an Op maker class. In the constructor of this class, describe the inputs, outputs and attributes of the operator.
3. Invoke the macro `REGISTER_OP`. This macro will
1. Call maker class to complete the `proto` and the `checker`
2. Using the completed `proto` and `checker`, it will add a new key-value pair to the `OpInfoMap`
4. Invoke `USE` macro in where the Op is used to make sure it is linked.
4. Invoke the `USE` macro in which the Op is used, to make sure that it is linked.
---
# Backward Module (1/2)
### Create Backward Operator
- Mapping from forwarding Op to backward Op
- Mapping from forward Op to backward Op
![backward](https://gist.githubusercontent.com/dzhwinter/a6fbd4623ee76c459f7f94591fd1abf0/raw/61026ab6e518e66bde66a889bc42557a1fccff33/backward.png)
---
# Backward Module (2/2)
### Build Backward Network
- **Input** graph of forwarding operators
- **Output** graph of backward operators
- **corner case in construction**
- shared variable => insert `Add` operator
- no gradient => insert `fill_zero_grad` operator
- recursive netOp => call `Backward` recursively
- **Input**: graph of forwarding operators
- **Output**: graph of backward operators
- **Corner cases in construction**
- Shared Variables => insert an `Add` operator to combine gradients
- No Gradient => insert a `fill_zero_grad` operator
- Recursive NetOp => call `Backward` recursively
- RNN Op => recursively call `Backward` on stepnet
......@@ -213,41 +214,41 @@ make sure the registration process is executed and linked.
* `Tensor` is an n-dimension array with type.
* Only dims and data pointers are stored in `Tensor`.
* All operators on `Tensor` is written in `Operator` or global functions.
* variable length Tensor design [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md)
* `Variable` is the inputs and outputs of an operator. Not just `Tensor`.
* step_scopes in RNN is a variable and not a tensor.
* `Scope` is where variables store at.
* map<string/*var name */, Variable>
* `Scope` has a hierarchical structure. The local scope can get variable from its parent scope.
* All operations on `Tensor` are written in `Operator` or global functions.
* Variable length Tensor design [LoDTensor](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/lod_tensor.md)
* `Variable` instances are the inputs and the outputs of an operator. Not just `Tensor`.
* `step_scopes` in RNN is a variable and not a tensor.
* `Scope` is where variables are stores.
* map<string `variable_name`, Variable>
* `Scope` has a hierarchical structure. The local scope can get variables from its parent scope.
---
# Block (in design)
## the difference with original RNNOp
- as an operator is more intuitive than `RNNOp`,
- offers new interface `Eval(targets)` to deduce the minimal block to `Run`,
- fits the compile-time/ runtime separation design.
- during the compilation, `SymbolTable` stores `VarDesc`s and `OpDesc`s and serialize to a `BlockDesc`
- when graph executes, a Block with `BlockDesc` passed in creates `Op` and `Var` then `Run`
- As an operator is more intuitive than `RNNOp`,
- Offers a new interface `Eval(targets)` to deduce the minimal block to `Run`,
- Fits the compile-time/ runtime separation design paradigm.
- During the compilation, `SymbolTable` stores `VarDesc`s and `OpDesc`s and serialize to a `BlockDesc`
- When graph executes, a Block with `BlockDesc` is passed. It then creates `Op` and `Var` instances and then invokes `Run`.
---
# Milestone
- take Paddle/books as the main line, the requirement of the models motivates framework refactoring,
- model migration
- framework development gives **priority support** to model migration, for example,
- Take Paddle/books as the main line, the requirement of the models motivates framework refactoring,
- Model migration
- Framework development gives **priority support** to model migration, for example,
- the MNIST demo needs a Python interface,
- the RNN models require the framework to support `LoDTensor`.
- determine some timelines,
- heavily-relied Ops need to be migrated first,
- different models can be migrated parallelly.
- improve the framework at the same time
- accept imperfection, concentrated on solving the specific problem at the right price.
- Determine some timelines,
- Frequently used Ops need to be migrated first,
- Different models can be migrated in parallel.
- Improve the framework at the same time
- Accept imperfection, concentrate on solving the specific problem at the right price.
---
# Control the migration quality
- compare the performance of migrated models with old ones.
- follow google C style
- build the automatic workflow of generating Python/C++ documentations
- the documentation of layers and ops should be written inside the code
- take the documentation quality into account when doing PR
- preview the documentations, read and improve them from users' perspective
- Compare the performance of migrated models with old ones.
- Follow the google C++ style
- Build the automatic workflow of generating Python/C++ documentations.
- The documentation of layers and ops should be written inside the code.
- Take the documentation quality into account when submitting pull requests.
- Preview the documentations, read and improve them from a user's perspective.
# Design for TensorArray
TensorArray as a new concept is borrowed from TensorFlow,
it is meant to be used with dynamic iteration primitives such as `while_loop` and `map_fn`.
This concept can be used to support our new design of dynamic operations, and help to refactor some existing variant-sentence-related layers,
such as `RecurrentGradientMachine`.
In [our design for dynamic RNN](https://github.com/PaddlePaddle/Paddle/pull/4401),
`TensorArray` is used to segment inputs and store states in all time steps.
By providing some methods similar to a C++ array,
the definition of some state-based dynamic models such as RNN could be more natural and highly flexible.
## Dynamic-Related Methods
Some basic methods should be proposed as follows:
### stack()
Pack the values in a `TensorArray` into a tensor with rank one higher than each tensor in `values`.
### unstack(axis=0)
Unpacks the given dimension of a rank-`R` tensor into rank-`(R-1)` tensors.
### concat()
Return the values in the `TensorArray` as a concatenated Tensor.
### write(index, value, data_shared=true)
Write value into index of the TensorArray.
### read(index)
Read the value at location `index` in the `TensorArray`.
### size()
Return the number of values.
## LoDTensor-related Supports
The `RecurrentGradientMachine` in Paddle serves as a flexible RNN layer; it takes variant length sequences as input,
because each step of RNN could only take a tensor-represented batch of data as input,
some preprocess should be taken on the inputs such as sorting the sentences by their length in descending order and cut each word and pack to new batches.
Such cut-like operations can be embedded into `TensorArray` as general methods called `unpack` and `pack`.
With these two methods, a variant-sentence-RNN can be implemented like
```c++
// input is the varient-length data
LodTensor sentence_input(xxx);
TensorArray ta;
Tensor indice_map;
Tensor boot_state = xxx; // to initialize rnn's first state
TensorArray::unpack(input, 1/*level*/, true/*sort_by_length*/, &ta, &indice_map);
TessorArray step_outputs;
TensorArray states;
for (int step = 0; step = ta.size(); step++) {
auto state = states.read(step);
// rnnstep is a function which acts like a step of RNN
auto step_input = ta.read(step);
auto step_output = rnnstep(step_input, state);
step_outputs.write(step_output, true/*data_shared*/);
}
// rnn_output is the final output of an rnn
LoDTensor rnn_output = ta.pack(ta, indice_map);
```
the code above shows that by embedding the LoDTensor-related preprocess operations into `TensorArray`,
the implementation of a RNN that supports varient-length sentences is far more concise than `RecurrentGradientMachine` because the latter mixes all the codes together, hard to read and extend.
some details are as follows.
### unpack(level, sort_by_length)
Split LodTensor in some `level` and generate batches, if set `sort_by_length`, will sort by length.
Returns:
- a new `TensorArray`, whose values are LodTensors and represents batches of data.
- an int32 Tensor, which stores the map from the new batch's indices to original LoDTensor
### pack(level, indices_map)
Recover the original LoD-arranged LoDTensor with the values in a `TensorArray` and `level` and `indices_map`.
......@@ -182,7 +182,7 @@ Note that **different devices (CPU, GPU)share an Op definition; whether or not t
`MulOp`'s CPU and GPU share the same `Kernel`. A non-sharing `OpKernel` example can be seen in [`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43).
To ease the writing of `OpKernel` compute, and for reusing code cross-device, `Eigen unsupported Tensor` module is used to implement `Compute` interface. To learn about how the Eigen library is used in PaddlePaddle, please see [usage document](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/use_eigen_cn.md).
To ease the writing of `OpKernel` compute, and for reusing code cross-device, [`Eigen-unsupported Tensor`](https://bitbucket.org/eigen/eigen/src/default/unsupported/Eigen/CXX11/src/Tensor/README.md?fileviewer=file-view-default) module is used to implement `Compute` interface. To learn about how the Eigen library is used in PaddlePaddle, please see [usage document](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/use_eigen_cn.md).
This concludes the forward implementation of an operator. Next its operation and kernel need to be registered in a `.cc` file.
......
## How to use Eigen in Paddle
Essentially, a neural network is a compute graph. T data needed for the computation is stored in `Tensor`s and its computation procedure is described by `Operator`s. An `Operator` calls the `Compute` interface in its corresponding `OpKernel` and operates on the `Tensor`.
### Eigen Tensor Module
The Eigen Tensor module supports powerful element-wise computation. In addition, a piece of code written using it can be run on both the CPU and the GPU.
Note that Eigen Tensor is still being actively developed, so its tests are not completely covered and its documentation may be sparse.
For details on Eigen Tensor module, please see [doc 1](https://github.com/RLovelett/eigen/blob/master/unsupported/Eigen/CXX11/src/Tensor/README.md) and [doc 2](https://bitbucket.org/eigen/eigen/src/default/unsupported/Eigen/CXX11/src/Tensor/README.md).
### paddle::framework::Tensor
Paddle Tensor's is defined in the framework directory with the following interface:
```cpp
class Tensor {
public:
/*! Return a pointer to mutable memory block. */
template <typename T>
inline T* data();
/**
* @brief Return a pointer to mutable memory block.
* @note If not exist, then allocation.
*/
template <typename T>
inline T* mutable_data(platform::Place place);
/**
* @brief Return a pointer to mutable memory block.
*
* @param[in] dims The dimensions of the memory block.
* @param[in] place The place of the memory block.
*
* @note If not exist, then allocation.
*/
template <typename T>
inline T* mutable_data(DDim dims, platform::Place place);
/*! Resize the dimensions of the memory block. */
inline Tensor& Resize(const DDim& dims);
/*! Return the dimensions of the memory block. */
inline const DDim& dims() const;
private:
/*! holds the memory block if allocated. */
std::shared_ptr<Placeholder> holder_;
/*! points to dimensions of memory block. */
DDim dim_;
};
```
`Placeholder` is used to delay memory allocation; that is, we can first define a tensor, using `Resize` to configure its shape, and then call `mutuable_data` to allocate the actual memory.
```cpp
paddle::framework::Tensor t;
paddle::platform::CPUPlace place;
// set size first
t.Resize({2, 3});
// allocate memory on CPU later
t.mutable_data(place);
```
### paddle::framework::Tensor Usage
`AddOp` demonstrates Tensor's usage.
- InferShape
When computing a neural network's compute graph, first call every `Operator`'s `InferShape` method, and use `Resize` to configure the size of the output tensor.
```cpp
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
ctx.Input<Tensor>("Y")->dims(),
"Two input of Add Op's dimension must be same.");
ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("X")->dims());
}
```
- Run
```cpp
void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Y");
auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
auto x = EigenVector<T>::Flatten(*input0);
auto y = EigenVector<T>::Flatten(*input1);
auto z = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();
z.device(place) = x + y;
}
```
### paddle::framework::Tensor到EigenTensor的转换
As shown above, in actual computation, we need to transform the input and output `Tensor`s into formats Eigen supports. We show some functions in [eigen.h](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/eigen.h) to implement the transformation from `paddle::framework::Tensor`to `EigenTensor/EigenMatrix/EigenVector/EigenScalar`.
Using EigenTensor as an example:
```cpp
Tensor t;
float* p = t.mutable_data<float>(make_ddim({1, 2, 3}), platform::CPUPlace());
for (int i = 0; i < 1 * 2 * 3; i++) {
p[i] = static_cast<float>(i);
}
EigenTensor<float, 3>::Type et = EigenTensor<float, 3>::From(t);
```
`From` is an interfacing method provided by the EigenTensor template, which implements the transformation from a `paddle::framework::Tensor` object to an EigenTensor. Since `rank` is a template parameter, it needs to be explicitly specified at the time of the transformation.
In Eigen, tensors with different ranks are different types, with `Vector` bring a rank-1 instance. Note that `EigenVector<T>::From` uses a transformation from an 1-dimensional Paddle tensor to a 1-dimensional Eigen tensor while `EigenVector<T>::Flatten` reshapes a paddle tensor and flattens it into a 1-dimensional Eigen tensor. Both resulting tensors are still typed EigenVector.
For more transformations, see the [unit tests](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/framework/eigen_test.cc) in the `eigen_test.cc` file.
### Implementing Computation
While computing, the device interface is needed from the EigenTensors on the left hand side of the assignments. Note that the computation between EigenTensors only changes the data originally inthe Tensor and does not change all the shape information associated with the Tensor.
```cpp
auto x = EigenVector<T>::Flatten(*input0);
auto y = EigenVector<T>::Flatten(*input1);
auto z = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();
z.device(place) = x + y;
```
In this code segment, input0/input1/output can be Tensors of arbitrary dimension. We are calling Flatten from EigenVector, transforming a tensor of any dimension into a 1-dimensional EigenVector. After completing computation, input0/input1/output will retain the same shape information, and they can be resized using the `Resize` interface.
Because the Eigen Tensor module is under-documented, please refer to `OpKernel`'s computation code in TensorFlow's [kernel module documentation](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/kernels).
......@@ -19,13 +19,14 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope)
proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator)
cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc)
cc_library(op_registry SRCS op_registry.cc DEPS grad_op_builder op_proto_maker op_info)
cc_test(op_registry_test SRCS op_registry_test.cc DEPS op_registry)
cc_test(grad_op_builder_test SRCS grad_op_builder_test.cc DEPS grad_op_builder op_registry add_op)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/block_desc.h"
#include "paddle/framework/program_desc.h"
namespace paddle {
namespace framework {
VarDescBind *BlockDescBind::NewVar(const std::string &name) {
need_update_ = true;
auto it = vars_.find(name);
PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name);
auto var = new VarDescBind(name);
vars_[name].reset(var);
return var;
}
VarDescBind *BlockDescBind::Var(const std::string &name) const {
auto it = vars_.find(name);
PADDLE_ENFORCE(it != vars_.end(),
"Can not find variable %s in current block.", name);
return it->second.get();
}
std::vector<VarDescBind *> BlockDescBind::AllVars() const {
std::vector<VarDescBind *> res;
for (const auto &p : vars_) {
res.push_back(p.second.get());
}
return res;
}
OpDescBind *BlockDescBind::AppendOp() {
need_update_ = true;
ops_.emplace_back(new OpDescBind());
return ops_.back().get();
}
OpDescBind *BlockDescBind::PrependOp() {
need_update_ = true;
ops_.emplace_front(new OpDescBind());
return ops_.front().get();
}
std::vector<OpDescBind *> BlockDescBind::AllOps() const {
std::vector<OpDescBind *> res;
for (const auto &op : ops_) {
res.push_back(op.get());
}
return res;
}
void BlockDescBind::Sync() {
if (need_update_) {
auto &op_field = *this->desc_->mutable_ops();
op_field.Clear();
op_field.Reserve(static_cast<int>(ops_.size()));
for (auto &op_desc : ops_) {
op_field.AddAllocated(op_desc->Proto());
}
need_update_ = false;
}
}
BlockDescBind *BlockDescBind::ParentBlock() const {
if (this->desc_->parent_idx() == -1) {
return nullptr;
}
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
}
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.RawPtr();
this->attrs_[name] = desc;
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <deque>
#include <unordered_map>
#include <vector>
#include "paddle/framework/op_desc.h"
#include "paddle/framework/var_desc.h"
namespace paddle {
namespace framework {
class ProgramDescBind;
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
// read/write speed. Only when we want the protobuf message, the local changes
// will be synchronized (by `Sync` method).
class BlockDescBind {
public:
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {}
BlockDescBind(const BlockDescBind &o) = delete;
BlockDescBind &operator=(const BlockDescBind &o) = delete;
int32_t ID() const { return desc_->idx(); }
int32_t Parent() const { return desc_->parent_idx(); }
VarDescBind *NewVar(const std::string &name_bytes);
VarDescBind *Var(const std::string &name_bytes) const;
std::vector<VarDescBind *> AllVars() const;
BlockDescBind *ParentBlock() const;
OpDescBind *AppendOp();
OpDescBind *PrependOp();
std::vector<OpDescBind *> AllOps() const;
void Sync();
BlockDesc *RawPtr() { return desc_; }
private:
ProgramDescBind *prog_; // not_own
BlockDesc *desc_; // not_own
bool need_update_;
std::deque<std::unique_ptr<OpDescBind>> ops_;
std::unordered_map<std::string, std::unique_ptr<VarDescBind>> vars_;
};
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <typeindex>
#include "paddle/framework/framework.pb.h"
namespace paddle {
namespace framework {
inline DataType ToDataType(std::type_index type) {
if (typeid(float).hash_code() == type.hash_code()) {
return DataType::FP32;
} else if (typeid(double).hash_code() == type.hash_code()) {
return DataType::FP64;
} else if (typeid(int).hash_code() == type.hash_code()) {
return DataType::INT32;
} else {
PADDLE_THROW("Not supported");
return static_cast<DataType>(-1);
}
}
} // namespace framework
} // namespace paddle
......@@ -54,5 +54,44 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
return grad_info.Creator()(info.grad_op_type_, inputs, outputs, op->Attrs());
}
static void TransOpDescArg(const OpDescBind* src_op, const OpArgType& src_type,
bool is_grad, OpDescBind* dst_op,
const OpArgType& dst_type) {
PADDLE_ENFORCE(dst_op != nullptr,
"Protobuf desc of gradient op must be initialized first.");
const auto& proto = OpInfoMap::Instance().Get(src_op->Type()).Proto();
const auto& src_arg_list =
src_type == OpArgType::IN ? proto.inputs() : proto.outputs();
for (const auto& arg : src_arg_list) {
if (arg.not_in_gradient() && !is_grad) continue;
const std::string src_name = arg.name();
std::vector<std::string> vars = src_type == OpArgType::IN
? src_op->Input(src_name)
: src_op->Output(src_name);
if (is_grad) {
for (std::string& var : vars) {
var = GradVarName(var);
}
}
std::string dst_name = is_grad ? GradVarName(src_name) : src_name;
dst_type == OpArgType::IN ? dst_op->SetInput(dst_name, vars)
: dst_op->SetOutput(dst_name, vars);
}
}
void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op) {
auto& info = OpInfoMap::Instance().Get(forw_op->Type());
PADDLE_ENFORCE(info.HasGradientOp());
grad_op->SetType(info.grad_op_type_);
TransOpDescArg(forw_op, OpArgType::IN, false, grad_op, OpArgType::IN);
TransOpDescArg(forw_op, OpArgType::OUT, false, grad_op, OpArgType::IN);
TransOpDescArg(forw_op, OpArgType::OUT, true, grad_op, OpArgType::IN);
TransOpDescArg(forw_op, OpArgType::IN, true, grad_op, OpArgType::OUT);
grad_op->SetAttrMap(forw_op->GetAttrMap());
}
} // namespace framework
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/framework/op_desc.h"
#include "paddle/framework/operator.h"
namespace paddle {
......@@ -21,5 +22,7 @@ namespace framework {
OperatorBase* BuildGradOp(const OperatorBase* op);
void CompleteGradOpDesc(const OpDescBind* forw_op, OpDescBind* grad_op);
} // namespace framework
} // namespace paddle
......@@ -120,3 +120,82 @@ TEST(GradOpBuilder, IOIgnoredInGradient) {
std::vector<std::string>(
{f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
}
TEST(GradOpDescBuilder, MutiInOut) {
f::OpDescBind *forw_op = new f::OpDescBind();
forw_op->SetType("mult_io");
forw_op->SetInput("In1", {"in1"});
forw_op->SetInput("In2_mult", {"in2_1", "in2_2", "in2_3"});
forw_op->SetInput("In3", {"in3"});
forw_op->SetOutput("Out1", {"out1"});
forw_op->SetOutput("Out2_mult", {"out2_1", "out2_2"});
f::OpDescBind *grad_op = new f::OpDescBind();
f::CompleteGradOpDesc(forw_op, grad_op);
EXPECT_EQ(grad_op->Type(), "mult_io_grad");
ASSERT_EQ(grad_op->InputNames().size(), 3UL + 2UL + 2UL);
EXPECT_EQ(grad_op->Input("In1"), std::vector<std::string>({"in1"}));
EXPECT_EQ(grad_op->Input("In2_mult"),
std::vector<std::string>({"in2_1", "in2_2", "in2_3"}));
EXPECT_EQ(grad_op->Input("In3"), std::vector<std::string>({"in3"}));
EXPECT_EQ(grad_op->Input("Out1"), std::vector<std::string>({"out1"}));
EXPECT_EQ(grad_op->Input("Out2_mult"),
std::vector<std::string>({"out2_1", "out2_2"}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out1")),
std::vector<std::string>({f::GradVarName("out1")}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out2_mult")),
std::vector<std::string>(
{f::GradVarName("out2_1"), f::GradVarName("out2_2")}));
ASSERT_EQ(grad_op->OutputNames().size(), 3UL);
EXPECT_EQ(grad_op->Output(f::GradVarName("In1")),
std::vector<std::string>({f::GradVarName("in1")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")),
std::vector<std::string>({f::GradVarName("in2_1"),
f::GradVarName("in2_2"),
f::GradVarName("in2_3")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In3")),
std::vector<std::string>({f::GradVarName("in3")}));
delete forw_op;
delete grad_op;
}
TEST(GradOpDescBuilder, IOIgnoredInGradient) {
f::OpDescBind *forw_op = new f::OpDescBind();
forw_op->SetType("io_ignored");
forw_op->SetInput("In1", {"in1"});
forw_op->SetInput("In2_mult", {"in2_1", "in2_2"});
forw_op->SetInput("In3_mult", {"in3_1", "in3_2"});
forw_op->SetOutput("Out1_mult", {"out1_1", "out1_2"});
forw_op->SetOutput("Out2", {"out2"});
f::OpDescBind *grad_op = new f::OpDescBind();
f::CompleteGradOpDesc(forw_op, grad_op);
EXPECT_EQ(grad_op->Type(), "io_ignored_grad");
// 'In2' and 'Out2' are ignored in gradient calculating
ASSERT_EQ(grad_op->InputNames().size(), 2UL + 1UL + 2UL);
EXPECT_EQ(grad_op->Input("In1"), std::vector<std::string>({"in1"}));
EXPECT_EQ(grad_op->Input("In3_mult"),
std::vector<std::string>({"in3_1", "in3_2"}));
EXPECT_EQ(grad_op->Input("Out1_mult"),
std::vector<std::string>({"out1_1", "out1_2"}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out1_mult")),
std::vector<std::string>(
{f::GradVarName("out1_1"), f::GradVarName("out1_2")}));
EXPECT_EQ(grad_op->Input(f::GradVarName("Out2")),
std::vector<std::string>({f::GradVarName("out2")}));
ASSERT_EQ(grad_op->OutputNames().size(), 3UL);
EXPECT_EQ(grad_op->Output(f::GradVarName("In1")),
std::vector<std::string>({f::GradVarName("in1")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In2_mult")),
std::vector<std::string>(
{f::GradVarName("in2_1"), f::GradVarName("in2_2")}));
EXPECT_EQ(grad_op->Output(f::GradVarName("In3_mult")),
std::vector<std::string>(
{f::GradVarName("in3_1"), f::GradVarName("in3_2")}));
delete forw_op;
delete grad_op;
}
\ No newline at end of file
......@@ -72,6 +72,22 @@ bool operator==(const LoD& a, const LoD& b) {
return true;
}
size_t LoDTensor::NumElements(size_t level, size_t idx) const {
PADDLE_ENFORCE_LT(level, NumLevels());
PADDLE_ENFORCE_LT(idx, NumElements(level));
// the last level of LoD, just return number of records in Tensor
if (level == NumLevels() - 1) {
return lod_[level][idx + 1] - lod_[level][idx];
}
// high level of LoD, and there is another lower level, return number of
// lower-level elements
auto tmp = SliceInLevel(lod_, level, idx, idx + 1);
PADDLE_ENFORCE_GE(tmp.size(), 2);
// there is a 0 as a placeholder stored in LoD, so the number of elements
// equals lod.size() - 1
return tmp[1].size() - 1;
}
void LoDTensor::ShrinkLevels(size_t level_begin, size_t level_end) {
auto new_lod = framework::SliceLevels(lod_, level_begin, level_end);
lod_ = new_lod;
......
......@@ -38,6 +38,18 @@ using Vector = thrust::host_vector<
T, thrust::system::cuda::experimental::pinned_allocator<T>>;
#endif
/*
* 3-level LoD stores
*
* 0 10 20
* 0 5 10 15 20
* 0 2 5 7 10 12 15 20
*
* - in a level, each element indicates offset in the underlying Tensor
* - the first element should be 0 and that indicates that this sequence start
* from 0
* - each sequence's begin and end(no-inclusive) is level[id, id+1]
*/
using LoD = std::vector<Vector<size_t>>;
LoD SliceLevels(const LoD& in, size_t level_begin, size_t level_end);
......@@ -65,11 +77,8 @@ class LoDTensor : public Tensor {
* Get a element from LoD.
*/
size_t lod_element(size_t level, size_t elem) const {
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE(elem < NumElements(level),
"element begin [%d] out of range [%d]", elem,
NumElements(level));
PADDLE_ENFORCE_LT(level, NumLevels());
PADDLE_ENFORCE_LT(elem, NumElements(level));
return (lod_)[level][elem];
}
......@@ -82,12 +91,23 @@ class LoDTensor : public Tensor {
* Number of elements in a level.
*/
size_t NumElements(size_t level = 0) const {
PADDLE_ENFORCE(level < NumLevels(), "level [%d] out of range [%d]", level,
NumLevels());
PADDLE_ENFORCE_LT(level, NumLevels());
// the last offset is the end of last element
return (lod_)[level].size() - 1;
}
/*
* Number of lower-level elements.
* For example, a 2-level lod-tensor
*
* 0-th level | |
* 1-th level || |||
*
* NumElements(0, 0) get 2
* NumElements(0, 1) get 3
*/
size_t NumElements(size_t level, size_t idx) const;
/*
* Shrink levels[level_begin:level_end]
*/
......
......@@ -56,6 +56,12 @@ TEST_F(LoDTensorTester, NumElements) {
ASSERT_EQ(lod_tensor_.NumElements(2), 8UL);
}
TEST_F(LoDTensorTester, NumElements2) {
ASSERT_EQ(lod_tensor_.NumElements(0, 0), 2UL);
ASSERT_EQ(lod_tensor_.NumElements(0, 1), 2UL);
ASSERT_EQ(lod_tensor_.NumElements(1, 1), 2UL);
}
TEST_F(LoDTensorTester, ShrinkLevels) {
// slice 1 level
for (size_t level = 0; level < 3UL; ++level) {
......@@ -65,7 +71,7 @@ TEST_F(LoDTensorTester, ShrinkLevels) {
ASSERT_EQ(new_lod_tensor.NumElements(0), lod_tensor_.NumElements(level));
ASSERT_EQ(new_lod_tensor.data<float>(), lod_tensor_.data<float>());
}
// slice 2 level
// shrink 2 level
for (size_t level = 0; level < 2UL; ++level) {
LoDTensor new_lod_tensor = lod_tensor_;
new_lod_tensor.ShrinkLevels(level, level + 2);
......
......@@ -36,8 +36,8 @@ TEST(LoDTensor, LoDInGPU) {
lod_tensor.mutable_data<float>(place);
lod_tensor.set_lod(src_lod);
CHECK_EQ(lod_tensor.lod_element(0, 2), 4);
CHECK_EQ(lod_tensor.lod_element(0, 4), 8);
CHECK_EQ(lod_tensor.lod_element(0, 2), 4UL);
CHECK_EQ(lod_tensor.lod_element(0, 4), 8UL);
auto lod = lod_tensor.lod();
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_desc.h"
#include "paddle/framework/block_desc.h"
namespace paddle {
namespace framework {
OpDesc *OpDescBind::Proto() {
Sync();
return &op_desc_;
}
const std::vector<std::string> &OpDescBind::Input(
const std::string &name) const {
auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name,
Type());
return it->second;
}
std::vector<std::string> OpDescBind::InputNames() const {
std::vector<std::string> retv;
retv.reserve(this->inputs_.size());
for (auto &ipt : this->inputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetInput(const std::string &param_name,
const std::vector<std::string> &args) {
need_update_ = true;
inputs_[param_name] = args;
}
const std::vector<std::string> &OpDescBind::Output(
const std::string &name) const {
auto it = outputs_.find(name);
PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s",
name, Type());
return it->second;
}
std::vector<std::string> OpDescBind::OutputNames() const {
std::vector<std::string> retv;
retv.reserve(this->outputs_.size());
for (auto &ipt : this->outputs_) {
retv.push_back(ipt.first);
}
return retv;
}
void OpDescBind::SetOutput(const std::string &param_name,
const std::vector<std::string> &args) {
need_update_ = true;
this->outputs_[param_name] = args;
}
AttrType OpDescBind::GetAttrType(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return static_cast<AttrType>(it->second.which() - 1);
}
std::vector<std::string> OpDescBind::AttrNames() const {
std::vector<std::string> retv;
retv.reserve(attrs_.size());
for (auto &attr : attrs_) {
retv.push_back(attr.first);
}
return retv;
}
void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
this->attrs_[name] = v;
need_update_ = true;
}
void OpDescBind::SetAttrMap(
const std::unordered_map<std::string, Attribute> &attr_map) {
attrs_ = attr_map;
need_update_ = true;
}
Attribute OpDescBind::GetAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return it->second;
}
int OpDescBind::GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return boost::get<BlockDesc *>(it->second)->idx();
}
const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()
const {
return attrs_;
}
void OpDescBind::Sync() {
if (need_update_) {
this->op_desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) {
auto *input = op_desc_.add_inputs();
input->set_parameter(ipt.first);
VectorToRepeated(ipt.second, input->mutable_arguments());
}
this->op_desc_.mutable_outputs()->Clear();
for (auto &opt : outputs_) {
auto *output = op_desc_.add_outputs();
output->set_parameter(opt.first);
VectorToRepeated(opt.second, output->mutable_arguments());
}
this->op_desc_.mutable_attrs()->Clear();
for (auto &attr : attrs_) {
auto *attr_desc = op_desc_.add_attrs();
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1));
boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second);
}
need_update_ = false;
}
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <unordered_map>
#include <vector>
#include "paddle/framework/attribute.h"
#include "paddle/framework/var_desc.h"
namespace paddle {
namespace framework {
class BlockDescBind;
class OpDescBind {
public:
OpDesc *Proto();
std::string Type() const { return op_desc_.type(); }
void SetType(const std::string &type) { op_desc_.set_type(type); }
const std::vector<std::string> &Input(const std::string &name) const;
std::vector<std::string> InputNames() const;
void SetInput(const std::string &param_name,
const std::vector<std::string> &args);
const std::vector<std::string> &Output(const std::string &name) const;
std::vector<std::string> OutputNames() const;
void SetOutput(const std::string &param_name,
const std::vector<std::string> &args);
std::string DebugString() { return this->Proto()->DebugString(); }
bool HasAttr(const std::string &name) const {
return attrs_.find(name) != attrs_.end();
}
AttrType GetAttrType(const std::string &name) const;
std::vector<std::string> AttrNames() const;
void SetAttr(const std::string &name, const Attribute &v);
void SetBlockAttr(const std::string &name, BlockDescBind &block);
// Only be used in C++
void SetAttrMap(const std::unordered_map<std::string, Attribute> &attr_map);
Attribute GetAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const;
// Only be used in C++
const std::unordered_map<std::string, Attribute> &GetAttrMap() const;
private:
struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
void operator()(bool b) const { attr_->set_b(b); }
void operator()(const std::vector<int> &v) const {
VectorToRepeated(v, attr_->mutable_ints());
}
void operator()(const std::vector<float> &v) const {
VectorToRepeated(v, attr_->mutable_floats());
}
void operator()(const std::vector<std::string> &v) const {
VectorToRepeated(v, attr_->mutable_strings());
}
void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc *desc) const {
attr_->set_block_idx(desc->idx());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
void Sync();
OpDesc op_desc_;
std::unordered_map<std::string, std::vector<std::string>> inputs_;
std::unordered_map<std::string, std::vector<std::string>> outputs_;
std::unordered_map<std::string, Attribute> attrs_;
// need_update_ indicate there some local changes not be synchronized. If
// local changes should be synchronized, need_update_ should be set to true.
bool need_update_{false};
};
} // namespace framework
} // namespace paddle
......@@ -100,13 +100,39 @@ class OpRegistrar : public Registrar {
}
};
template <typename PlaceType, typename KernelType>
template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor;
template <typename PlaceType, size_t I, typename... KernelTypes>
struct OpKernelRegistrarFunctor<PlaceType, false, I, KernelTypes...> {
using KERNEL_TYPE =
typename std::tuple_element<I, std::tuple<KernelTypes...>>::type;
void operator()(const char* op_type) const {
using T = typename KERNEL_TYPE::ELEMENT_TYPE;
OperatorWithKernel::OpKernelKey key(ToDataType(std::type_index(typeid(T))),
PlaceType());
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KERNEL_TYPE);
constexpr auto size = std::tuple_size<std::tuple<KernelTypes...>>::value;
OpKernelRegistrarFunctor<PlaceType, I + 1 == size, I + 1, KernelTypes...>
func;
func(op_type);
}
};
template <typename PlaceType, size_t I, typename... KernelType>
struct OpKernelRegistrarFunctor<PlaceType, true, I, KernelType...> {
void operator()(const char* op_type) const {}
};
// User can register many kernel in one place. The data type could be different.
template <typename PlaceType, typename... KernelType>
class OpKernelRegistrar : public Registrar {
public:
explicit OpKernelRegistrar(const char* op_type) {
OperatorWithKernel::OpKernelKey key;
key.place_ = PlaceType();
OperatorWithKernel::AllOpKernels()[op_type][key].reset(new KernelType);
OpKernelRegistrarFunctor<PlaceType, false, 0, KernelType...> func;
func(op_type);
}
};
......
......@@ -10,7 +10,6 @@ class CosineOp : public OperatorBase {
using OperatorBase::OperatorBase;
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
void InferShape(const Scope& scope) const override {}
};
class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
......@@ -29,7 +28,6 @@ class CosineOpProtoAndCheckerMaker : public OpProtoAndCheckerMaker {
class MyTestOp : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
};
......
......@@ -22,14 +22,14 @@ namespace framework {
template <>
Eigen::DefaultDevice& ExecutionContext::GetEigenDevice<
platform::CPUPlace, Eigen::DefaultDevice>() const {
return *device_context_.get_eigen_device<Eigen::DefaultDevice>();
return *device_context_.GetEigenDevice<platform::CPUPlace>();
}
#ifndef PADDLE_ONLY_CPU
template <>
Eigen::GpuDevice&
ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
return *device_context_.get_eigen_device<Eigen::GpuDevice>();
return *device_context_.GetEigenDevice<platform::GPUPlace>();
}
#endif
......
......@@ -15,12 +15,14 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <atomic>
#include <string>
#include <unordered_map>
#include <vector>
#include "op_info.h"
#include "paddle/framework/attribute.h"
#include "paddle/framework/data_type.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/scope.h"
......@@ -82,10 +84,6 @@ class OperatorBase {
virtual std::string DebugString() const;
/// InferShape infer the size of Variables used by this Operator with
/// information inside scope
virtual void InferShape(const Scope& scope) const = 0;
/// Net will call this function to Run an op.
virtual void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const = 0;
......@@ -163,7 +161,6 @@ class OperatorBase {
class NOP : public OperatorBase {
public:
using OperatorBase::OperatorBase;
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {}
std::unique_ptr<OperatorBase> Clone() const override {
......@@ -299,21 +296,6 @@ template <>
std::vector<Tensor*> InferShapeContext::MultiOutput<Tensor>(
const std::string& name) const;
template <typename T>
struct EigenDeviceConverter;
template <>
struct EigenDeviceConverter<platform::CPUPlace> {
using EigenDeviceType = Eigen::DefaultDevice;
};
#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
using EigenDeviceType = Eigen::GpuDevice;
};
#endif
class ExecutionContext : public InferShapeContext {
public:
ExecutionContext(const OperatorBase& op, const Scope& scope,
......@@ -321,8 +303,8 @@ class ExecutionContext : public InferShapeContext {
: InferShapeContext(op, scope), device_context_(device_context) {}
template <typename PlaceType,
typename DeviceType =
typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
typename DeviceType = typename platform::EigenDeviceConverter<
PlaceType>::EigenDeviceType>
DeviceType& GetEigenDevice() const;
platform::Place GetPlace() const { return device_context_.GetPlace(); }
......@@ -407,7 +389,7 @@ class RuntimeInferShapeContext : public InferShapeContextBase {
const Scope& scope_;
};
class OpKernel {
class OpKernelBase {
public:
/**
* ExecutionContext is the only parameter of Kernel Run function.
......@@ -418,48 +400,61 @@ class OpKernel {
virtual void Compute(const ExecutionContext& context) const = 0;
virtual ~OpKernel() {}
virtual ~OpKernelBase() = default;
};
template <typename T>
class OpKernel : public OpKernelBase {
public:
using ELEMENT_TYPE = T;
};
class OperatorWithKernel : public OperatorBase {
public:
struct OpKernelKey {
platform::Place place_;
DataType data_type_;
OpKernelKey() = default;
explicit OpKernelKey(const platform::DeviceContext& dev_ctx) {
place_ = dev_ctx.GetPlace();
}
OpKernelKey(DataType data_type, platform::Place place)
: place_(place), data_type_(data_type) {}
OpKernelKey(DataType data_type, const platform::DeviceContext& dev_ctx)
: place_(dev_ctx.GetPlace()), data_type_(data_type) {}
bool operator==(const OpKernelKey& o) const {
return platform::places_are_same_class(place_, o.place_);
return platform::places_are_same_class(place_, o.place_) &&
data_type_ == o.data_type_;
}
};
struct OpKernelHash {
std::hash<bool> hash_;
std::hash<int> hash_;
size_t operator()(const OpKernelKey& key) const {
return hash_(platform::is_gpu_place(key.place_));
int place = key.place_.which();
int data_type = static_cast<int>(key.data_type_);
int pre_hash = data_type << NUM_PLACE_TYPE_LIMIT_IN_BIT |
(place & ((1 << NUM_PLACE_TYPE_LIMIT_IN_BIT) - 1));
return hash_(pre_hash);
}
};
using OpKernelMap =
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
std::unordered_map<OpKernelKey, std::unique_ptr<OpKernelBase>,
OpKernelHash>;
OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
// runtime infershape
void InferShape(const Scope& scope) const override {
auto c = RuntimeInferShapeContext(*this, scope);
InferShape(&c);
}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const final {
auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
opKernel->Compute(ExecutionContext(*this, scope, dev_ctx));
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
ExecutionContext ctx(*this, scope, dev_ctx);
auto& opKernel = AllOpKernels().at(type_).at(
OpKernelKey(IndicateDataType(ctx), dev_ctx));
opKernel->Compute(ctx);
}
static std::unordered_map<std::string /* op_type */, OpKernelMap>&
......@@ -469,13 +464,43 @@ class OperatorWithKernel : public OperatorBase {
}
bool SupportGPU() const override {
OperatorWithKernel::OpKernelKey key;
key.place_ = platform::GPUPlace();
return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0;
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
[](OpKernelMap::const_reference kern_pair) {
return platform::is_gpu_place(kern_pair.first.place_);
});
}
protected:
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
// indicate kernel DataType by input data. Defaultly all input data must be
// same.
virtual DataType IndicateDataType(const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
int data_type = -1;
for (auto& input : this->inputs_) {
for (auto& ipt_name : input.second) {
auto* var = scope.FindVar(ipt_name);
if (var != nullptr) {
const Tensor* t = nullptr;
if (var->IsType<Tensor>()) {
t = &var->Get<Tensor>();
} else if (var->IsType<LoDTensor>()) {
t = &var->Get<LoDTensor>();
}
if (t != nullptr) {
int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE(tmp == data_type || data_type == -1,
"DataType of Paddle Op must be same.");
data_type = tmp;
}
}
}
}
PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
return static_cast<DataType>(data_type);
}
};
} // namespace framework
......
......@@ -27,7 +27,6 @@ class OpWithoutKernelTest : public OperatorBase {
OpWithoutKernelTest(const std::string& type, const VariableNameMap& inputs,
const VariableNameMap& outputs, const AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs), x(1) {}
void InferShape(const Scope& scope) const override {}
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
++op_run_num;
......@@ -87,7 +86,6 @@ TEST(OperatorBase, all) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope.NewVar("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0);
op->InferShape(scope);
op->Run(scope, device_context);
ASSERT_EQ(paddle::framework::op_run_num, 1);
}
......@@ -116,10 +114,13 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {}
DataType IndicateDataType(const ExecutionContext& ctx) const override {
return DataType::FP32;
}
};
template <typename T1, typename T2>
class CPUKernelTest : public OpKernel {
class CPUKernelTest : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
std::cout << "this is cpu kernel" << std::endl;
......@@ -146,7 +147,7 @@ class OpKernelTestMultiInputsProtoAndCheckerMaker
}
};
class CPUKernalMultiInputsTest : public OpKernel {
class CPUKernalMultiInputsTest : public OpKernel<float> {
public:
void Compute(const ExecutionContext& ctx) const {
auto xs = ctx.op().Inputs("xs");
......@@ -255,7 +256,6 @@ class OperatorClone : public paddle::framework::OperatorBase {
const paddle::framework::VariableNameMap& outputs,
const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void InferShape(const paddle::framework::Scope& scope) const override {}
void Run(const paddle::framework::Scope& scope,
const paddle::platform::DeviceContext& dev_ctx) const override {}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/program_desc.h"
#include "paddle/framework/block_desc.h"
namespace paddle {
namespace framework {
using ProgDescMap =
std::unordered_map<ProgramDesc *, std::unique_ptr<ProgramDescBind>>;
static ProgDescMap *g_bind_map = nullptr;
ProgramDescBind &ProgramDescBind::Instance(ProgramDesc *prog) {
if (g_bind_map == nullptr) {
g_bind_map = new ProgDescMap();
}
auto &map = *g_bind_map;
auto &ptr = map[prog];
if (ptr == nullptr) {
ptr.reset(new ProgramDescBind(prog));
}
return *ptr;
}
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
auto *b = prog_->add_blocks();
b->set_parent_idx(parent.ID());
b->set_idx(prog_->blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get();
}
ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) {
block->Sync();
}
return prog_;
}
ProgramDescBind::ProgramDescBind(ProgramDesc *prog) {
prog_ = prog;
for (auto &block : *prog->mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block));
}
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/framework/framework.pb.h"
namespace paddle {
namespace framework {
class BlockDescBind;
class ProgramDescBind {
public:
static ProgramDescBind &Instance(ProgramDesc *prog);
ProgramDescBind(const ProgramDescBind &o) = delete;
ProgramDescBind &operator=(const ProgramDescBind &o) = delete;
BlockDescBind *AppendBlock(const BlockDescBind &parent);
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
std::string DebugString() { return Proto()->DebugString(); }
size_t Size() const { return blocks_.size(); }
ProgramDesc *Proto();
private:
explicit ProgramDescBind(ProgramDesc *prog);
// Not owned
ProgramDesc *prog_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_;
};
} // namespace framework
} // namespace paddle
......@@ -29,20 +29,10 @@ limitations under the License. */
namespace paddle {
namespace pybind {
namespace details {
template <bool less, size_t i, typename... args>
struct CastToPyBufferImpl;
}
} // namespace pybind
namespace framework {
class Tensor {
public:
template <bool less, size_t i, typename... args>
friend struct pybind::details::CastToPyBufferImpl;
template <typename T, size_t D, int MajorType, typename IndexType>
friend struct EigenTensor;
......@@ -119,6 +109,8 @@ class Tensor {
return holder_->place();
}
std::type_index type() const { return holder_->type(); }
private:
template <typename T>
inline void check_memory_size() const;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/var_desc.h"
namespace paddle {
namespace framework {
void VarDescBind::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
}
void VarDescBind::SetDataType(DataType data_type) {
desc_.mutable_lod_tensor()->set_data_type(data_type);
}
std::vector<int64_t> VarDescBind::Shape() const {
return RepeatedToVector(desc_.lod_tensor().dims());
}
DataType VarDescBind::GetDataType() const {
return desc_.lod_tensor().data_type();
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/framework/framework.pb.h"
namespace paddle {
namespace framework {
// convert between std::vector and protobuf repeated.
template <typename T>
inline std::vector<T> RepeatedToVector(
const google::protobuf::RepeatedField<T> &repeated_field) {
std::vector<T> ret;
ret.reserve(repeated_field.size());
std::copy(repeated_field.begin(), repeated_field.end(),
std::back_inserter(ret));
return ret;
}
template <typename T, typename RepeatedField>
inline void VectorToRepeated(const std::vector<T> &vec,
RepeatedField *repeated_field) {
repeated_field->Reserve(vec.size());
for (const auto &elem : vec) {
*repeated_field->Add() = elem;
}
}
// Specialize vector<bool>.
template <typename RepeatedField>
inline void VectorToRepeated(const std::vector<bool> &vec,
RepeatedField *repeated_field) {
repeated_field->Reserve(vec.size());
for (auto elem : vec) {
*repeated_field->Add() = elem;
}
}
class VarDescBind {
public:
explicit VarDescBind(const std::string &name) { desc_.set_name(name); }
VarDesc *Proto() { return &desc_; }
std::string Name() const { return desc_.name(); }
void SetShape(const std::vector<int64_t> &dims);
void SetDataType(DataType data_type);
std::vector<int64_t> Shape() const;
DataType GetDataType() const;
private:
VarDesc desc_;
};
} // namespace framework
} // namespace paddle
......@@ -18,7 +18,6 @@ limitations under the License. */
#include "neon_util.h"
namespace paddle {
namespace neon {
#if defined(__ARM_NEON__) || defined(__ARM_NEON)
......@@ -26,17 +25,20 @@ namespace neon {
template <int filterSize, int stride>
struct DepthwiseConvKernel {};
inline float32_t conv3x3(float32x4_t r0,
float32x4_t r1,
float32x4_t r2,
inline float32_t conv3x3(const float* r0,
const float* r1,
const float* r2,
float32x4_t k0,
float32x4_t k1,
float32x4_t k2) {
float32x4_t tmp;
tmp = vmulq_f32(r0, k0);
tmp = vmlaq_f32(tmp, r1, k1);
tmp = vmlaq_f32(tmp, r2, k2);
return vaddvq_f32(tmp);
float32_t tmp[12];
vst1q_f32(&(tmp[0]), k0);
vst1q_f32(&(tmp[4]), k1);
vst1q_f32(&(tmp[8]), k2);
float32_t sum0 = r0[0] * tmp[0] + r0[1] * tmp[1] + r0[2] * tmp[2];
float32_t sum1 = r1[0] * tmp[4] + r1[1] * tmp[5] + r1[2] * tmp[6];
float32_t sum2 = r2[0] * tmp[8] + r2[1] * tmp[9] + r2[2] * tmp[10];
return sum0 + sum1 + sum2;
}
inline float32_t conv4x4(float32x4_t r0,
......@@ -136,10 +138,7 @@ struct DepthwiseConvKernel<3, 1> {
}
for (int r = 0; r < remain; r++) {
float32x4_t i0 = vld1q_f32(r0);
float32x4_t i1 = vld1q_f32(r1);
float32x4_t i2 = vld1q_f32(r2);
*outputData = conv3x3(i0, i1, i2, k[0], k[1], k[2]);
*outputData = conv3x3(r0, r1, r2, k[0], k[1], k[2]);
r0++;
r1++;
r2++;
......@@ -243,10 +242,7 @@ struct DepthwiseConvKernel<3, 2> {
}
for (int r = 0; r < remain; r++) {
float32x4_t i0 = vld1q_f32(r0);
float32x4_t i1 = vld1q_f32(r1);
float32x4_t i2 = vld1q_f32(r2);
*outputData = conv3x3(i0, i1, i2, k[0], k[1], k[2]);
*outputData = conv3x3(r0, r1, r2, k[0], k[1], k[2]);
r0 += 2;
r1 += 2;
r2 += 2;
......
......@@ -28,7 +28,7 @@ bool MKLDNNConvLayer::init(const LayerMap& layerMap,
if (!MKLDNNLayer::init(layerMap, parameterMap)) {
return false;
}
CHECK_EQ(inputLayers_.size(), 1) << "Only support one input layer yet";
CHECK_EQ(inputLayers_.size(), 1UL) << "Only support one input layer yet";
CHECK_EQ(inputLayers_.size(), parameters_.size());
CHECK(config_.shared_biases()) << "Only support shared biases yet";
......
......@@ -28,7 +28,7 @@ bool MKLDNNFcLayer::init(const LayerMap& layerMap,
return false;
}
CHECK_EQ(inputLayers_.size(), 1) << "Only support one input layer yet";
CHECK_EQ(inputLayers_.size(), 1UL) << "Only support one input layer yet";
CHECK_EQ(inputLayers_.size(), parameters_.size());
CHECK(!parameters_[0]->isSparse()) << "Do not support sparse yet";
......
......@@ -228,7 +228,7 @@ void genGroundTruth(vector<SingleBeamExpansion>& beamExpansions,
curBeam.groundTruth[j] = *(start + n);
curBeam.inBeam[j] = 1;
} else {
CHECK_LE(curBeam.rowIdxInBeam[j] + 1,
CHECK_LE((size_t)curBeam.rowIdxInBeam[j] + 1,
curBeam.subSeqStartPos.size() - 1);
int start = curBeam.subSeqStartPos[curBeam.rowIdxInBeam[j]];
int end = curBeam.subSeqStartPos[curBeam.rowIdxInBeam[j] + 1];
......
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
../framework/.clang-format
\ No newline at end of file
......@@ -80,6 +80,15 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
platform::GpuMemcpySync(dst, src, num, cudaMemcpyHostToDevice);
}
template <>
void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
void* dst,
platform::GPUPlace src_place,
const void* src, size_t num) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpySync(dst, src, num, cudaMemcpyDeviceToDevice);
}
#endif // PADDLE_ONLY_CPU
} // namespace memory
......
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
../framework/.clang-format
\ No newline at end of file
......@@ -67,6 +67,13 @@ function(op_library TARGET)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
endif()
# reduce_op contains several operators
if ("${TARGET}" STREQUAL "reduce_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(reduce_sum);\n")
endif()
# pybind USE_NO_KERNEL_OP
file(READ ${TARGET}.cc TARGET_CONTENT)
......@@ -100,8 +107,8 @@ set(DEPS_OPS
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS framework_proto tensor net_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy_function)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy_function softmax_function)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
......@@ -47,7 +47,7 @@ __global__ void AccuracyCudaKernel(const int N, const int D, const int* Xdata,
}
template <typename T>
class AccuracyOpCUDAKernel : public framework::OpKernel {
class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
......@@ -35,7 +35,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenScalar = framework::EigenScalar<T, MajorType, IndexType>;
template <typename Place, typename T>
class AccuracyKernel : public framework::OpKernel {
class AccuracyKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* inference = ctx.Input<Tensor>("Inference");
......
......@@ -132,6 +132,17 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftsignOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softsign operator");
AddOutput("Y", "Output of Softsign operator");
AddComment("Softsign activation operator, softsign(x) = x / (1 + |x|)");
}
};
template <typename AttrType>
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
......@@ -277,6 +288,15 @@ REGISTER_OP_CPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(softsign,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::SoftsignFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
softsign_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SoftsignGradFunctor<float>>);
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(brelu,
......
......@@ -80,6 +80,13 @@ REGISTER_OP_GPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(softsign,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::SoftsignFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
softsign_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SoftsignGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(brelu,
ops::BReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(brelu_grad,
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace operators {
template <typename Place, typename T, typename Functor>
class ActivationKernel : public framework::OpKernel {
class ActivationKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
......@@ -36,7 +36,7 @@ class ActivationKernel : public framework::OpKernel {
};
template <typename Place, typename T, typename Functor>
class ActivationGradKernel : public framework::OpKernel {
class ActivationGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
......@@ -201,8 +201,28 @@ struct SquareGradFunctor {
}
};
// softsign(x) = x / (1 + |x|)
template <typename T>
struct SoftsignFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x / (static_cast<T>(1) + x.abs());
}
};
// d(softsign(x))/dx = 1 / (1 + |x|)^2
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
struct SoftsignGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) =
dy * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
}
};
template <typename Place, typename T, typename AttrType = T>
class BReluKernel : public framework::OpKernel {
class BReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
......@@ -219,7 +239,7 @@ class BReluKernel : public framework::OpKernel {
};
template <typename Place, typename T, typename AttrType = T>
class BReluGradKernel : public framework::OpKernel {
class BReluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
......@@ -239,7 +259,7 @@ class BReluGradKernel : public framework::OpKernel {
};
template <typename Place, typename T, typename AttrType = T>
class SoftReluKernel : public framework::OpKernel {
class SoftReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
......@@ -256,7 +276,7 @@ class SoftReluKernel : public framework::OpKernel {
};
template <typename Place, typename T, typename AttrType = T>
class SoftReluGradKernel : public framework::OpKernel {
class SoftReluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
......@@ -277,7 +297,7 @@ class SoftReluGradKernel : public framework::OpKernel {
};
template <typename Place, typename T, typename AttrType = T>
class PowKernel : public framework::OpKernel {
class PowKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
......@@ -293,7 +313,7 @@ class PowKernel : public framework::OpKernel {
};
template <typename Place, typename T, typename AttrType = T>
class PowGradKernel : public framework::OpKernel {
class PowGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
......@@ -312,7 +332,7 @@ class PowGradKernel : public framework::OpKernel {
};
template <typename Place, typename T, typename AttrType = T>
class STanhKernel : public framework::OpKernel {
class STanhKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
......@@ -329,7 +349,7 @@ class STanhKernel : public framework::OpKernel {
};
template <typename Place, typename T, typename AttrType = T>
class STanhGradKernel : public framework::OpKernel {
class STanhGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
......
......@@ -25,7 +25,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class AddKernel : public framework::OpKernel {
class AddKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input0 = context.Input<Tensor>("X");
......
......@@ -56,7 +56,7 @@ class ClipGradFunctor {
};
template <typename Place, typename T>
class ClipKernel : public framework::OpKernel {
class ClipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
......@@ -73,7 +73,7 @@ class ClipKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class ClipGradKernel : public framework::OpKernel {
class ClipGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max = context.Attr<T>("max");
......
......@@ -25,12 +25,14 @@ class ConcatOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of ConcatOp should be empty.")
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ConcatOp should not be null.");
auto ins = ctx->GetInputsDim("X");
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
size_t n = ins.size();
const size_t n = ins.size();
PADDLE_ENFORCE_GT(n, 1, "Input tensors count should > 1.");
......@@ -72,10 +74,27 @@ class ConcatOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class ConcatOpGrad : public framework::OperatorWithKernel {
public:
ConcatOpGrad(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(concat, ops::ConcatOp, ops::ConcatOpMaker)
REGISTER_OP(concat, ops::ConcatOp, ops::ConcatOpMaker, concat_grad,
ops::ConcatOpGrad)
REGISTER_OP_CPU_KERNEL(concat,
ops::ConcatKernel<paddle::platform::CPUPlace, float>)
REGISTER_OP_CPU_KERNEL(concat_grad,
ops::ConcatGradKernel<paddle::platform::CPUPlace, float>)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/concat_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(concat,
ops::ConcatKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
concat_grad, ops::ConcatGradKernel<paddle::platform::GPUPlace, float>);
......@@ -16,46 +16,51 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class ConcatKernel : public framework::OpKernel {
class ConcatKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
size_t n = ins.size();
size_t output_axis_dim = 0;
size_t before = 1, after = 1;
for (size_t i = 0; i < n; i++) {
output_axis_dim += ins[i]->dims()[axis];
}
auto& input_zero = ins[0];
for (int64_t i = 0; i < input_zero->dims().size(); i++) {
if (i == axis) {
continue;
}
if (i < axis) {
before *= input_zero->dims()[i];
} else {
after *= input_zero->dims()[i];
}
}
const size_t n = ins.size();
size_t output_offset = 0;
out->mutable_data<T>(ctx.GetPlace());
auto out_stride = framework::stride(out->dims());
for (size_t i = 0; i < n; i++) {
auto& in = ins[i];
auto axis_dim = in->dims()[axis];
for (size_t j = 0; j < before; j++) {
size_t len = axis_dim * after * sizeof(T);
const T* src = in->data<T>() + axis_dim * after * j;
T* out_data = out->mutable_data<T>(platform::CPUPlace());
T* dest = out_data + output_offset + output_axis_dim * after * j;
memcpy(dest, src, len);
}
output_offset += axis_dim * after;
auto in_stride = framework::stride(in->dims());
StridedMemcpy<T>(ctx.device_context(), in->data<T>(), in_stride,
in->dims(), out_stride, out->data<T>() + output_offset);
output_offset += axis_dim * in_stride[axis];
}
}
};
template <typename Place, typename T>
class ConcatGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* in = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto outs = ctx.MultiOutput<framework::Tensor>(framework::GradVarName("X"));
int64_t axis = static_cast<int64_t>(ctx.Attr<int>("axis"));
const size_t n = outs.size();
size_t input_offset = 0;
auto in_stride = framework::stride(in->dims());
for (size_t i = 0; i < n; i++) {
auto& out = outs[i];
out->mutable_data<T>(ctx.GetPlace());
size_t axis_dim = out->dims()[axis];
auto out_stride = framework::stride(out->dims());
StridedMemcpy<T>(ctx.device_context(), in->data<T>() + input_offset,
in_stride, out->dims(), out_stride, out->data<T>());
input_offset += axis_dim * in_stride[axis];
}
}
};
......
......@@ -82,7 +82,7 @@ void CondOp::InferShape(const Scope& scope) const {
}
// each net calls InferShape
sub_net_op_[i]->InferShape(*sub_scopes[i]);
// sub_net_op_[i]->InferShape(*sub_scopes[i]);
}
for (auto& output : Outputs("Outs")) {
......
......@@ -57,8 +57,10 @@ class CondOp : public framework::OperatorBase {
/*
* InferShape must be called before Run.
* FIXME(yuyang18): Since InferShape has been removed, this implementation
* could be wrong.
*/
void InferShape(const framework::Scope& scope) const override;
void InferShape(const framework::Scope& scope) const;
/*
* Set True Block
......
......@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class CosSimKernel : public framework::OpKernel {
class CosSimKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// get Tensor
......@@ -67,7 +67,7 @@ class CosSimKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class CosSimGradKernel : public framework::OpKernel {
class CosSimGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
// get Tensor
......
......@@ -27,7 +27,7 @@ using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor;
template <typename T>
class CropKernel : public framework::OpKernel {
class CropKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
......@@ -69,7 +69,7 @@ void CropGradFunction(const framework::ExecutionContext& context) {
}
template <typename Place, typename T>
class CropGradKernel : public framework::OpKernel {
class CropGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
size_t rank =
......
......@@ -47,6 +47,12 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Y", {x_dims[0], 1});
ctx->ShareLoD("X", /*->*/ "Y");
}
// CrossEntropy's data type just determined by "X"
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
};
class CrossEntropyGradientOp : public framework::OperatorWithKernel {
......@@ -87,6 +93,12 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
// CrossEntropy's data type just determined by "X"
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
};
class CrossEntropyOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -18,14 +18,6 @@ namespace paddle {
namespace operators {
namespace {
// TODO(qingqing): make zero setting a common function.
template <typename T>
__global__ void Zero(T* X, const int N) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
X[i] = 0.0;
}
}
template <typename T>
__global__ void CrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
......@@ -53,7 +45,7 @@ __global__ void SoftCrossEntropyGradientKernel(T* dX, const T* dY, const T* X,
} // namespace
template <typename T>
class CrossEntropyOpCUDAKernel : public framework::OpKernel {
class CrossEntropyOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -64,12 +56,12 @@ class CrossEntropyOpCUDAKernel : public framework::OpKernel {
y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::GPUPlace, T>()(
ctx, y, x, label, ctx.Attr<bool>("softLabel"));
ctx.device_context(), y, x, label, ctx.Attr<bool>("softLabel"));
}
};
template <typename T>
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -99,11 +91,7 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel {
.stream()>>>(dx_data, dy_data, x_data, label_data,
batch_size, class_num);
} else {
Zero<T><<<grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(dx_data, batch_size * class_num);
math::SetConstant<platform::GPUPlace, T>(ctx.device_context(), dx, 0);
auto* label_data = label->data<int>();
grid = (batch_size + block - 1) / block;
CrossEntropyGradientKernel<T><<<
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/cross_entropy.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
......@@ -26,7 +27,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
class CrossEntropyOpKernel : public framework::OpKernel {
class CrossEntropyOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
......@@ -37,12 +38,12 @@ class CrossEntropyOpKernel : public framework::OpKernel {
y->mutable_data<T>(ctx.GetPlace());
math::CrossEntropyFunctor<platform::CPUPlace, T>()(
ctx, y, x, labels, ctx.Attr<bool>("softLabel"));
ctx.device_context(), y, x, labels, ctx.Attr<bool>("softLabel"));
}
};
template <typename T>
class CrossEntropyGradientOpKernel : public framework::OpKernel {
class CrossEntropyGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
......@@ -69,8 +70,7 @@ class CrossEntropyGradientOpKernel : public framework::OpKernel {
const T* x_data = x->data<T>();
const int* label_data = label->data<int>();
// TODO(qingqing): make zero setting a common function.
memset(dx_data, 0, sizeof(T) * batch_size * class_num);
math::SetConstant<platform::CPUPlace, T>(ctx.device_context(), dx, 0);
for (int i = 0; i < batch_size; ++i) {
PADDLE_ASSERT(label_data[i] >= 0 || label_data[i] < class_num);
......
......@@ -47,7 +47,7 @@ struct MaskGenerator {
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random.
template <typename Place, typename T, typename AttrType>
class GPUDropoutKernel : public framework::OpKernel {
class GPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
......
......@@ -26,7 +26,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T, typename AttrType>
class CPUDropoutKernel : public framework::OpKernel {
class CPUDropoutKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
......@@ -62,7 +62,7 @@ class CPUDropoutKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class DropoutGradKernel : public framework::OpKernel {
class DropoutGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE(context.Attr<bool>("is_training"),
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace operators {
template <typename Place, typename T>
class ElementwiseAddKernel : public framework::OpKernel {
class ElementwiseAddKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenAddFunctor, Place, T>(ctx);
......@@ -101,7 +101,7 @@ struct ElementwiseAddBroadCast2GradFunctor {
};
template <typename Place, typename T>
class ElementwiseAddGradKernel : public framework::OpKernel {
class ElementwiseAddGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseAddGradFunctor<T>,
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace operators {
template <typename Place, typename T>
class ElementwiseDivKernel : public framework::OpKernel {
class ElementwiseDivKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenDivFunctor, Place, T>(ctx);
......@@ -103,7 +103,7 @@ struct ElementwiseDivBroadCast2GradFunctor {
};
template <typename Place, typename T>
class ElementwiseDivGradKernel : public framework::OpKernel {
class ElementwiseDivGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseDivGradFunctor<T>,
......
......@@ -36,7 +36,9 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker,
elementwise_mul_grad, ops::ElementwiseOpGrad);
REGISTER_OP_CPU_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>);
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>);
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, double>);
......@@ -19,7 +19,9 @@ namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(
elementwise_mul,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, float>);
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseMulKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, float>);
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, float>,
ops::ElementwiseMulGradKernel<paddle::platform::GPUPlace, double>);
......@@ -19,7 +19,7 @@ namespace paddle {
namespace operators {
template <typename Place, typename T>
class ElementwiseMulKernel : public framework::OpKernel {
class ElementwiseMulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenMulFunctor, Place, T>(ctx);
......@@ -102,7 +102,7 @@ struct ElementwiseMulBroadCast2GradFunctor {
};
template <typename Place, typename T>
class ElementwiseMulGradKernel : public framework::OpKernel {
class ElementwiseMulGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseMulGradFunctor<T>,
......
......@@ -19,7 +19,7 @@ namespace paddle {
namespace operators {
template <typename Place, typename T>
class ElementwiseSubKernel : public framework::OpKernel {
class ElementwiseSubKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseCompute<EigenSubFunctor, Place, T>(ctx);
......@@ -102,7 +102,7 @@ struct ElementwiseSubBroadCast2GradFunctor {
};
template <typename Place, typename T>
class ElementwiseSubGradKernel : public framework::OpKernel {
class ElementwiseSubGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
ElementwiseGradCompute<Place, T, ElementwiseSubGradFunctor<T>,
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace operators {
template <typename Place, typename T>
class FillZerosLikeKernel : public framework::OpKernel {
class FillZerosLikeKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* output = context.Output<framework::Tensor>("Y");
......
......@@ -37,6 +37,11 @@ class GatherOp : public framework::OperatorWithKernel {
output_dims[0] = batch_size;
ctx->SetOutputDim("Out", output_dims);
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
};
class GatherGradOp : public framework::OperatorWithKernel {
......@@ -47,6 +52,11 @@ class GatherGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContextBase* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("X")->type());
}
};
class GatherOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -24,7 +24,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class GatherOpKernel : public framework::OpKernel {
class GatherOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *X = ctx.Input<Tensor>("X");
......@@ -37,7 +37,7 @@ class GatherOpKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class GatherGradientOpKernel : public framework::OpKernel {
class GatherGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *Index = ctx.Input<Tensor>("Index");
......
......@@ -16,7 +16,7 @@ namespace paddle {
namespace operators {
template <typename T>
class CPUGaussianRandomKernel : public framework::OpKernel {
class CPUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.Attr<float>("mean");
......@@ -56,6 +56,11 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
"dims can be one int or array. dims must be set.");
ctx->SetOutputDim("Out", framework::make_ddim(temp));
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("data_type"));
}
};
class GaussianRandomOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -76,6 +81,8 @@ Use to initialize tensor with gaussian random generator.
"Random seed of generator."
"0 means use system wide seed")
.SetDefault(0);
AddAttr<int>("data_type", "output data type")
.SetDefault(framework::DataType::FP32);
}
};
......
......@@ -37,7 +37,7 @@ struct GaussianGenerator {
};
template <typename T>
class GPUGaussianRandomKernel : public framework::OpKernel {
class GPUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
......
......@@ -25,7 +25,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class GemmConv2DKernel : public framework::OpKernel {
class GemmConv2DKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
......@@ -98,7 +98,7 @@ class GemmConv2DKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class GemmConvGrad2DKernel : public framework::OpKernel {
class GemmConvGrad2DKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* input = context.Input<Tensor>("Input");
......
......@@ -36,6 +36,11 @@ class LookupTableOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", {ids_dims[0], table_dims[1]});
ctx->ShareLoD("Ids", /*->*/ "Out");
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type());
}
};
class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -69,6 +74,11 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
auto table_dims = ctx->GetInputDim("W");
ctx->SetOutputDim(framework::GradVarName("W"), table_dims);
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.Input<Tensor>("W")->type());
}
};
} // namespace operators
......
......@@ -61,7 +61,7 @@ __global__ void LookupTableGrad(T* table, const T* output, const int32_t* ids,
}
template <typename T>
class LookupTableCUDAKernel : public framework::OpKernel {
class LookupTableCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W");
......@@ -85,7 +85,7 @@ class LookupTableCUDAKernel : public framework::OpKernel {
};
template <typename T>
class LookupTableGradCUDAKernel : public framework::OpKernel {
class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids");
......
......@@ -23,7 +23,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class LookupTableKernel : public framework::OpKernel {
class LookupTableKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto table_t = context.Input<Tensor>("W"); // float tensor
......@@ -44,7 +44,7 @@ class LookupTableKernel : public framework::OpKernel {
};
template <typename T>
class LookupTableGradKernel : public framework::OpKernel {
class LookupTableGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto ids_t = context.Input<Tensor>("Ids");
......
......@@ -90,7 +90,7 @@ __global__ void LSTMUnitGradientKernel(const int nthreads, const int dim,
}
template <typename T, typename AttrType = T>
class LstmUnitOpCUDAKernel : public framework::OpKernel {
class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......@@ -121,7 +121,7 @@ class LstmUnitOpCUDAKernel : public framework::OpKernel {
};
template <typename T, typename AttrType = T>
class LstmUnitGradOpCUDAKernel : public framework::OpKernel {
class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
......
......@@ -33,7 +33,7 @@ inline T tanh(T x) {
}
template <typename Place, typename T, typename AttrType = T>
class LstmUnitKernel : public framework::OpKernel {
class LstmUnitKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
......@@ -76,7 +76,7 @@ class LstmUnitKernel : public framework::OpKernel {
};
template <typename Place, typename T, typename AttrType = T>
class LstmUnitGradKernel : public framework::OpKernel {
class LstmUnitGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
......
if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc
im2col.cu pooling.cc pooling.cu DEPS cblas device_context)
nv_library(softmax_function SRCS softmax.cc softmax.cu
DEPS operator)
nv_library(cross_entropy_function SRCS cross_entropy.cc cross_entropy.cu
DEPS operator)
nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu pooling.cc pooling.cu DEPS cblas device_context)
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator)
else()
cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc DEPS cblas device_context)
cc_library(softmax_function SRCS softmax.cc DEPS operator)
cc_library(cross_entropy_function SRCS cross_entropy.cc DEPS operator)
cc_library(math_function SRCS math_function.cc im2col.cc pooling.cc
DEPS cblas device_context operator)
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_library(softmax SRCS softmax.cc DEPS operator)
cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator)
endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
......@@ -26,8 +26,8 @@ using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
class CrossEntropyFunctor<platform::CPUPlace, T> {
public:
void operator()(const framework::ExecutionContext& ctx,
framework::Tensor* out, const framework::Tensor* prob,
void operator()(const platform::DeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel) {
const int batch_size = prob->dims()[0];
if (softLabel) {
......@@ -35,7 +35,7 @@ class CrossEntropyFunctor<platform::CPUPlace, T> {
auto lbl = EigenMatrix<T>::From(*labels);
auto loss = EigenMatrix<T>::From(*out);
loss.device(ctx.GetEigenDevice<platform::CPUPlace>()) =
loss.device(*ctx.GetEigenDevice<platform::CPUPlace>()) =
-((lbl * in.log().unaryExpr(math::TolerableValue<T>()))
.sum(Eigen::DSizes<int, 1>(1))
.reshape(Eigen::DSizes<int, 2>(batch_size, 1)));
......
......@@ -74,8 +74,8 @@ using Tensor = framework::Tensor;
template <typename T>
class CrossEntropyFunctor<platform::GPUPlace, T> {
public:
void operator()(const framework::ExecutionContext& ctx,
framework::Tensor* out, const framework::Tensor* prob,
void operator()(const platform::DeviceContext& ctx, framework::Tensor* out,
const framework::Tensor* prob,
const framework::Tensor* labels, bool softLabel) {
const T* prob_data = prob->data<T>();
T* loss_data = out->mutable_data<T>(ctx.GetPlace());
......@@ -87,20 +87,18 @@ class CrossEntropyFunctor<platform::GPUPlace, T> {
const T* label_data = labels->data<T>();
int block = class_num > 512 ? 512 : pow(2, int(std::log2(class_num)));
SoftCrossEntropyKernel<
T><<<batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(loss_data, prob_data, label_data, class_num);
SoftCrossEntropyKernel<T><<<
batch_size, block, block * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
loss_data, prob_data, label_data, class_num);
} else {
const int* label_data = labels->data<int>();
int block = 512;
int grid = (batch_size + block - 1) / block;
CrossEntropyKernel<T><<<
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ctx.device_context())
.stream()>>>(loss_data, prob_data, label_data,
batch_size, class_num);
grid, block, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream()>>>(
loss_data, prob_data, label_data, batch_size, class_num);
}
}
};
......
......@@ -37,9 +37,7 @@ struct TolerableValue {
template <typename Place, typename T>
class CrossEntropyFunctor {
public:
// (TODO caoying) it is much better to use DeviceContext as the first
// parameter.
void operator()(const framework::ExecutionContext& context,
void operator()(const platform::DeviceContext& context,
framework::Tensor* out, const framework::Tensor* prob,
const framework::Tensor* labels, const bool softLabel);
};
......
......@@ -52,6 +52,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
#include <cmath>
#include "paddle/framework/eigen.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
......@@ -84,6 +85,13 @@ void matmul(const platform::DeviceContext& context,
const framework::Tensor& matrix_b, bool trans_b, T alpha,
framework::Tensor* matrix_out, T beta);
template <typename Place, typename T>
void SetConstant(const platform::DeviceContext& context,
framework::Tensor* tensor, T num) {
auto t = framework::EigenVector<T>::Flatten(*tensor);
t.device(*context.GetEigenDevice<Place>()) = t.constant(static_cast<T>(num));
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -243,3 +243,24 @@ TEST(math_function, gemm_trans_clbas) {
EXPECT_EQ(input3_ptr[6], 86);
EXPECT_EQ(input3_ptr[7], 99);
}
TEST(math_function, zero) {
paddle::framework::Tensor tensor;
auto* cpu_place = new paddle::platform::CPUPlace();
float* t = tensor.mutable_data<float>({2, 2}, *cpu_place);
paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::SetConstant<paddle::platform::CPUPlace, float>(
context, &tensor, 0);
EXPECT_EQ(t[0], 0);
EXPECT_EQ(t[1], 0);
EXPECT_EQ(t[2], 0);
EXPECT_EQ(t[3], 0);
paddle::operators::math::SetConstant<paddle::platform::CPUPlace, float>(
context, &tensor, 1);
EXPECT_EQ(t[0], 1);
EXPECT_EQ(t[1], 1);
EXPECT_EQ(t[2], 1);
EXPECT_EQ(t[3], 1);
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/softmax.h"
......@@ -19,6 +19,7 @@ namespace operators {
namespace math {
template class SoftmaxFunctor<platform::CPUPlace, float>;
template class SoftmaxGradFunctor<platform::CPUPlace, float>;
} // namespace math
} // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
......@@ -21,6 +21,7 @@ namespace operators {
namespace math {
template class SoftmaxFunctor<platform::GPUPlace, float>;
template class SoftmaxGradFunctor<platform::GPUPlace, float>;
} // namespace math
} // namespace operators
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
......@@ -36,7 +36,7 @@ struct ValueClip {
template <typename Place, typename T>
class SoftmaxFunctor {
public:
void operator()(const framework::ExecutionContext& context,
void operator()(const platform::DeviceContext& context,
const framework::Tensor* X, framework::Tensor* Y) {
auto logits = EigenMatrix<T>::From(*X);
auto softmax = EigenMatrix<T>::From(*Y);
......@@ -58,8 +58,8 @@ class SoftmaxFunctor {
.broadcast(one_by_class))
.unaryExpr(ValueClip<T>());
softmax.device(context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(context.GetEigenDevice<Place>()) =
softmax.device(*context.GetEigenDevice<Place>()) = shifted_logits.exp();
softmax.device(*context.GetEigenDevice<Place>()) =
(softmax *
softmax.sum(along_class)
.inverse()
......@@ -68,6 +68,37 @@ class SoftmaxFunctor {
.broadcast(one_by_class));
}
};
template <typename Place, typename T>
class SoftmaxGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor* y, const framework::Tensor* y_grad,
framework::Tensor* x_grad) {
auto softmax = EigenMatrix<T>::From(*y);
auto softmax_grad = EigenMatrix<T>::From(*y_grad);
auto logits_grad = EigenMatrix<T>::From(*x_grad);
const int kBatchDim = 0;
const int kClassDim = 1;
const int batch_size = softmax.dimension(kBatchDim);
const int num_classes = softmax.dimension(kClassDim);
Eigen::DSizes<int, 1> along_class(kClassDim);
Eigen::DSizes<int, 2> batch_by_one(batch_size, 1);
Eigen::DSizes<int, 2> one_by_class(1, num_classes);
auto dot = (softmax * softmax_grad)
.sum(along_class)
.eval()
.reshape(batch_by_one)
.broadcast(one_by_class);
logits_grad.device(*context.GetEigenDevice<Place>()) =
(softmax_grad - dot) * softmax;
}
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class MeanKernel : public framework::OpKernel {
class MeanKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<Tensor>("X");
......@@ -45,7 +45,7 @@ class MeanKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class MeanGradKernel : public framework::OpKernel {
class MeanGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto OG = context.Input<Tensor>(framework::GradVarName("Out"));
......
......@@ -20,7 +20,7 @@ namespace paddle {
namespace operators {
template <typename Place, typename T>
class MinusKernel : public framework::OpKernel {
class MinusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* left_tensor = context.Input<framework::Tensor>("X");
......
......@@ -39,7 +39,7 @@ struct ModifiedHuberLossBackward {
};
template <typename T>
class ModifiedHuberLossGradGPUKernel : public framework::OpKernel {
class ModifiedHuberLossGradGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Y");
......
......@@ -47,7 +47,7 @@ struct ModifiedHuberLossForward {
};
template <typename Place, typename T>
class ModifiedHuberLossKernel : public framework::OpKernel {
class ModifiedHuberLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("X");
......@@ -73,7 +73,7 @@ class ModifiedHuberLossKernel : public framework::OpKernel {
// CPU backward kernel
template <typename T>
class ModifiedHuberLossGradCPUKernel : public framework::OpKernel {
class ModifiedHuberLossGradCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in0 = context.Input<Tensor>("Y");
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/mul_op.h"
......@@ -35,12 +35,14 @@ class MulOp : public framework::OperatorWithKernel {
int x_num_col_dims = ctx->Attrs().Get<int>("x_num_col_dims");
int y_num_col_dims = ctx->Attrs().Get<int>("y_num_col_dims");
PADDLE_ENFORCE(x_dims.size() > x_num_col_dims,
"The rank of input tensor X should be larger than "
"`mul_op`'s `x_num_col_dims`.");
PADDLE_ENFORCE(y_dims.size() > y_num_col_dims,
"The rank of input tensor Y should be larger than "
"`mul_op`'s `y_num_col_dims`.");
PADDLE_ENFORCE_GT(
x_dims.size(), x_num_col_dims,
"The input tensor X's rank of MulOp should be larger than "
"x_num_col_dims.");
PADDLE_ENFORCE_GT(
y_dims.size(), y_num_col_dims,
"The input tensor Y's rank of MulOp should be larger than "
"y_num_col_dims.");
auto x_mat_dims = framework::flatten_to_2d(x_dims, x_num_col_dims);
auto y_mat_dims = framework::flatten_to_2d(y_dims, y_num_col_dims);
......
......@@ -28,7 +28,7 @@ template <typename T, int MajorType = Eigen::RowMajor,
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
class MulKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* x = context.Input<Tensor>("X");
......@@ -52,7 +52,7 @@ class MulKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class MulGradKernel : public framework::OpKernel {
class MulGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int x_num_col_dims = ctx.template Attr<int>("x_num_col_dims");
......
......@@ -50,6 +50,11 @@ class MultiplexOp : public framework::OperatorWithKernel {
}
ctx->SetOutputDim("Out", in_dim);
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type());
}
};
class MultiplexOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -99,6 +104,11 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
}
ctx->SetOutputsDim(framework::GradVarName("X"), d_ins);
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type());
}
};
} // namespace operators
......
......@@ -21,7 +21,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class MultiplexGPUKernel : public framework::OpKernel {
class MultiplexGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<Tensor>("X");
......@@ -42,7 +42,7 @@ class MultiplexGPUKernel : public framework::OpKernel {
for (auto i = 0; i < rows; i++) {
int32_t k = index[i];
PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative.");
PADDLE_ENFORCE_LT(k, ins.size(),
PADDLE_ENFORCE_LT((size_t)k, ins.size(),
"index exceeds the number of candidate tensors.");
memory::Copy(place, out->data<T>() + i * cols, place,
ins[k]->data<T>() + i * cols, cols * sizeof(T), stream);
......@@ -51,7 +51,7 @@ class MultiplexGPUKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class MultiplexGradGPUKernel : public framework::OpKernel {
class MultiplexGradGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
......
......@@ -23,7 +23,7 @@ namespace paddle {
namespace operators {
template <typename Place, typename T>
class MultiplexCPUKernel : public framework::OpKernel {
class MultiplexCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto ins = ctx.MultiInput<framework::Tensor>("X");
......@@ -48,7 +48,7 @@ class MultiplexCPUKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class MultiplexGradCPUKernel : public framework::OpKernel {
class MultiplexGradCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
......
......@@ -53,16 +53,6 @@ class NetOp : public framework::OperatorBase {
this->CompleteAddOp();
}
/**
* Infer all the operators' input and output variables' shapes, will be called
* before every mini-batch
*/
void InferShape(const framework::Scope& scope) const override {
for (auto& op : ops_) {
op->InferShape(scope);
}
}
/**
* @brief Run the network.
*
......
......@@ -7,14 +7,12 @@ namespace operators {
using Scope = framework::Scope;
using DeviceContext = platform::DeviceContext;
static int infer_shape_cnt = 0;
static int run_cnt = 0;
class TestOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
DEFINE_OP_CLONE_METHOD(TestOp);
void InferShape(const Scope& scope) const override { ++infer_shape_cnt; }
void Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const override {
++run_cnt;
......
......@@ -47,7 +47,7 @@ void PadFunction(const framework::ExecutionContext& context) {
}
template <typename Place, typename T>
class PadKernel : public framework::OpKernel {
class PadKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
int rank = context.Input<Tensor>("X")->dims().size();
......@@ -97,7 +97,7 @@ void PadGradFunction(const framework::ExecutionContext& context) {
}
template <typename Place, typename T>
class PadGradKernel : public framework::OpKernel {
class PadGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
size_t rank =
......
......@@ -25,7 +25,7 @@ namespace operators {
using Tensor = framework::Tensor;
template <typename Place, typename T>
class PoolKernel : public framework::OpKernel {
class PoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
......@@ -82,7 +82,7 @@ class PoolKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class PoolGradKernel : public framework::OpKernel {
class PoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_x = context.Input<Tensor>("X");
......
......@@ -40,7 +40,7 @@ class PReluFunctor {
};
template <typename Place, typename T>
class PReluKernel : public framework::OpKernel {
class PReluKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
......@@ -77,7 +77,7 @@ class PReluGradFunctor {
};
template <typename Place, typename T>
class PReluGradKernel : public framework::OpKernel {
class PReluGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* dx = context.Output<Tensor>(framework::GradVarName("X"));
......
......@@ -21,7 +21,7 @@ namespace paddle {
namespace operators {
template <typename Place, typename T>
class RankLossKernel : public framework::OpKernel {
class RankLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::Tensor>("Out");
......@@ -42,7 +42,7 @@ class RankLossKernel : public framework::OpKernel {
};
template <typename Place, typename T>
class RankLossGradKernel : public framework::OpKernel {
class RankLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_left_t =
......
......@@ -28,29 +28,6 @@ using Variable = framework::Variable;
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
void RecurrentAlgorithm::InferShape(const Scope& scope) const {
auto* input0 = scope.FindVar(arg_->inlinks[0]);
PADDLE_ENFORCE_NOT_NULL(input0);
seq_len_ = input0->GetMutable<LoDTensor>()->dims()[0];
PADDLE_ENFORCE_GT(seq_len_, 0);
CreateScopes(scope);
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
InitMemories(step_scopes[0], true /*infer_shape_mode*/);
for (size_t i = 0; i < seq_len_; i++) {
if (i > 0) {
rnn::LinkMemories(step_scopes, arg_->memories, i, -1,
true /*infer_shape_mode*/);
}
(*stepnet_)->InferShape(*step_scopes[i]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
}
void RecurrentAlgorithm::Run(const Scope& scope,
const platform::DeviceContext& dev_ctx) const {
auto step_scopes = GetStepScopes(scope);
......@@ -202,24 +179,6 @@ void RecurrentGradientAlgorithm::LinkBootMemoryGradients(
}
}
void RecurrentGradientAlgorithm::InferShape(const Scope& scope) const {
seq_len_ =
scope.FindVar(arg_->inlinks[0])->GetMutable<LoDTensor>()->dims()[0];
auto step_scopes = GetStepScopes(scope);
rnn::SegmentInputs(step_scopes, arg_->inlinks, seq_len_,
true /*infer_shape_mode*/);
for (int step_id = seq_len_ - 1; step_id >= 0; --step_id) {
if (static_cast<size_t>(step_id) != seq_len_ - 1) {
rnn::LinkMemories(step_scopes, arg_->memories, step_id, 1,
true /*infer_shape_mode*/);
}
(*stepnet_)->InferShape(*step_scopes[step_id]);
}
rnn::ConcatOutputs(step_scopes, arg_->outlinks, seq_len_,
true /*infer_shape_mode*/);
LinkBootMemoryGradients(step_scopes[0], true /*infer_shape_mode*/);
}
RecurrentGradientOp::RecurrentGradientOp(
const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册