提交 cd7d0f85 编写于 作者: L Liu Yiqun

Merge branch 'develop' into core_inference_example

......@@ -93,6 +93,15 @@ Test on batch size 1, 2, 4, 8, 16 on Intel(R) Xeon(R) Gold 6148 CPU @ 2.40GHz
| MKLML | 22.74 | 41.56 | 81.22 | 133.47 | 210.53 |
| MKL-DNN | 175.10 | 272.92 | 450.70 | 512.00 | 600.94 |
- Alexnet
| BatchSize | 1 | 2 | 4 | 8 | 16 |
|-----------|--------|--------|--------|--------|--------|
| OpenBLAS | | | | | |
| MKLML | 21.32 | 36.55 | 73.06 | 131.15 | 192.77 |
| MKL-DNN | 442.91 | 656.41 | 719.10 | 847.68 | 850.51 |
chart TBD
### Laptop
TBD
......@@ -19,7 +19,11 @@ args = {
'num_samples': num_samples
}
define_py_data_sources2(
"train.list", None, module="provider", obj="process", args=args)
"train.list" if not is_infer else None,
"test.list" if is_infer else None,
module="provider",
obj="process",
args=args)
settings(
batch_size=batch_size,
......
......@@ -8,15 +8,19 @@ function clock_to_seconds() {
}
function infer() {
unset OMP_NUM_THREADS MKL_NUM_THREADS OMP_DYNAMIC KMP_AFFINITY
topology=$1
layer_num=$2
bs=$3
thread=`nproc`
if [ $thread -gt $bs ]; then
thread=$bs
trainers=`nproc`
if [ $trainers -gt $bs ]; then
trainers=$bs
fi
log="logs/infer-${topology}-${layer_num}-${thread}openblas-${bs}.log"
log="logs/infer-${topology}-${layer_num}-${trainers}openblas-${bs}.log"
threads=$((`nproc` / trainers))
if [ $threads -eq 0 ]; then
threads=1
fi
export OPENBLAS_NUM_THREADS=$threads
models_in="models/${topology}-${layer_num}/pass-00000/"
if [ ! -d $models_in ]; then
......@@ -28,7 +32,7 @@ function infer() {
--config="${topology}.py" \
--use_mkldnn=False \
--use_gpu=False \
--trainer_count=$thread \
--trainer_count=$trainers \
--log_period=$log_period \
--config_args="batch_size=${bs},layer_num=${layer_num},is_infer=True,num_samples=256" \
--init_model_path=$models_in \
......
set -e
function train() {
unset OMP_NUM_THREADS MKL_NUM_THREADS OMP_DYNAMIC KMP_AFFINITY
export OPENBLAS_NUM_THREADS=1
topology=$1
layer_num=$2
bs=$3
......
# Backward Building
## Motivation
In Neural Network, most models are solved by the backpropagation algorithm(known as **BP**) at present. Technically, BP calculates the gradient of the loss function, then propagates it back through the networks following the chain rule. However, when configuring the model structure, users do not need to define the backward part. So a mechanism is required by the framework which can complete the model's backward part automatically according to the given forward part.
When implementing a specific `op`, the developer is also asked to implement its backward version, called `grad_op`. A `grad_op` takes gradients of its corresponding `op`'s outputs, and calculate gradients of the `op`'s inputs. During the building of a model's backward part, the framework creates each forward `op`'s `grad_op`, and then string them together in reverse order of forwarding part. In this way, gradients spread from the end to the beginning of the model, in another word, from the loss to parameters.
## Challenges
The motivation of backward building is apparent. However, implementation it correctly is not so easy. In the **Fluid** design, a deep learning model is described by `Program`, `Block`, `Op` and `Variable`. The `Block` itself can be nested. It means that the `op`s and `variable`s are scattered across different blocks rather than all be gathered in a single graph. Our backward building algorithm shall visit blocks in recursive order and be able to insert `grad_op`s and new created `variable`s into the right place.
## Usage
Although the whole algorithm is comprised of many functions, only one is exposed as API:
```python
def append_backward(loss, parameter_list=None, no_grad_set=None):
"""
Append backward part to main_program
Args:
loss(Variable): The variable generated by the cost function.
parameter_list(list): Parameters that need to be updated by optimizers.
If None, it means all parameters need to be updated.
no_grad_set(set): Variables that have no gradients in Block 0.
If None, the set will be generated inside the function and
contains all variables with `step_gradient=True` from all blocks.
Return:
(list[Variable]): list of (parameters, gradients) pair.
"""
```
By invoking this API, the framework appends backward part of the program where the `loss` is. It takes three arguments. `loss` means the final loss value. It must be a scalar and is usually the output of the loss layer. It is also where the gradient generated and backpropagation starts. `parameter_list` marks all parameters needs updating. If it's `None`, all parameter will be updated by optimizers. `no_grad_set` marks variables without gradient. if all outputs of some `grad_op` are in `no_grad_set`, the `grad_op` will not be run.
This API will be invoked automatically before optimizer building.
As a result, in most cases, users do not need to invoke the API by themselves to append backward part.
## Implementation
The implementation of backward building algorithm is in `backward.py` file. The whole algorithm can be divided into two independent parts: creating `grad_op`s and creating new variables.
### Creating `grad_op`s
The creating of `grad_op`s is implemented by:
```python
def _append_backward_ops_(target,
block,
target_block,
no_grad_dict,
grad_to_var):
"""
Create all grad ops, and insert them into given block
Args:
target(Variable): the target variable of forward pass
block(Block): the block where forward ops are
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
val(set) a set of varibale names. These varibales have no gradient
grad_to_var(dict)(output argument):
key(str): grad variable name
val(str): corresponding forward variable name
"""
```
Given a `block`, the function will traverses all `op`s in this block in reverse order, gets corresponding `grad_op` from the C++ core via `core.get_grad_op_desc()`, then append it to `target_block`.
However, some specific `op`(e.g. `while_op`, `if_else_op`) can hold its own sub-block. For these sub-blocks contains `op`s as well, the `grad_op` creating should be recursive.
During the reverse traversal, we check each `op` whether it has an attribute named `sub_block`. If so, it means there is a sub-block and we need to deal with it first. After creating a new block whose father is the one in `op`'s attribute, we invoke `_append_backward_ops_()` recursively, assigning the new block to parameter `target_block` and the one in `op`'s attribute to `block`. The *pseudo-code* shows this process:
```
******* pseudo-code ********
for op in reversed(block.ops):
if op has an attribute named 'sub_block':
Get the sub-block(`s_block`) from op's attribute.
Create a new block(`grad_s_block`), whose father is `s_block`.
Invoke _append_backward_ops_(), with `block=s_block` and `target_block=grad_s_block`
Invoke `core.get_grad_op_desc()` to get op's grad_op.
Insert name correspondings between variables and their gradients of the grad_op to grad_to_var
Assign grad_s_block to grad_op as it's 'sub_block' attribute.
Append grad_op to current target_block.
```
The first invoking of `_append_backward_ops_()` is initiated by `append_backward()`, in which parameters `block` and `target_block` are all assigned with root block(the block with index 0).
### Corner Cases of `grad_op` Creating
In the previous section, we show the regular process of `grad_op` creating. However, in some corner cases, the conventional algorithm is not enough to get the correct result and appending handling is required. These additional processes run after the algorithm mentioned above and do some special adjusts on its output `grad_op`s.
#### Shared Variables
If a variable is read by more than one `op` in the forward pass, its gradient is likely to be written by more than one `grad_op`s in the next backward pass. To make the gradient result being the sum of all `grad_op`s' outputs instead of the last running one, we assign each output with a temporary variable and then add a `sum_op` to add them up.
For the debug convenience, if the final gradient name is `w@GRAD`, it's corresponding temporary variables will be named as `w@GRAD@RENAME@0`, `w@GRAD@RENAME@1`...
See function `_addup_repetitive_outputs_` in `backward.py` for implementation details.
#### No Gradient Variables
In our framework, variables can be marked as *no_gradient*, it means that the gradient of this variable is unnecessary and can be considered as zero in model training. Apparently, when all the outputs of some `grad_op` are marked as *no_gradient*, the `grad_op` itself can be skipped in backward pass.
But these unnecessary gradients still need to be creating and initialized by something, otherwise following `grad_op`s who take these gradients as inputs take the risk of using uninitialized memory. In our code, we employ `fill_zeros_like_op` to initialize them as all zeros.
This features are implemented in function `_remove_no_grad_branch_`. It checks new created `grad_op`s one-by-one, removes whose outputs are all in `no_grad_set` or inserts `fill_zeros_like_op` when its necessary. We can get the `no_grad_set` from the `_append_backward_ops_` argument `no_grad_dict` or generate it on the fly by scanning all variables' `no_gradient` attribute(True or False).
### Creating Backward Variables
Up to now, we have completed all creating and adjusting jobs of `grad_op`s. However, backward variables have not been created. Now they are only represented by `grad_op`'s input and output arguments. The backward variable creating job will be done by:
```python
def _append_backward_vars_(block,
start_op_idx,
grad_to_var,
grad_info_map):
"""
Create new variables required by backward pass.
Args:
block(Block): the block where new variables will be created
start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created
grad_to_var(dict):
key(str): grad variable name
val(str): corresponding forward variable name
In most cases, this dict is generated by _append_backward_ops_()
grad_info_map(dict)(output argument):
key(str): forward variable name
val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index
"""
```
Given a `block`, this function traverses all the `grad_op`s in it(The argument `start_op_idx` indicates where the grad_op sequence starts.) and creates all the uncreated outputs. The *pseudo-code* shows this process:
```
for op in block.ops[start_op_idx : ]:
if op has an attribute named 'sub_block':
Get the sub-block(`s_block`) from op's attribute.
Invoke _append_backward_vars_(), with `block=s_block`
for var_name in op.all_output_names():
if block.has_var_recursive(var_name) or var_name is the name of empty variable:
continue
create a new variable named 'var_name' in block
if grad_to_var.has_key(var_name):
set grad_info_map[grad_to_var[var_name]] as a tuple of (var_name. block)
do op's var type inference
do op's shape inference
```
......@@ -15,7 +15,7 @@
获取PaddlePaddle的Docker镜像
------------------------------
执行下面的命令获取最新的PaddlePaddle Docker镜像
执行下面的命令获取最新的PaddlePaddle Docker镜像,版本为cpu_avx_mkl:
.. code-block:: bash
......@@ -27,7 +27,7 @@
docker pull docker.paddlepaddle.org/paddle
下载GPU版本的Docker镜像:
下载GPU版本(cuda8.0_cudnn5_avx_mkl)的Docker镜像:
.. code-block:: bash
......@@ -54,7 +54,7 @@
.. _docker_run:
在Docker中执行PaddlePaddle训练程序
------------------------------
----------------------------------
假设您已经在当前目录(比如在/home/work)编写了一个PaddlePaddle的程序 :code:`train.py` (可以参考
`PaddlePaddleBook <http://www.paddlepaddle.org/docs/develop/book/01.fit_a_line/index.cn.html>`_
......@@ -82,7 +82,7 @@
.. _docker_run_book:
使用Docker启动PaddlePaddle Book教程
------------------------------
-----------------------------------
使用Docker可以快速在本地启动一个包含了PaddlePaddle官方Book教程的Jupyter Notebook,可以通过网页浏览。
PaddlePaddle Book是为用户和开发者制作的一个交互式的Jupyter Notebook。
......
......@@ -16,7 +16,7 @@ After you've read above tutorials you may proceed the following steps.
Pull PaddlePaddle Docker Image
------------------------------
Run the following command to download the latest Docker images:
Run the following command to download the latest Docker images, the version is cpu_avx_mkl:
.. code-block:: bash
......@@ -28,7 +28,7 @@ For users in China, we provide a faster mirror:
docker pull docker.paddlepaddle.org/paddle
Download GPU version images:
Download GPU version (cuda8.0_cudnn5_avx_mkl) images:
.. code-block:: bash
......@@ -58,7 +58,7 @@ and run:
.. _docker_run:
Launch your training program in Docker
------------------------------
--------------------------------------
Assume that you have already written a PaddlePaddle program
named :code:`train.py` under directory :code:`/home/work` (refer to
......
......@@ -11,14 +11,14 @@ PaddlePaddle可以使用常用的Python包管理工具
------------------------------
执行下面的命令即可在当前机器上安装PaddlePaddle的运行时环境,并自动下载安装依赖软件。
执行下面的命令即可在当前机器上安装PaddlePaddle的运行时环境,并自动下载安装依赖软件,版本为cpu_avx_openblas
.. code-block:: bash
pip install paddlepaddle
如果需要安装支持GPU的版本,需要执行:
如果需要安装支持GPU的版本(cuda7.5_cudnn5_avx_openblas),需要执行:
.. code-block:: bash
......
......@@ -12,14 +12,14 @@ Install Using pip
------------------------------
Run the following command to install PaddlePaddle on the current
machine, it will also download requirements.
machine, it will also download requirements, the version is cpu_avx_openblas.
.. code-block:: bash
pip install paddlepaddle
If you wish to install GPU version, just run:
If you wish to install GPU version (cuda7.5_cudnn5_avx_openblas), just run:
.. code-block:: bash
......
......@@ -7,13 +7,13 @@
++++++++
PaddlePaddle支持使用pip快速安装,目前支持CentOS 6以上, Ubuntu 14.04以及MacOS 10.12,并安装有Python2.7。
执行下面的命令完成快速安装:
执行下面的命令完成快速安装,版本为cpu_avx_openblas
.. code-block:: bash
pip install paddlepaddle
如果需要安装支持GPU的版本,需要执行:
如果需要安装支持GPU的版本(cuda7.5_cudnn5_avx_openblas),需要执行:
.. code-block:: bash
......
......@@ -8,13 +8,13 @@ Quick Install
You can use pip to install PaddlePaddle with a single command, supports
CentOS 6 above, Ubuntu 14.04 above or MacOS 10.12, with Python 2.7 installed.
Simply run the following command to install:
Simply run the following command to install, the version is cpu_avx_openblas:
.. code-block:: bash
pip install paddlepaddle
If you need to install GPU version, run:
If you need to install GPU version (cuda7.5_cudnn5_avx_openblas), run:
.. code-block:: bash
......
......@@ -5,10 +5,18 @@ cc_library(ddim SRCS ddim.cc DEPS eigen3)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_library(tensor SRCS tensor.cc DEPS ddim place paddle_memory device_context)
if (WITH_GPU)
nv_library(tensor SRCS tensor.cc tensor_util.cu DEPS ddim place paddle_memory device_context framework_proto)
else()
cc_library(tensor SRCS tensor.cc tensor_util.cc DEPS ddim place paddle_memory device_context framework_proto)
endif ()
cc_test(tensor_test SRCS tensor_test.cc DEPS tensor)
cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor)
if (WITH_GPU)
nv_test(tensor_util_test SRCS tensor_util_test.cc tensor_util_test.cu DEPS tensor)
else()
cc_test(tensor_util_test SRCS tensor_util_test.cc DEPS tensor)
endif()
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
......
# Operator/expression 's Backward
## Motivation
In Neural Network, most models are solved by the backpropagation algorithm(known as **BP**) at present. Technically, BP calculates the gradient of the loss function, then propagates it back through the networks following the chain rule. Hence we need a module that chains the gradient operators/expressions together to construct the backward pass. Every forward network needs a backward network to construct the full computation graph. The operator/expression's backward pass will be generated with respect to the forward pass.
## Implementation
In this design doc, we exported only one API for generating the backward pass.
```c++
std::unique_ptr<OperatorBase> Backward(const OperatorBase& forwardOp,
const std::unordered_set<std::string>& no_grad_vars);
```
The implementation behind it can be divided into two parts, **Backward Operator Creating** and **Backward Operator Building**.
### Backward Operator Registry
A backward network is built up with several backward operators. Backward operators take forward operators' inputs, outputs, and output gradients and then calculate its input gradients.
| | forward operator | backward operator
| ---------------------- | ---------------- |------------------------- |
| **Operator::inputs_** | Inputs | Inputs, Outputs, OutputGradients |
| **Operator::outputs_** | Outputs | InputGradients |
In most cases, there is a one-to-one relation between the forward and backward operators. These relations are recorded by a global hash map(`OpInfoMap`). To follow the philosophy of minimum core and to make operators pluggable, the registry mechanism is introduced.
For example, we have `mul_op`, and we can register its information and corresponding backward operator by the following macro:
```cpp
REGISTER_OP(mul, MulOp, MulOpMaker, mul_grad, MulOpGrad);
```
`mul` is the operator's type. `MulOp` and `MulOpMaker` are the operator class and the operator maker class respectively.
`mul_grad` is the type of backward operator, and `MulOpGrad` is its class name.
### Backward Opeartor Creating
Given a certain forward operator, we can get its corresponding backward operator by calling:
```cpp
OperatorBase* bwd_op = BuildGradOp(const OperatorBase* fwd_op);
```
The function `BuildGradOp` will sequentially execute following processes:
1. Get the `type_` of given forward operator, and then get the corresponding backward operator's type by looking up the `OpInfoMap`.
2. Build two maps named `inputs` and `outputs` to temporarily store backward operator's inputs and outputs. Copy forward operator's `inputs_` and `outputs_` to map `inputs`, except these, are not necessary for gradient computing.
3. Add forward inputs' gradient variables into map `output`, adding forward outputs' gradient variables into map `input`.
4. Building backward operator with `inputs`, `outputs` and forward operator's attributes.
### Backward Network Building
A backward network is a series of backward operators. The main idea of building a backward network is creating backward operators in the inverted sequence and appending them together one by one. There are some corner cases that need special processing.
1. Op
When the input forward network is an Op, return its gradient Operator immediately. If all of its outputs are in no gradient set, then return a special `NOP`.
2. NetOp
In our design, the network itself is also a kind of operator(**NetOp**). So the operators contained by a big network may be some small network. When the input forward network is a NetOp, it needs to call the sub NetOp/Operators backward function recursively. During the process, we need to collect the `OutputGradients` name according to the forward NetOp.
3. RnnOp
RnnOp is a nested stepnet operator. Backward module needs to recusively call `Backward` for every stepnet.
4. Sharing Variables
As illustrated in the figure 1 and figure 2, two operators share the same variable name **W@GRAD**, which will overwrite their shared input variable.
<p align="center">
<img src="./images/duplicate_op.png" width="50%" ><br/>
​ Figure 1. Sharing variables in operators.
</p>
​ Sharing variable between operators or same input variable used in multiple operators can lead to duplicate gradient variables. As illustrated in figure 2, we need to rename the gradient names recursively and add a generic add operator to prevent overwriting.
<p align="center">
<img src="images/duplicate_op2.png" width="40%" ><br/>
​ Figure 2. Replace sharing variable's gradient with `Add` operator.
</p>
​ Because the framework finds variables according to their names, we need to rename the output links. We add an integer suffix to represent its position in the clockwise direction.
5. Part of the Gradient is Zero.
In the whole graph, there is some case of that one operator's gradient is not needed, but its input's gradient is a dependency link of other operator, we need to fill a same shape gradient matrix in the position. In our implementation, we insert a special `fillZeroLike` operator.
Follow these rules above, then collect the sub graph `OutputGradients`/`InputGradients` as the NetOp's and return it.
......@@ -27,9 +27,8 @@ limitations under the License. */
namespace paddle {
namespace framework {
using DataTransformFn =
std::function<void(const std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out)>;
using DataTransformFn = std::function<void(const platform::DeviceContext* ctx,
const Variable& in, Variable* out)>;
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
struct KernelTypePairHash {
......
......@@ -54,18 +54,18 @@ auto kernel1 = GenFromBit({0, 0, 0, 1});
auto kernel2 = GenFromBit({0, 0, 1, 0});
auto kernel3 = GenFromBit({0, 0, 1, 1});
void TransDataType_t(std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out) {
void TransDataType_t(const platform::DeviceContext* ctx, const Variable& in,
Variable* out) {
test_value++;
}
void TransDataLayout_t(std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out) {
void TransDataLayout_t(const platform::DeviceContext* ctx, const Variable& in,
Variable* out) {
test_value--;
}
void TransLibraryType_t(std::vector<platform::DeviceContext*> ctx,
const Variable& in, Variable* out) {
void TransLibraryType_t(const platform::DeviceContext* ctx, const Variable& in,
Variable* out) {
test_value += 2;
}
......@@ -83,7 +83,8 @@ TEST(DataTransform, Register) {
using namespace paddle::platform;
auto& instance = DataTransformFnMap::Instance();
std::vector<DeviceContext*> ctx;
ASSERT_EQ(instance.Map().size(), 3UL);
DeviceContext* ctx = nullptr;
paddle::framework::Variable in;
paddle::framework::Variable out;
......
......@@ -14,18 +14,17 @@ limitations under the License. */
#include "paddle/framework/executor.h"
#include <algorithm>
#include <iostream>
#include <memory>
#include <set>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/scope.h"
DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
"extremely slow so please use this flag wisely.");
namespace paddle {
namespace framework {
......@@ -58,6 +57,19 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
}
}
static void CheckTensorNANOrInf(const std::string& name,
const framework::Tensor& tensor) {
if (tensor.memory_size() == 0) {
return;
}
if (tensor.type().hash_code() != typeid(float).hash_code() &&
tensor.type().hash_code() != typeid(double).hash_code()) {
return;
}
PADDLE_ENFORCE(!framework::HasInf(tensor), "Tensor %s has Inf", name);
PADDLE_ENFORCE(!framework::HasNAN(tensor), "Tensor %s has NAN", name);
}
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope, bool create_vars) {
// TODO(tonyyang-svail):
......@@ -101,6 +113,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
VLOG(3) << op->DebugString();
op->Run(*local_scope, place_);
if (FLAGS_check_nan_inf) {
for (auto& vname : op->OutputVars(true)) {
auto* var = local_scope->FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
}
}
}
}
if (create_vars && create_local_scope) {
scope->DeleteScope(local_scope);
......
......@@ -189,62 +189,16 @@ void AppendLoD(LoD *lod, const LoD &lod_length) {
void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
const platform::DeviceContext &dev_ctx) {
// TODO(typhoonzero): serialize to ostream
{ // the 1st field, uint32_t version
{ // the 1st field, uint32_t version for LoDTensor
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char *>(&version), sizeof(version));
}
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
proto::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims());
auto *pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
auto out = desc.SerializeAsString();
os.write(out.data(), size);
}
{ // the 3rd field, tensor data
uint64_t size = tensor.memory_size();
auto *data_ptr = tensor.data<void>();
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
"Index overflow when writing tensor");
if (platform::is_gpu_place(tensor.place())) {
#ifdef PADDLE_WITH_CUDA
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
std::unique_ptr<char[]> buf(new char[kBufSize]);
auto &gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext &>(dev_ctx);
platform::CPUPlace cpu;
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(),
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void *>(data), size_to_write,
gpu_dev_ctx.stream());
gpu_dev_ctx.Wait();
os.write(buf.get(), size_to_write);
data += size_to_write;
size -= size_to_write;
}
#else
PADDLE_THROW("Unexpected branch");
#endif
} else {
os.write(static_cast<const char *>(data_ptr),
static_cast<std::streamsize>(size));
}
}
{ // the 4th field, lod information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
{
// the 2st field, LoD information
// uint64_t lod_level
// uint64_t lod_level_1 size in byte.
// int* lod_level_1 data
// ...
auto lod = tensor.lod();
uint64_t size = lod.size();
os.write(reinterpret_cast<const char *>(&size), sizeof(size));
......@@ -256,49 +210,19 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
static_cast<std::streamsize>(size));
}
}
// the 3st field, Tensor
SerializeToStream(os, static_cast<Tensor>(tensor), dev_ctx);
}
void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
proto::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
is.read(reinterpret_cast<char *>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char *>(buf.get()), size);
PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
"Cannot parse tensor desc");
}
{ // read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
tensor->Resize(framework::make_ddim(dims));
void *buf;
platform::Place cpu = platform::CPUPlace();
switch (desc.data_type()) {
case proto::FP32:
buf = tensor->mutable_data<float>(cpu);
break;
case proto::FP64:
buf = tensor->mutable_data<double>(cpu);
break;
case proto::INT32:
buf = tensor->mutable_data<int>(cpu);
break;
case proto::INT64:
buf = tensor->mutable_data<int64_t>(cpu);
break;
default:
PADDLE_THROW("DataType %d not supported", desc.data_type());
}
is.read(static_cast<char *>(buf), tensor->memory_size());
}
{ // read lod
{
// the 1st field, unit32_t version for SelectedRows
uint32_t version;
is.read(reinterpret_cast<char *>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
}
{
// the 2st field, LoD information
uint64_t lod_level;
is.read(reinterpret_cast<char *>(&lod_level), sizeof(lod_level));
auto &lod = *tensor->mutable_lod();
......@@ -312,6 +236,8 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor) {
lod[i] = tmp;
}
}
// the 3st filed, Tensor
DeserializeFromStream(is, static_cast<Tensor *>(tensor));
}
} // namespace framework
......
......@@ -126,6 +126,20 @@ TEST_F(LoDTensorTester, ShrinkInLevel) {
EXPECT_NE(t1.data<float>(), lod_tensor_.data<float>());
}
TEST_F(LoDTensorTester, SerializeAndDeserialize) {
LoDTensor dst_tensor;
platform::CPUDeviceContext cpu_ctx((platform::CPUPlace()));
std::ostringstream oss;
SerializeToStream(oss, lod_tensor_, cpu_ctx);
std::istringstream iss(oss.str());
DeserializeFromStream(iss, &dst_tensor);
float* dst_ptr = dst_tensor.mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < kLodTensorSize; ++i) {
EXPECT_EQ(dst_ptr[i], i);
}
EXPECT_EQ(dst_tensor.lod(), lod_tensor_.lod());
}
TEST(LodExpand, test) {
LoD lod{{0, 2}};
LoDTensor tensor;
......
......@@ -384,6 +384,24 @@ class RuntimeInferShapeContext : public InferShapeContext {
const Scope& scope_;
};
const platform::DeviceContext* GetDeviceContext(
framework::KernelTypePair& kernel_pair) {
auto& actual_kernel_key = kernel_pair.first;
auto& expected_kernel_key = kernel_pair.second;
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
if (platform::is_gpu_place(actual_kernel_key.place_) &&
platform::is_cpu_place(expected_kernel_key.place_)) {
return pool.Get(actual_kernel_key.place_);
} else if (platform::is_cpu_place(actual_kernel_key.place_) &&
platform::is_gpu_place(expected_kernel_key.place_)) {
return pool.Get(expected_kernel_key.place_);
} else {
PADDLE_THROW(
"Currently, model parallelism is only supported between CPU and CUDA");
}
}
void OperatorWithKernel::Run(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
......@@ -418,9 +436,9 @@ void OperatorWithKernel::Run(const Scope& scope,
"CPU and other devices. For example, multi-GPU model "
"parallelism will failed.");
} else {
auto kernel_pair = std::make_pair(actual_kernel_key, expected_kernel_key);
const DataTransformFn* trans_fun =
DataTransformFnMap::Instance().GetNullable(
std::make_pair(actual_kernel_key, expected_kernel_key));
DataTransformFnMap::Instance().GetNullable(kernel_pair);
if (trans_fun) {
auto input_vars = this->InputVars();
// TODO(qijun) filter the input vars that do not need to be transformed
......@@ -437,22 +455,18 @@ void OperatorWithKernel::Run(const Scope& scope,
}
if (!need_trans.empty()) {
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform::DeviceContext* trans_dev_ctx = nullptr;
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
auto trans_dev_ctx = GetDeviceContext(kernel_pair);
// Wait for transform starting
dev_ctx->Wait();
for (auto var_name : need_trans) {
(*trans_fun)(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
(*trans_fun)(trans_dev_ctx, *(scope.FindVar(var_name)),
scope.FindVar(var_name + framework::KernelTypeToString(
expected_kernel_key)));
}
// Wait for data transform finishing
for (auto ctx : trans_dev_ctx_vec) {
ctx->Wait();
}
trans_dev_ctx->Wait();
}
}
}
......
......@@ -12,5 +12,58 @@ limitations under the License. */
#include "paddle/framework/selected_rows.h"
namespace paddle {
namespace framework {} // namespace framework
namespace framework {
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx) {
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char*>(&version), sizeof(version));
}
{
// the 2st field, rows information
auto& rows = selected_rows.rows();
uint64_t size = rows.size();
os.write(reinterpret_cast<const char*>(&size), sizeof(size));
for (uint64_t i = 0; i < size; ++i) {
os.write(reinterpret_cast<const char*>(&rows[i]), sizeof(rows[i]));
}
}
{
// the 3st field, the height of SelectedRows
int64_t height = selected_rows.height();
os.write(reinterpret_cast<const char*>(&height), sizeof(height));
}
// the 4st field, Tensor data
SerializeToStream(os, selected_rows.value(), dev_ctx);
}
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows) {
auto tensor = *selected_rows->mutable_value();
{
// the 1st field, unit32_t version for SelectedRows
uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
}
{
// the 2st field, rows information
uint64_t size;
is.read(reinterpret_cast<char*>(&size), sizeof(size));
auto& rows = *selected_rows->mutable_rows();
rows.resize(size);
for (uint64_t i = 0; i < size; ++i) {
is.read(reinterpret_cast<char*>(&rows[i]), sizeof(int64_t));
}
}
{
// the 3st field, the height of the SelectedRows
int64_t height;
is.read(reinterpret_cast<char*>(&height), sizeof(int64_t));
selected_rows->set_height(height);
}
// the 4st field, tensor which contains the data
DeserializeFromStream(is, &tensor);
}
} // namespace framework
} // namespace paddle
......@@ -59,5 +59,14 @@ class SelectedRows {
int64_t height_;
};
/*
* Serialize/Desiralize SelectedRows to std::ostream
* You can pass ofstream or ostringstream to serilize to file
* or to a in memory string. GPU tensor will be copied to CPU.
*/
void SerializeToStream(std::ostream& os, const SelectedRows& selected_rows,
const platform::DeviceContext& dev_ctx);
void DeserializeFromStream(std::istream& is, SelectedRows* selected_rows);
} // namespace framework
} // namespace paddle
......@@ -43,5 +43,19 @@ TEST_F(SelectedRowsTester, complete_dims) {
ASSERT_EQ(selected_rows_->GetCompleteDims(), make_ddim({10, 100}));
}
TEST_F(SelectedRowsTester, SerializeAndDeseralize) {
SelectedRows dst_tensor;
platform::CPUDeviceContext cpu_ctx(place_);
std::ostringstream oss;
SerializeToStream(oss, *selected_rows_, cpu_ctx);
std::istringstream iss(oss.str());
DeserializeFromStream(iss, &dst_tensor);
ASSERT_EQ(selected_rows_->rows(), dst_tensor.rows());
ASSERT_EQ(selected_rows_->height(), dst_tensor.height());
}
} // namespace framework
} // namespace paddle
......@@ -15,12 +15,13 @@
#include <gtest/gtest.h>
#include <string>
namespace framework = paddle::framework;
namespace platform = paddle::platform;
TEST(Tensor, Dims) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor tt;
framework::Tensor tt;
tt.Resize({2, 3, 4});
DDim dims = tt.dims();
framework::DDim dims = tt.dims();
ASSERT_EQ(arity(dims), 3);
for (int i = 0; i < 3; ++i) {
EXPECT_EQ(i + 2, dims[i]);
......@@ -28,12 +29,12 @@ TEST(Tensor, Dims) {
}
TEST(Tensor, DataAssert) {
paddle::framework::Tensor src_tensor;
framework::Tensor src_tensor;
bool caught = false;
try {
src_tensor.data<double>();
} catch (paddle::platform::EnforceNotMet err) {
} catch (platform::EnforceNotMet err) {
caught = true;
std::string msg =
"holder_ should not be null\nTensor holds no memory. Call "
......@@ -50,61 +51,65 @@ TEST(Tensor, DataAssert) {
because Memory::Alloc() and Memory::Free() have not been ready.
*/
TEST(Tensor, MutableData) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
framework::Tensor src_tensor;
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CPUPlace());
p1 = src_tensor.mutable_data<float>(framework::make_ddim({1, 2, 3}),
platform::CPUPlace());
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), CPUPlace());
p2 = src_tensor.mutable_data<float>(framework::make_ddim({3, 4}),
platform::CPUPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), CPUPlace());
p1 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2, 3}),
platform::CPUPlace());
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CPUPlace());
p2 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2}),
platform::CPUPlace());
EXPECT_EQ(p1, p2);
}
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
framework::Tensor src_tensor;
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CUDAPlace());
p1 = src_tensor.mutable_data<float>(framework::make_ddim({1, 2, 3}),
platform::CUDAPlace());
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), CUDAPlace());
p2 = src_tensor.mutable_data<float>(framework::make_ddim({3, 4}),
platform::CUDAPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), CUDAPlace());
p1 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2, 3}),
platform::CUDAPlace());
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CUDAPlace());
p2 = src_tensor.mutable_data<float>(framework::make_ddim({2, 2}),
platform::CUDAPlace());
EXPECT_EQ(p1, p2);
}
#endif
}
TEST(Tensor, ShareDataWith) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
Tensor dst_tensor;
framework::Tensor src_tensor;
framework::Tensor dst_tensor;
// Try to share data form uninitialized tensor
bool caught = false;
try {
......@@ -121,16 +126,18 @@ TEST(Tensor, ShareDataWith) {
}
ASSERT_TRUE(caught);
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CPUPlace());
src_tensor.mutable_data<int>(framework::make_ddim({2, 3, 4}),
platform::CPUPlace());
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
Tensor dst_tensor;
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CUDAPlace());
framework::Tensor src_tensor;
framework::Tensor dst_tensor;
src_tensor.mutable_data<int>(framework::make_ddim({2, 3, 4}),
platform::CUDAPlace());
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
......@@ -138,13 +145,12 @@ TEST(Tensor, ShareDataWith) {
}
TEST(Tensor, Slice) {
using namespace paddle::framework;
using namespace paddle::platform;
{
Tensor src_tensor;
src_tensor.mutable_data<int>(make_ddim({5, 3, 4}), CPUPlace());
Tensor slice_tensor = src_tensor.Slice(1, 3);
DDim slice_dims = slice_tensor.dims();
framework::Tensor src_tensor;
src_tensor.mutable_data<int>(framework::make_ddim({5, 3, 4}),
platform::CPUPlace());
framework::Tensor slice_tensor = src_tensor.Slice(1, 3);
framework::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 3);
EXPECT_EQ(slice_dims[0], 2);
EXPECT_EQ(slice_dims[1], 3);
......@@ -153,11 +159,12 @@ TEST(Tensor, Slice) {
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<int>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<int>(src_tensor.dims(), CPUPlace()));
src_tensor.mutable_data<int>(src_tensor.dims(), platform::CPUPlace()));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<int>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<int>(slice_tensor.dims(), CPUPlace()));
uintptr_t slice_mutable_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.mutable_data<int>(
slice_tensor.dims(), platform::CPUPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 3 * 4 * 1 * sizeof(int), slice_data_address);
......@@ -165,22 +172,25 @@ TEST(Tensor, Slice) {
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), CUDAPlace());
Tensor slice_tensor = src_tensor.Slice(2, 6);
DDim slice_dims = slice_tensor.dims();
framework::Tensor src_tensor;
src_tensor.mutable_data<double>(framework::make_ddim({6, 9}),
platform::CUDAPlace());
framework::Tensor slice_tensor = src_tensor.Slice(2, 6);
framework::DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
EXPECT_EQ(slice_dims[0], 4);
EXPECT_EQ(slice_dims[1], 9);
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<double>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<double>(src_tensor.dims(), CUDAPlace()));
uintptr_t src_mutable_data_address =
reinterpret_cast<uintptr_t>(src_tensor.mutable_data<double>(
src_tensor.dims(), platform::CUDAPlace()));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<double>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<double>(slice_tensor.dims(), CUDAPlace()));
uintptr_t slice_mutable_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.mutable_data<double>(
slice_tensor.dims(), platform::CUDAPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
......@@ -189,23 +199,19 @@ TEST(Tensor, Slice) {
}
TEST(Tensor, ReshapeToMatrix) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
int* src_ptr = src.mutable_data<int>({2, 3, 4, 9}, CPUPlace());
framework::Tensor src;
int* src_ptr = src.mutable_data<int>({2, 3, 4, 9}, platform::CPUPlace());
for (int i = 0; i < 2 * 3 * 4 * 9; ++i) {
src_ptr[i] = i;
}
Tensor res = ReshapeToMatrix(src, 2);
framework::Tensor res = framework::ReshapeToMatrix(src, 2);
ASSERT_EQ(res.dims()[0], 2 * 3);
ASSERT_EQ(res.dims()[1], 4 * 9);
}
TEST(Tensor, Layout) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
ASSERT_EQ(src.layout(), DataLayout::kNHWC);
src.set_layout(DataLayout::kAnyLayout);
ASSERT_EQ(src.layout(), DataLayout::kAnyLayout);
framework::Tensor src;
ASSERT_EQ(src.layout(), framework::DataLayout::kNHWC);
src.set_layout(framework::DataLayout::kAnyLayout);
ASSERT_EQ(src.layout(), framework::DataLayout::kAnyLayout);
}
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/tensor_util.h"
namespace paddle {
namespace framework {
template <typename Predicate, typename DevCtx>
struct AnyDTypeVisitor {
Predicate predicate_;
const Tensor& tensor_;
const DevCtx& ctx_;
Tensor* out_;
AnyDTypeVisitor(Predicate predicate, const Tensor& tensor, const DevCtx& ctx,
Tensor* out)
: predicate_(predicate), tensor_(tensor), ctx_(ctx), out_(out) {}
template <typename T>
void operator()() const {
auto t = EigenVector<T>::Flatten(tensor_);
auto o = EigenScalar<bool>::From(*out_);
// return any of predicate_(t) is true.
o.device(*ctx_.eigen_device()) = predicate_(t).any();
}
};
template <typename Predicate, typename DevCtx>
inline void AnyImpl(Predicate predicate, const framework::Tensor& tensor,
const DevCtx& ctx, framework::Tensor* out) {
VisitDataType(ToDataType(tensor.type()), AnyDTypeVisitor<Predicate, DevCtx>(
predicate, tensor, ctx, out));
}
template <typename Predicate>
struct AnyVisitor : public boost::static_visitor<bool> {
const framework::Tensor& tensor_;
Predicate predicate_;
AnyVisitor(const framework::Tensor& tensor, Predicate predicate)
: tensor_(tensor), predicate_(std::move(predicate)) {}
template <typename Place>
bool operator()(const Place& place) const {
framework::Tensor out;
out.Resize({1});
out.mutable_data<bool>(place);
auto* ctx = platform::DeviceContextPool::Instance().GetByPlace(place);
AnyImpl(predicate_, tensor_, *ctx, &out);
return this->GetResult(out, place);
}
bool GetResult(const framework::Tensor& out,
const platform::CUDAPlace& gpu) const {
platform::CPUPlace cpu;
framework::Tensor tmp;
tmp.Resize({1});
tmp.mutable_data<bool>(cpu);
auto gpuctx = platform::DeviceContextPool::Instance().Get(gpu);
gpuctx->Wait();
CopyFrom(out, cpu, *gpuctx, &tmp);
gpuctx->Wait();
return GetResult(tmp, cpu);
}
bool GetResult(const framework::Tensor& out,
const platform::CPUPlace& cpu) const {
return *out.data<bool>();
}
};
template <typename Predicate>
inline bool Any(const framework::Tensor& tensor, Predicate predicate) {
AnyVisitor<Predicate> visitor(tensor, predicate);
auto place = tensor.place();
return platform::VisitPlace(place, visitor);
}
struct HasNANPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isnan()) {
// Cast eigen_vector to vector of bool. true if is inf.
return eigen_vec.isnan();
}
};
bool HasNAN(const framework::Tensor& tensor) {
HasNANPredicate predicate;
return Any(tensor, predicate);
}
struct HasInfPredicate {
template <typename T>
auto operator()(const T& eigen_vec) const
-> decltype(std::declval<T>().isinf()) {
// Cast eigen_vector to vector of bool. true if is inf.
return eigen_vec.isinf();
}
};
bool HasInf(const framework::Tensor& tensor) {
HasInfPredicate predicate;
return Any(tensor, predicate);
}
} // namespace framework
} // namespace paddle
./tensor_util.cc
\ No newline at end of file
......@@ -13,7 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/data_type.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace framework {
......@@ -205,5 +209,109 @@ inline void CopyToVector(const Tensor& src, std::vector<T>* dst) {
src_ptr, size);
}
// Returns true if a tensor contains NAN, i.e., Not A Number.
bool HasNAN(const framework::Tensor& tensor);
// Returns true if a tensor contains Inf, i.e., Infinity.
bool HasInf(const framework::Tensor& tensor);
inline void SerializeToStream(std::ostream& os, const Tensor& tensor,
const platform::DeviceContext& dev_ctx) {
// TODO(typhoonzero): serialize to ostream
{ // the 1st field, uint32_t version
constexpr uint32_t version = 0;
os.write(reinterpret_cast<const char*>(&version), sizeof(version));
}
{ // the 2nd field, tensor description
// int32_t size
// void* protobuf message
proto::TensorDesc desc;
desc.set_data_type(framework::ToDataType(tensor.type()));
auto dims = framework::vectorize(tensor.dims());
auto* pb_dims = desc.mutable_dims();
pb_dims->Resize(static_cast<int>(dims.size()), 0);
std::copy(dims.begin(), dims.end(), pb_dims->begin());
int32_t size = desc.ByteSize();
os.write(reinterpret_cast<const char*>(&size), sizeof(size));
auto out = desc.SerializeAsString();
os.write(out.data(), size);
}
{ // the 3rd field, tensor data
uint64_t size = tensor.memory_size();
auto* data_ptr = tensor.data<void>();
PADDLE_ENFORCE(size < std::numeric_limits<std::streamsize>::max(),
"Index overflow when writing tensor");
if (platform::is_gpu_place(tensor.place())) {
#ifdef PADDLE_WITH_CUDA
constexpr size_t kBufSize = 1024 * 1024 * 64; // 64MB
std::unique_ptr<char[]> buf(new char[kBufSize]);
auto& gpu_dev_ctx =
static_cast<const platform::CUDADeviceContext&>(dev_ctx);
platform::CPUPlace cpu;
uintptr_t data = reinterpret_cast<uintptr_t>(data_ptr);
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(),
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void*>(data), size_to_write,
gpu_dev_ctx.stream());
gpu_dev_ctx.Wait();
os.write(buf.get(), size_to_write);
data += size_to_write;
size -= size_to_write;
}
#else
PADDLE_THROW("Unexpected branch");
#endif
} else {
os.write(static_cast<const char*>(data_ptr),
static_cast<std::streamsize>(size));
}
}
}
inline void DeserializeFromStream(std::istream& is, Tensor* tensor) {
uint32_t version;
is.read(reinterpret_cast<char*>(&version), sizeof(version));
PADDLE_ENFORCE_EQ(version, 0U, "Only version 0 is supported");
proto::TensorDesc desc;
{ // int32_t size
// proto buffer
int32_t size;
is.read(reinterpret_cast<char*>(&size), sizeof(size));
std::unique_ptr<char[]> buf(new char[size]);
is.read(reinterpret_cast<char*>(buf.get()), size);
PADDLE_ENFORCE(desc.ParseFromArray(buf.get(), size),
"Cannot parse tensor desc");
}
{ // read tensor
std::vector<int64_t> dims;
dims.reserve(static_cast<size_t>(desc.dims().size()));
std::copy(desc.dims().begin(), desc.dims().end(), std::back_inserter(dims));
tensor->Resize(framework::make_ddim(dims));
void* buf;
platform::Place cpu = platform::CPUPlace();
// TODO(Yancey1989): use VisiterDataType instead of DataType switch
switch (desc.data_type()) {
case proto::FP32:
buf = tensor->mutable_data<float>(cpu);
break;
case proto::FP64:
buf = tensor->mutable_data<double>(cpu);
break;
case proto::INT32:
buf = tensor->mutable_data<int>(cpu);
break;
case proto::INT64:
buf = tensor->mutable_data<int64_t>(cpu);
break;
default:
PADDLE_THROW("DataType %d not supported", desc.data_type());
}
is.read(static_cast<char*>(buf), tensor->memory_size());
}
}
} // namespace framework
} // namespace paddle
......@@ -13,6 +13,7 @@
#include "paddle/framework/tensor_util.h"
#include <gtest/gtest.h>
#include <cmath>
#include <string>
namespace paddle {
......@@ -230,5 +231,78 @@ TEST(CopyToVector, Tensor) {
#endif
}
TEST(HasNAN, CPU) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
float* buf = src.mutable_data<float>({3}, CPUPlace());
buf[0] = 0.0;
buf[1] = NAN;
buf[2] = 0.0;
ASSERT_TRUE(HasNAN(src));
}
TEST(HasInf, CPU) {
using namespace paddle::framework;
using namespace paddle::platform;
Tensor src;
double* buf = src.mutable_data<double>({3}, CPUPlace());
buf[0] = 1.0;
buf[1] = INFINITY;
buf[2] = 0.0;
ASSERT_TRUE(HasInf(src));
}
TEST(Tensor, SerializeAndDeserialize) {
framework::Tensor src_tensor;
int array[6] = {1, 2, 3, 4, 5, 6};
src_tensor.Resize({2, 3});
int* src_ptr = src_tensor.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < 6; ++i) {
src_ptr[i] = array[i];
}
{
framework::Tensor dst_tensor;
auto place = new platform::CPUPlace();
platform::CPUDeviceContext cpu_ctx(*place);
std::ostringstream oss;
SerializeToStream(oss, src_tensor, cpu_ctx);
std::istringstream iss(oss.str());
DeserializeFromStream(iss, &dst_tensor);
int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < 5; ++i) {
ASSERT_EQ(dst_ptr[i], array[i]);
}
delete place;
}
#ifdef PADDLE_WITH_CUDA
{
Tensor gpu_tensor;
gpu_tensor.Resize({2, 3});
Tensor dst_tensor;
auto gpu_place = new platform::CUDAPlace();
platform::CUDADeviceContext gpu_ctx(*gpu_place);
CopyFrom(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
std::ostringstream oss;
SerializeToStream(oss, gpu_tensor, gpu_ctx);
std::istringstream iss(oss.str());
DeserializeFromStream(iss, &dst_tensor);
int* dst_ptr = dst_tensor.mutable_data<int>(platform::CPUPlace());
for (int i = 0; i < 6; ++i) {
ASSERT_EQ(dst_ptr[i], array[i]);
}
delete gpu_place;
}
#endif
}
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/framework/tensor_util.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
namespace paddle {
namespace framework {
static __global__ void FillNAN(float* buf) {
buf[0] = 0.0;
buf[1] = 0.1;
buf[2] = NAN;
}
static __global__ void FillInf(float* buf) {
buf[0] = 0.0;
buf[1] = INFINITY;
buf[2] = 0.5;
}
TEST(HasNAN, GPU) {
Tensor tensor;
platform::CUDAPlace gpu(0);
auto& pool = platform::DeviceContextPool::Instance();
auto* cuda_ctx = pool.GetByPlace(gpu);
float* buf = tensor.mutable_data<float>({3}, gpu);
FillNAN<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
ASSERT_TRUE(HasNAN(tensor));
}
TEST(HasInf, GPU) {
Tensor tensor;
platform::CUDAPlace gpu(0);
auto& pool = platform::DeviceContextPool::Instance();
auto* cuda_ctx = pool.GetByPlace(gpu);
float* buf = tensor.mutable_data<float>({3}, gpu);
FillInf<<<1, 1, 0, cuda_ctx->stream()>>>(buf);
cuda_ctx->Wait();
ASSERT_TRUE(HasInf(tensor));
}
} // namespace framework
} // namespace paddle
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <condition_variable>
#include <functional>
#include <future>
#include <mutex>
#include <queue>
#include <thread>
......@@ -25,10 +26,11 @@ limitations under the License. */
namespace paddle {
namespace framework {
typedef std::function<void()> Task;
class ThreadPool {
public:
typedef std::packaged_task<void()> Task;
typedef std::function<void()> Fun;
/**
* @brief Get a instance of threadpool, the thread number will
* be specified as the number of hardware thread contexts
......@@ -61,13 +63,18 @@ class ThreadPool {
/**
* @brief Push a function to the queue, and will be scheduled and
* executed if a thread is available.
* @param[in] Task will be pushed to the task queue.
* @param[in] Task, will be pushed to the task queue.
* @return std::future<void>, we could wait for the task finished by
* f.wait().
*/
void Run(const Task& fn) {
std::future<void> Run(const Fun& fn) {
std::unique_lock<std::mutex> lock(mutex_);
tasks_.push(fn);
Task task(std::bind(fn));
std::future<void> f = task.get_future();
tasks_.push(std::move(task));
lock.unlock();
scheduled_.notify_one();
return f;
}
/**
......@@ -110,7 +117,7 @@ class ThreadPool {
break;
}
// pop a task from the task queue
auto task = tasks_.front();
auto task = std::move(tasks_.front());
tasks_.pop();
--available_;
......
......@@ -20,16 +20,21 @@ limitations under the License. */
namespace framework = paddle::framework;
void do_sum(framework::ThreadPool* pool, std::atomic<int>& sum, int cnt) {
std::vector<std::future<void>> fs;
for (int i = 0; i < cnt; ++i) {
pool->Run([&sum]() { sum.fetch_add(1); });
auto f = pool->Run([&sum]() { sum.fetch_add(1); });
fs.push_back(std::move(f));
}
for (auto& f : fs) {
f.wait();
}
}
TEST(ThreadPool, ConcurrentInit) {
framework::ThreadPool* pool;
int concurrent_cnt = 50;
int n = 50;
std::vector<std::thread> threads;
for (int i = 0; i < concurrent_cnt; ++i) {
for (int i = 0; i < n; ++i) {
std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); });
threads.push_back(std::move(t));
}
......@@ -38,13 +43,13 @@ TEST(ThreadPool, ConcurrentInit) {
}
}
TEST(ThreadPool, ConcurrentStart) {
TEST(ThreadPool, ConcurrentRun) {
framework::ThreadPool* pool = framework::ThreadPool::GetInstance();
std::atomic<int> sum(0);
std::vector<std::thread> threads;
int concurrent_cnt = 50;
int n = 50;
// sum = (n * (n + 1)) / 2
for (int i = 1; i <= concurrent_cnt; ++i) {
for (int i = 1; i <= n; ++i) {
std::thread t(do_sum, pool, std::ref(sum), i);
threads.push_back(std::move(t));
}
......@@ -52,5 +57,5 @@ TEST(ThreadPool, ConcurrentStart) {
t.join();
}
pool->Wait();
EXPECT_EQ(sum, ((concurrent_cnt + 1) * concurrent_cnt) / 2);
EXPECT_EQ(sum, ((n + 1) * n) / 2);
}
......@@ -126,14 +126,165 @@ public:
inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
}
}
};
#ifdef PADDLE_MOBILE_INFERENCE
if (Device == DEVICE_TYPE_CPU) {
memory_.reset();
/*
* \brief Forward calculation of convolution, optimized for mobile.
*/
template <DeviceType Device>
class GemmConvMobileFunction : public ConvFunctionBase {
public:
void init(const FuncConfig& config) override {
ConvFunctionBase::init(config);
}
void check(const BufferArgs& inputs, const BufferArgs& outputs) override {
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
checkShape(input, filter, output);
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(numInputs_, inputs.size());
CHECK_EQ(numOutputs_, outputs.size());
check(inputs, outputs);
// TODO(hedaoyuan): Need to define some index macros,
// to avoid useing 0 and 1.
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& output = outputs[0].shape();
real beta;
if (outputs[0].getArgType() == ADD_TO) {
beta = 1.0;
} else {
beta = 0.0;
}
#endif
size_t batchSize = input[0];
size_t inputChannels = input[1];
size_t inputHeight = input[2];
size_t inputWidth = input[3];
size_t filterHeight = getFilterHeight(filter);
size_t filterWidth = getFilterWidth(filter);
size_t outputChannels = output[1];
size_t outputHeight = output[2];
size_t outputWidth = output[3];
real* inputData = inputs[0].data<real>();
real* filterData = inputs[1].data<real>();
real* outputData = outputs[0].data<real>();
bool needIm2col = isNeedIm2col(filter);
TensorShape imShape =
TensorShape({inputChannels / groups_, inputHeight, inputWidth});
TensorShape colShape;
real* colData = NULL;
size_t colHeight = inputChannels / groups_ * filterHeight * filterWidth;
size_t colWidth = outputHeight * outputWidth;
// Max col matrix height 256, Max col matrix width 1024
size_t stepColHeight = std::min(colHeight, static_cast<size_t>(256));
size_t stepColWidth = std::min(colWidth, static_cast<size_t>(2048));
if (needIm2col) {
colShape = TensorShape({inputChannels / groups_,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
resizeBuffer<Device>(stepColHeight * stepColWidth * sizeof(real));
colData = reinterpret_cast<real*>(memory_->getBuf());
}
Im2ColMobileFunctor<real> im2col;
size_t inputOffset = imShape.getElements();
size_t outputOffset =
(outputChannels / groups_) * outputHeight * outputWidth;
size_t filterOffset = filter.getElements() / groups_;
int nStride = colWidth;
int kStride = colHeight;
for (size_t i = 0; i < batchSize; i++) {
for (size_t g = 0; g < groups_; g++) {
if (needIm2col) {
real beta_ = beta;
for (size_t colHeightStart = 0; colHeightStart < colHeight;
colHeightStart += stepColHeight) {
for (size_t colWidthStart = 0; colWidthStart < colWidth;
colWidthStart += stepColWidth) {
int N = std::min(colWidth - colWidthStart, stepColWidth);
int K = std::min(colHeight - colHeightStart, stepColHeight);
// im2col
im2col(inputData + g * inputOffset,
imShape,
colData,
colShape,
strideH(),
strideW(),
paddingH(),
paddingW(),
dilationH(),
dilationW(),
colHeightStart,
K,
colWidthStart,
N);
// gemm
int M = outputChannels / groups_;
BlasGemm<Device, real>::compute(
false,
false,
M,
N,
K,
1.0f,
filterData + g * filterOffset + colHeightStart,
kStride,
colData,
N,
beta_,
outputData + g * outputOffset + colWidthStart,
nStride);
}
beta_ = 1.0;
}
} else {
int M = outputChannels / groups_;
int N = outputHeight * outputWidth;
int K = inputChannels / groups_ * filterHeight * filterWidth;
BlasGemm<Device, real>::compute(false,
false,
M,
N,
K,
1.0f,
filterData + g * filterOffset,
K,
inputData + g * inputOffset,
N,
beta,
outputData + g * outputOffset,
N);
}
}
inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
}
memory_.reset();
}
};
#endif
/*
* \brief Backward input calculation of convolution.
*/
......@@ -348,7 +499,11 @@ public:
}
};
#ifdef PADDLE_MOBILE_INFERENCE
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvMobileFunction);
#else
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
#endif
REGISTER_TYPED_FUNC(GemmConvGradInput, CPU, GemmConvGradInputFunction);
REGISTER_TYPED_FUNC(GemmConvGradFilter, CPU, GemmConvGradFilterFunction);
#ifdef PADDLE_WITH_CUDA
......
......@@ -98,4 +98,54 @@ public:
int dilationWidth = 1);
};
template <class T>
class Im2ColMobileFunctor {
public:
void operator()(const T* imData,
const TensorShape& imShape,
T* colData,
const TensorShape& colShape,
int strideHeight,
int strideWidth,
int paddingHeight,
int paddingWidth,
int dilationHeight,
int dilationWidth,
int colHeightStart,
int colHeightSize,
int colWidthStart,
int colWidthSize) {
int inputHeight = imShape[1];
int inputWidth = imShape[2];
int filterHeight = colShape[1];
int filterWidth = colShape[2];
int outputWidth = colShape[4];
for (int colh = 0; colh < colHeightSize; colh++) {
int wOffset = (colHeightStart + colh) % filterWidth;
int hOffset = ((colHeightStart + colh) / filterWidth) % filterHeight;
int c_im = (colHeightStart + colh) / filterWidth / filterHeight;
for (int colw = 0; colw < colWidthSize; colw++) {
int h = (colWidthStart + colw) / outputWidth;
int w = (colWidthStart + colw) % outputWidth;
int imRowIdx = h * strideHeight + hOffset * dilationHeight;
int imColIdx = w * strideWidth + wOffset * dilationWidth;
if ((imRowIdx - paddingHeight) < 0 ||
(imRowIdx - paddingHeight) >= inputHeight ||
(imColIdx - paddingWidth) < 0 ||
(imColIdx - paddingWidth) >= inputWidth) {
colData[colh * colWidthSize + colw] = static_cast<T>(0);
} else {
imRowIdx += c_im * inputHeight - paddingHeight;
imColIdx -= paddingWidth;
colData[colh * colWidthSize + colw] =
imData[imRowIdx * inputWidth + imColIdx];
}
}
}
}
};
} // namespace paddle
......@@ -138,4 +138,86 @@ TEST(Im2ColFunctor, GPU) { TestIm2ColFunctor<DEVICE_TYPE_GPU, float>(); }
#endif
template <class T>
void TestIm2ColMobileFunctor() {
for (size_t channels : {32}) {
for (size_t inputHeight : {33, 100}) {
for (size_t inputWidth : {32, 96}) {
for (size_t filterHeight : {5}) {
for (size_t filterWidth : {7}) {
for (size_t stride : {2}) {
for (size_t padding : {1}) {
for (size_t dilation : {1, 3}) {
size_t filterSizeH = (filterHeight - 1) * dilation + 1;
size_t filterSizeW = (filterWidth - 1) * dilation + 1;
if (inputHeight + 2 * padding < filterSizeH ||
inputWidth + 2 * padding < filterSizeW)
break;
if (padding >= filterSizeH || padding >= filterSizeW) break;
size_t outputHeight =
(inputHeight - filterSizeH + 2 * padding) / stride + 1;
size_t outputWidth =
(inputWidth - filterSizeW + 2 * padding) / stride + 1;
TensorShape imShape =
TensorShape({channels, inputHeight, inputWidth});
TensorShape colShape1 = TensorShape({channels,
filterHeight,
filterWidth,
outputHeight,
outputWidth});
size_t height = channels * filterHeight * filterWidth;
size_t width = outputHeight * outputWidth;
VectorPtr input1 =
Vector::create(imShape.getElements(), false);
VectorPtr input2 =
Vector::create(imShape.getElements(), false);
MatrixPtr output1 =
Matrix::create(height, width, false, false);
MatrixPtr output2 =
Matrix::create(height, width, false, false);
input1->uniform(0.001, 1);
input2->copyFrom(*input1);
Im2ColFunctor<kCFO, DEVICE_TYPE_CPU, T> im2Col1;
Im2ColMobileFunctor<T> im2Col2;
im2Col1(input1->getData(),
imShape,
output1->getData(),
colShape1,
stride,
stride,
padding,
padding,
dilation,
dilation);
im2Col2(input2->getData(),
imShape,
output2->getData(),
colShape1,
stride,
stride,
padding,
padding,
dilation,
dilation,
0,
height,
0,
width);
autotest::TensorCheckEqual(*output1, *output2);
}
}
}
}
}
}
}
}
}
TEST(Im2ColFunctor, Mobile) { TestIm2ColMobileFunctor<float>(); }
} // namespace paddle
......@@ -53,7 +53,6 @@ function(op_library TARGET)
if (${op_library_DEPS_len} GREATER 0)
set(DEPS_OPS ${TARGET} ${DEPS_OPS} PARENT_SCOPE)
endif()
if (WITH_GPU)
nv_library(${TARGET} SRCS ${cc_srcs} ${cu_cc_srcs} ${cu_srcs} DEPS ${op_library_DEPS}
${op_common_deps})
......@@ -187,6 +186,36 @@ endfunction()
add_subdirectory(math)
add_subdirectory(nccl)
set(DEPS_OPS
cond_op
cross_entropy_op
recurrent_op
softmax_with_cross_entropy_op
softmax_op
sequence_softmax_op
sum_op
pool_op
maxout_op
unpool_op
pool_with_index_op
conv_op
conv_transpose_op
nccl_op
sequence_conv_op
sequence_pool_op
lod_rank_table_op
lod_tensor_to_array_op
array_to_lod_tensor_op
max_sequence_len_op
lstm_op
gru_op
adagrad_op
sgd_op
save_op
load_op
send_op
recv_op
detection_output_op)
if(WITH_GPU)
op_library(nccl_op DEPS nccl_common)
else()
......@@ -210,6 +239,7 @@ op_library(cond_op DEPS framework_proto tensor net_op)
op_library(cross_entropy_op DEPS cross_entropy)
op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library(softmax_op DEPS softmax)
op_library(detection_output_op DEPS softmax)
op_library(sequence_softmax_op DEPS softmax)
op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor)
......@@ -229,6 +259,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op DEPS executor)
op_library(cos_sim_op DEPS cos_sim_functor)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor)
op_library(load_op DEPS lod_tensor)
......
......@@ -105,48 +105,18 @@ struct SparseAdagradFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) {
// 1. g_m.rows = set(g.rows)
auto grad_rows = grad.rows();
std::set<int64_t> row_set(grad_rows.begin(), grad_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto grad_width = grad.value().dims()[1];
std::unique_ptr<framework::SelectedRows> grad_merge{
new framework::SelectedRows()};
grad_merge->set_rows(merge_rows);
grad_merge->set_height(grad.height());
grad_merge->mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
auto* grad_data = grad.value().data<T>();
for (size_t i = 0; i < grad_rows.size(); i++) {
size_t grad_merge_i = FindPos(merge_rows, grad_rows[i]);
for (int64_t j = 0; j < grad_width; j++) {
grad_merge_data[grad_merge_i * grad_width + j] +=
grad_data[i * grad_width + j];
}
}
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
auto grad_merge = merge_func(context, grad);
auto& merge_rows = grad_merge.rows();
auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
// 2. m += g_m * g_m
std::unique_ptr<framework::SelectedRows> grad_square{
new framework::SelectedRows()};
grad_square->set_rows(grad_merge->rows());
grad_square->set_height(grad_merge->height());
grad_square->mutable_value()->mutable_data<T>(grad_merge->value().dims(),
context.GetPlace());
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.eigen_device()) = gm * gm;
math::scatter::Mul<platform::CPUDeviceContext, T> sqare_func;
auto grad_square = sqare_func(context, grad_merge, grad_merge);
math::SelectedRowsAddToTensor<platform::CPUDeviceContext, T> functor;
functor(context, *grad_square, moment);
functor(context, grad_square, moment);
// 3. update parameter
auto* lr = learning_rate.data<T>();
......
......@@ -78,62 +78,30 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
const framework::Tensor& learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param) {
// 1. g_m.rows = set(g.rows)
auto grad_rows = grad.rows();
std::set<int64_t> row_set(grad_rows.begin(), grad_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto grad_width = grad.value().dims()[1];
std::unique_ptr<framework::SelectedRows> grad_merge{
new framework::SelectedRows()};
grad_merge->set_rows(merge_rows);
grad_merge->set_height(grad.height());
grad_merge->mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), grad_width}),
context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, grad_merge->mutable_value(), 0.0);
auto* grad_merge_data = grad_merge->mutable_value()->data<T>();
auto* grad_data = grad.value().data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid1(1, grad_rows.size());
MergeGradKernel<
T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_data, grad.rows().data(),
grad_merge_data, grad_merge->rows().data(),
grad_merge->rows().size(), grad_width);
math::scatter::MergeAdd<platform::CUDADeviceContext, T> merge_func;
auto grad_merge = merge_func(context, grad);
auto* grad_merge_data = grad_merge.mutable_value()->template data<T>();
auto& merge_rows = grad_merge.rows();
// 2. m += g_m * g_m
std::unique_ptr<framework::SelectedRows> grad_square{
new framework::SelectedRows()};
grad_square->set_rows(grad_merge->rows());
grad_square->set_height(grad_merge->height());
grad_square->mutable_value()->mutable_data<T>(grad_merge->value().dims(),
context.GetPlace());
auto gs =
framework::EigenVector<T>::Flatten(*(grad_square->mutable_value()));
auto gm = framework::EigenVector<T>::Flatten(grad_merge->value());
gs.device(*context.eigen_device()) = gm * gm;
math::scatter::Mul<platform::CUDADeviceContext, T> sqare_func;
auto grad_square = sqare_func(context, grad_merge, grad_merge);
math::SelectedRowsAddToTensor<platform::CUDADeviceContext, T> functor;
functor(context, *grad_square, moment);
functor(context, grad_square, moment);
// 3. update parameter
auto* lr = learning_rate.data<T>();
auto* param_data = param->data<T>();
auto* moment_data = moment->data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid2(1, merge_rows.size());
SparseAdagradFunctorKernel<
T, 256><<<grid2, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_merge_data, grad_merge->rows().data(),
.stream()>>>(grad_merge_data, grad_merge.rows().data(),
lr, param_data, moment_data, grad_width,
epsilon);
}
......
......@@ -16,11 +16,14 @@ limitations under the License. */
#include <math.h> // for sqrt in CPU and CUDA
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/for_range.h"
namespace paddle {
namespace operators {
namespace scatter = paddle::operators::math::scatter;
template <typename T>
struct AdamFunctor {
T beta1_;
......@@ -79,6 +82,69 @@ struct AdamFunctor {
}
};
template <typename T>
struct SparseAdamFunctor {
T beta1_;
T beta2_;
T epsilon_;
const T* beta1_pow_;
const T* beta2_pow_;
const T* moment1_;
T* moment1_out_;
const T* moment2_;
T* moment2_out_;
const T* lr_;
const T* grad_;
const T* param_;
T* param_out_;
const int64_t* rows_;
int64_t row_numel_;
SparseAdamFunctor(T beta1, T beta2, T epsilon, const T* beta1_pow,
const T* beta2_pow, const T* mom1, T* mom1_out,
const T* mom2, T* mom2_out, const T* lr, const T* grad,
const T* param, T* param_out, const int64_t* rows,
int64_t row_numel)
: beta1_(beta1),
beta2_(beta2),
epsilon_(epsilon),
beta1_pow_(beta1_pow),
beta2_pow_(beta2_pow),
moment1_(mom1),
moment1_out_(mom1_out),
moment2_(mom2),
moment2_out_(mom2_out),
lr_(lr),
grad_(grad),
param_(param),
param_out_(param_out),
rows_(rows),
row_numel_(row_numel) {}
inline HOSTDEVICE void operator()(size_t i) const {
T beta1_pow = *beta1_pow_;
T beta2_pow = *beta2_pow_;
for (int64_t j = 0; j < row_numel_; ++j) {
T g = grad_[i * row_numel_ + j];
T mom1 = moment1_[rows_[i] * row_numel_ + j];
T mom2 = moment2_[rows_[i] * row_numel_ + j];
T lr = *lr_;
T p = param_[rows_[i] * row_numel_ + j];
lr *= sqrt(1 - beta2_pow) / (1 - beta1_pow);
mom1 = beta1_ * mom1 + (1 - beta1_) * g;
mom2 = beta2_ * mom2 + (1 - beta2_) * g * g;
p -= lr * (mom1 / (sqrt(mom2) + epsilon_));
moment1_out_[rows_[i] * row_numel_ + j] = mom1;
moment2_out_[rows_[i] * row_numel_ + j] = mom2;
param_out_[rows_[i] * row_numel_ + j] = p;
} // for col id
}
};
template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> {
public:
......@@ -90,7 +156,8 @@ class AdamOpKernel : public framework::OpKernel<T> {
T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto& param = Ref(ctx.Input<LoDTensor>("Param"), "Must set Param");
auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
// auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
auto* grad_var = ctx.InputVar("Grad");
auto& mom1 = Ref(ctx.Input<LoDTensor>("Moment1"), "Must set Moment1");
auto& mom2 = Ref(ctx.Input<LoDTensor>("Moment2"), "Must set Moment2");
auto& lr =
......@@ -108,18 +175,48 @@ class AdamOpKernel : public framework::OpKernel<T> {
auto& mom2_out =
Ref(ctx.Output<LoDTensor>("Moment2Out"), "Must set Moment1Out");
AdamFunctor<T> functor(beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(),
mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad.template data<T>(),
param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()));
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()), param.numel());
for_range(functor);
if (grad_var->IsType<framework::LoDTensor>()) {
auto& grad = Ref(ctx.Input<LoDTensor>("Grad"), "Must set Grad");
AdamFunctor<T> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad.template data<T>(),
param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()));
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
param.numel());
for_range(functor);
} else if (grad_var->IsType<framework::SelectedRows>()) {
auto& grad =
Ref(ctx.Input<framework::SelectedRows>("Grad"), "Must set Grad");
// merge duplicated rows if any.
scatter::MergeAdd<DeviceContext, T> merge_func;
auto grad_merge =
merge_func(ctx.template device_context<DeviceContext>(), grad);
auto& grad_tensor = grad_merge.value();
const T* grad_data = grad_tensor.template data<T>();
auto* rows = grad_merge.rows().data();
auto row_numel = grad_tensor.numel() / grad_merge.rows().size();
SparseAdamFunctor<T> functor(
beta1, beta2, epsilon, beta1_pow.template data<T>(),
beta2_pow.template data<T>(), mom1.template data<T>(),
mom1_out.template mutable_data<T>(ctx.GetPlace()),
mom2.template data<T>(),
mom2_out.template mutable_data<T>(ctx.GetPlace()),
lr.template data<T>(), grad_data, param.template data<T>(),
param_out.template mutable_data<T>(ctx.GetPlace()), rows, row_numel);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(ctx.device_context()),
grad_merge.rows().size());
for_range(functor);
} else {
PADDLE_THROW("Variable type not supported by adam_op");
}
}
};
......
......@@ -31,8 +31,6 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int groups = ctx->Attrs().Get<int>("groups");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
int input_channels = in_dims[1];
int output_channels = filter_dims[0];
PADDLE_ENFORCE(in_dims.size() == 4 || in_dims.size() == 5,
"Conv intput should be 4-D or 5-D tensor.");
......@@ -45,9 +43,13 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE_EQ(
paddings.size(), strides.size(),
"Conv paddings dimension and Conv strides dimension should be the same.");
int input_channels = in_dims[1];
PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");
int output_channels = filter_dims[0];
PADDLE_ENFORCE_EQ(
output_channels % groups, 0,
"The number of output channels should be divided by groups.");
......
......@@ -13,19 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/cos_sim_functor.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/platform/for_range.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class CosSimKernel : public framework::OpKernel<T> {
......@@ -41,28 +37,25 @@ class CosSimKernel : public framework::OpKernel<T> {
out_x_norm->mutable_data<T>(context.GetPlace());
out_y_norm->mutable_data<T>(context.GetPlace());
// convert Tensor to Eigen Tensor
int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0];
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
auto z = EigenVector<T>::Flatten(*out_z);
auto x_norm = EigenVector<T>::Flatten(*out_x_norm);
auto y_norm = EigenVector<T>::Flatten(*out_y_norm);
// compute
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto row_along = Eigen::array<int, 1>({{1}});
x_norm.device(place) = x.square().sum(row_along).sqrt();
y_norm.device(place) = y.square().sum(row_along).sqrt();
int cols = framework::product(in_x->dims()) / rows_x;
if (rows_x == rows_y) {
auto xy = (x * y).sum(Eigen::array<int, 1>({{1}}));
z.device(place) = xy / x_norm / y_norm;
math::CosSimFunctor<T, true> functor(
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
out_y_norm->data<T>(), out_z->data<T>(), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()), rows_x);
for_range(functor);
} else {
Eigen::DSizes<int, 2> bcast(rows_x, 1);
auto xy = (x * y.broadcast(bcast)).sum(row_along);
z.device(place) = xy / x_norm / y_norm.broadcast(bcast);
math::CosSimFunctor<T, false> functor(
in_x->data<T>(), in_y->data<T>(), out_x_norm->data<T>(),
out_y_norm->data<T>(), out_z->data<T>(), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()), rows_x);
for_range(functor);
}
}
};
......@@ -81,62 +74,54 @@ class CosSimGradKernel : public framework::OpKernel<T> {
auto* out_grad_y = context.Output<Tensor>(framework::GradVarName("Y"));
auto* in_grad_z = context.Input<Tensor>(framework::GradVarName("Out"));
// convert Tensor to Eigen Tensor
auto x = EigenMatrix<T>::Reshape(*in_x, 1);
auto y = EigenMatrix<T>::Reshape(*in_y, 1);
auto z = EigenMatrix<T>::Reshape(*in_z, 1);
auto x_norm = EigenMatrix<T>::Reshape(*in_x_norm, 1);
auto y_norm = EigenMatrix<T>::Reshape(*in_y_norm, 1);
auto dz = EigenMatrix<T>::Reshape(*in_grad_z, 1);
// compute gradident
int rows_x = in_x->dims()[0];
int rows_y = in_y->dims()[0];
int cols = framework::product(in_x->dims()) / rows_x;
Eigen::DSizes<int, 2> bcast_cols(1, cols);
auto z_bcast = z.broadcast(bcast_cols);
auto dz_bcast = dz.broadcast(bcast_cols);
auto x_snorm_bcast = x_norm.square().eval().broadcast(bcast_cols);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
if (rows_x == rows_y) {
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_cols);
auto norm_prod_bcast = (x_norm * y_norm).eval().broadcast(bcast_cols);
// compute dx
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
auto grad = y / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
dx.device(place) = dz_bcast * grad;
math::CosSimGradFunctor<T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()),
rows_x);
for_range(functor);
}
// compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenMatrix<T>::Reshape(*out_grad_y, 1);
auto grad = x / norm_prod_bcast - z_bcast * y / y_snorm_bcast;
dy.device(place) = dz_bcast * grad;
math::CosSimGradFunctor<T> functor(
in_y_norm->data<T>(), in_x_norm->data<T>(), in_y->data<T>(),
in_x->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_y->mutable_data<T>(context.GetPlace()), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()),
rows_x);
for_range(functor);
}
} else {
Eigen::DSizes<int, 2> bcast_rows(rows_x, 1);
Eigen::DSizes<int, 2> bcast_rows_cols(rows_x, cols);
auto y_bcast = y.broadcast(bcast_rows);
auto y_snorm_bcast = y_norm.square().eval().broadcast(bcast_rows_cols);
auto norm_prod_bcast = (x_norm * y_norm.eval().broadcast(bcast_rows))
.eval()
.broadcast(bcast_cols);
// compute dx
if (out_grad_x) {
out_grad_x->mutable_data<T>(context.GetPlace());
auto dx = EigenMatrix<T>::Reshape(*out_grad_x, 1);
auto grad = y_bcast / norm_prod_bcast - z_bcast * x / x_snorm_bcast;
dx.device(place) = dz_bcast * grad;
math::CosSimDxFunctor<T> functor(
in_x_norm->data<T>(), in_y_norm->data<T>(), in_x->data<T>(),
in_y->data<T>(), in_z->data<T>(), in_grad_z->data<T>(),
out_grad_x->mutable_data<T>(context.GetPlace()), cols);
platform::ForRange<DeviceContext> for_range(
static_cast<const DeviceContext&>(context.device_context()),
rows_x);
for_range(functor);
}
// compute dy
if (out_grad_y) {
out_grad_y->mutable_data<T>(context.GetPlace());
auto dy = EigenVector<T>::Flatten(*out_grad_y);
auto grad = x / norm_prod_bcast - z_bcast * y_bcast / y_snorm_bcast;
dy.device(place) = (dz_bcast * grad).sum(Eigen::array<int, 1>({{0}}));
math::SetConstant<DeviceContext, T> set_zero;
auto& dev_ctx = context.template device_context<DeviceContext>();
set_zero(dev_ctx, out_grad_y, static_cast<T>(0));
math::CosSimDyFunctor<DeviceContext, T> functor;
functor(dev_ctx, in_x_norm->data<T>(), in_y_norm->data<T>(),
in_x->data<T>(), in_y->data<T>(), in_z->data<T>(),
in_grad_z->data<T>(), static_cast<size_t>(rows_x),
static_cast<size_t>(cols), out_grad_y->data<T>());
}
}
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/detection_output_op.h"
namespace paddle {
namespace operators {
class DetectionOutputOpMaker : public framework::OpProtoAndCheckerMaker {
public:
DetectionOutputOpMaker(OpProto* proto, OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Loc",
"(Tensor) The input tensor of detection_output operator."
"The input predict locations"
"The format of input tensor is kNCHW. Where K is priorbox point "
"numbers,"
"N is How many boxes are there on each point, "
"C is 4, H and W both are 1.");
AddInput("Conf",
"(Tensor) The input tensor of detection_output operator."
"The input priorbox confidence."
"The format of input tensor is kNCHW. Where K is priorbox point "
"numbers,"
"N is How many boxes are there on each point, "
"C is the number of classes, H and W both are 1.");
AddInput("PriorBox",
"(Tensor) The input tensor of detection_output operator."
"The format of input tensor is the position and variance "
"of the boxes");
AddOutput("Out",
"(Tensor) The output tensor of detection_output operator.");
AddAttr<int>("background_label_id", "(int), The background class index.");
AddAttr<int>("num_classes", "(int), The number of the classification.");
AddAttr<float>("nms_threshold",
"(float), The Non-maximum suppression threshold.");
AddAttr<float>("confidence_threshold",
"(float), The classification confidence threshold.");
AddAttr<int>("top_k", "(int), The bbox number kept of the layer’s output.");
AddAttr<int>("nms_top_k",
"(int), The bbox number kept of the NMS’s output.");
AddComment(R"DOC(
detection output for SSD(single shot multibox detector)
Apply the NMS to the output of network and compute the predict
bounding box location. The output’s shape of this layer could
be zero if there is no valid bounding box.
)DOC");
}
};
class DetectionOutputOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Loc"),
"Input(X) of DetectionOutputOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Conf"),
"Input(X) of DetectionOutputOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("PriorBox"),
"Input(X) of DetectionOutputOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of DetectionOutputOp should not be null.");
std::vector<int64_t> output_shape({1, 7});
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(detection_output, ops::DetectionOutputOp,
ops::DetectionOutputOpMaker);
REGISTER_OP_CPU_KERNEL(
detection_output,
ops::DetectionOutputKernel<paddle::platform::CPUDeviceContext, float>,
ops::DetectionOutputKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/detection_output_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
detection_output,
ops::DetectionOutputKernel<paddle::platform::CUDADeviceContext, float>,
ops::DetectionOutputKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/framework/tensor.h"
#include "paddle/operators/math/detection_util.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/softmax.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
inline void transpose_fun(const framework::ExecutionContext& context,
const framework::Tensor& src,
framework::Tensor* dst) {
int input_nums = src.dims()[0];
int offset = 0;
for (int j = 0; j < input_nums; ++j) {
framework::Tensor in_p_tensor = src.Slice(j, j + 1);
std::vector<int64_t> shape_vec(
{in_p_tensor.dims()[0], in_p_tensor.dims()[1], in_p_tensor.dims()[3],
in_p_tensor.dims()[4], in_p_tensor.dims()[2]});
framework::DDim shape(framework::make_ddim(shape_vec));
framework::Tensor in_p_tensor_transpose;
in_p_tensor_transpose.mutable_data<T>(shape, context.GetPlace());
std::vector<int> shape_axis({0, 1, 3, 4, 2});
math::Transpose<DeviceContext, T, 5> trans5;
trans5(context.template device_context<DeviceContext>(), in_p_tensor,
&in_p_tensor_transpose, shape_axis);
auto dst_stride = framework::stride(dst->dims());
auto src_stride = framework::stride(in_p_tensor_transpose.dims());
StridedMemcpy<T>(context.device_context(), in_p_tensor_transpose.data<T>(),
src_stride, in_p_tensor_transpose.dims(), dst_stride,
dst->data<T>() + offset);
offset += in_p_tensor_transpose.dims()[4] * src_stride[4];
}
}
template <typename DeviceContext, typename T>
class DetectionOutputKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_loc = context.Input<framework::Tensor>("Loc");
const framework::Tensor* in_conf = context.Input<framework::Tensor>("Conf");
const framework::Tensor* in_priorbox =
context.Input<framework::Tensor>("PriorBox");
auto* out = context.Output<framework::Tensor>("Out");
int num_classes = context.template Attr<int>("num_classes");
int top_k = context.template Attr<int>("top_k");
int nms_top_k = context.template Attr<int>("nms_top_k");
int background_label_id = context.template Attr<int>("background_label_id");
float nms_threshold = context.template Attr<float>("nms_threshold");
float confidence_threshold =
context.template Attr<float>("confidence_threshold");
size_t batch_size = in_conf->dims()[1];
int conf_sum_size = in_conf->numel();
// for softmax
std::vector<int64_t> conf_shape_softmax_vec(
{conf_sum_size / num_classes, num_classes});
framework::DDim conf_shape_softmax(
framework::make_ddim(conf_shape_softmax_vec));
// for knchw => nhwc
std::vector<int64_t> loc_shape_vec({1, in_loc->dims()[1], in_loc->dims()[3],
in_loc->dims()[4],
in_loc->dims()[2] * in_loc->dims()[0]});
std::vector<int64_t> conf_shape_vec(
{1, in_conf->dims()[1], in_conf->dims()[3], in_conf->dims()[4],
in_conf->dims()[2] * in_conf->dims()[0]});
framework::DDim loc_shape(framework::make_ddim(loc_shape_vec));
framework::DDim conf_shape(framework::make_ddim(conf_shape_vec));
framework::Tensor loc_tensor;
framework::Tensor conf_tensor;
loc_tensor.mutable_data<T>(loc_shape, context.GetPlace());
conf_tensor.mutable_data<T>(conf_shape, context.GetPlace());
// for cpu
framework::Tensor loc_cpu;
framework::Tensor conf_cpu;
framework::Tensor priorbox_cpu;
const T* priorbox_data = in_priorbox->data<T>();
transpose_fun<DeviceContext, T>(context, *in_loc, &loc_tensor);
transpose_fun<DeviceContext, T>(context, *in_conf, &conf_tensor);
conf_tensor.Resize(conf_shape_softmax);
math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), &conf_tensor,
&conf_tensor);
T* loc_data = loc_tensor.data<T>();
T* conf_data = conf_tensor.data<T>();
if (platform::is_gpu_place(context.GetPlace())) {
loc_cpu.mutable_data<T>(loc_tensor.dims(), platform::CPUPlace());
framework::CopyFrom(loc_tensor, platform::CPUPlace(),
context.device_context(), &loc_cpu);
loc_data = loc_cpu.data<T>();
conf_cpu.mutable_data<T>(conf_tensor.dims(), platform::CPUPlace());
framework::CopyFrom(conf_tensor, platform::CPUPlace(),
context.device_context(), &conf_cpu);
conf_data = conf_cpu.data<T>();
priorbox_cpu.mutable_data<T>(in_priorbox->dims(), platform::CPUPlace());
framework::CopyFrom(*in_priorbox, platform::CPUPlace(),
context.device_context(), &priorbox_cpu);
priorbox_data = priorbox_cpu.data<T>();
}
// get decode bboxes
size_t num_priors = in_priorbox->numel() / 8;
std::vector<std::vector<operators::math::BBox<T>>> all_decoded_bboxes;
for (size_t n = 0; n < batch_size; ++n) {
std::vector<operators::math::BBox<T>> decoded_bboxes;
for (size_t i = 0; i < num_priors; ++i) {
size_t prior_offset = i * 8;
size_t loc_pred_offset = n * num_priors * 4 + i * 4;
std::vector<math::BBox<T>> prior_bbox_vec;
math::GetBBoxFromPriorData<T>(priorbox_data + prior_offset, 1,
prior_bbox_vec);
std::vector<std::vector<T>> prior_bbox_var;
math::GetBBoxVarFromPriorData<T>(priorbox_data + prior_offset, 1,
prior_bbox_var);
std::vector<T> loc_pred_data;
for (size_t j = 0; j < 4; ++j)
loc_pred_data.push_back(*(loc_data + loc_pred_offset + j));
math::BBox<T> bbox = math::DecodeBBoxWithVar<T>(
prior_bbox_vec[0], prior_bbox_var[0], loc_pred_data);
decoded_bboxes.push_back(bbox);
}
all_decoded_bboxes.push_back(decoded_bboxes);
}
std::vector<std::map<size_t, std::vector<size_t>>> all_indices;
int num_kept = math::GetDetectionIndices<T>(
conf_data, num_priors, num_classes, background_label_id, batch_size,
confidence_threshold, nms_top_k, nms_threshold, top_k,
all_decoded_bboxes, &all_indices);
if (num_kept <= 0) {
std::vector<int64_t> out_shape_vec({0, 0});
framework::DDim out_shape(framework::make_ddim(out_shape_vec));
out->Resize(out_shape);
return;
}
std::vector<int64_t> out_shape_vec({num_kept, 7});
framework::DDim out_shape(framework::make_ddim(out_shape_vec));
out->mutable_data<T>(out_shape, context.GetPlace());
framework::Tensor out_cpu;
T* out_data = out->data<T>();
if (platform::is_gpu_place(context.GetPlace())) {
out_cpu.mutable_data<T>(out->dims(), platform::CPUPlace());
out_data = out_cpu.data<T>();
}
math::GetDetectionOutput<T>(conf_data, num_kept, num_priors, num_classes,
batch_size, all_indices, all_decoded_bboxes,
out_data);
if (platform::is_gpu_place(context.GetPlace())) {
framework::CopyFrom(out_cpu, platform::CUDAPlace(),
context.device_context(), out);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
......@@ -70,7 +71,7 @@ class GRUKernel : public framework::OpKernel<T> {
}
int frame_size = hidden_dims[1];
math::hl_gru_value<T> gru_value;
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
......@@ -89,6 +90,10 @@ class GRUKernel : public framework::OpKernel<T> {
}
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType(
context.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType(
context.Attr<std::string>("gate_activation"));
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
......@@ -101,9 +106,8 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.gate_value = gate_t.data<T>();
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
gru_value.prev_out_value = gru_value.output_value;
}
......@@ -170,12 +174,12 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_hidden_grad.set_lod(batch_hidden->lod());
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
math::hl_gru_value<T> gru_value;
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
math::hl_gru_grad<T> gru_grad;
math::GRUMetaGrad<T> gru_grad;
if (weight_grad) {
gru_grad.gate_weight_grad =
weight_grad->mutable_data<T>(context.GetPlace());
......@@ -189,6 +193,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
auto batch_starts = batch_hidden_grad.lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto active_node = math::detail::GetActivationType(
context.Attr<std::string>("activation"));
auto active_gate = math::detail::GetActivationType(
context.Attr<std::string>("gate_activation"));
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
......@@ -219,9 +227,8 @@ class GRUGradKernel : public framework::OpKernel<T> {
}
math::GRUUnitGradFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node,
active_gate);
}
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
......
......@@ -38,7 +38,7 @@ class LoadOp : public framework::OperatorBase {
out_var_name);
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
framework::DeserializeFromStream(fin, tensor);
DeserializeFromStream(fin, tensor);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
......
......@@ -9,13 +9,14 @@ if(WITH_GPU)
nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS device_context)
nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context)
nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context tensor)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context math_function)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context tensor)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context)
nv_library(unpooling SRCS unpooling.cc unpooling.cu DEPS device_context)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function)
nv_library(cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
......@@ -23,13 +24,14 @@ else()
cc_library(cross_entropy SRCS cross_entropy.cc DEPS device_context)
cc_library(pooling SRCS pooling.cc DEPS device_context)
cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function)
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(vol2col SRCS vol2col.cc DEPS device_context tensor)
cc_library(context_project SRCS context_project.cc DEPS device_context math_function)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context tensor)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(maxouting SRCS maxouting.cc DEPS device_context)
cc_library(unpooling SRCS unpooling.cc DEPS device_context)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
cc_library(cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/cos_sim_functor.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct CosSimDyFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const {
for (size_t row_id = 0; row_id < rows; ++row_id) {
auto xy_norm_prod = x_norm[row_id] * y_norm[0];
auto dz_data = dz[row_id];
auto z_data = z[row_id];
auto* x_data = x + cols * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto y_norm_square = y_norm[0] * y_norm[0];
auto reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
dy[i] += dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
}
}
}
};
template struct CosSimDyFunctor<platform::CPUDeviceContext, float>;
template struct CosSimDyFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/cos_sim_functor.h"
#include "paddle/platform/cuda_helper.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
__global__ void CosSimDyKernel(const T* x_norm, const T* y_norm, const T* x,
const T* y, const T* z, const T* dz,
const size_t rows, const size_t cols, T* dy) {
int grid_size = blockDim.x * gridDim.x;
T y_norm_data = y_norm[0];
for (int row_id = blockIdx.x * blockDim.x + threadIdx.x; row_id < rows;
row_id += grid_size) {
T xy_norm_prod = x_norm[row_id] * y_norm_data;
T dz_data = dz[row_id];
T z_data = z[row_id];
const T* x_data = x + cols * row_id;
T reciprocal_xy_norm_prod = 1 / xy_norm_prod;
T y_norm_square = y_norm_data * y_norm_data;
T reciprocal_y_norm_square = 1 / y_norm_square;
for (size_t i = 0; i < cols; ++i) {
T dy_data = dz_data * (x_data[i] * reciprocal_xy_norm_prod -
z_data * y[i] * reciprocal_y_norm_square);
platform::CudaAtomicAdd(dy + i, dy_data);
}
}
}
template <typename T>
struct CosSimDyFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& ctx, const T* x_norm,
const T* y_norm, const T* x, const T* y, const T* z,
const T* dz, const size_t rows, const size_t cols,
T* dy) const {
const int block_size = 512;
dim3 threads(block_size, 1);
dim3 grid(1, (rows + block_size - 1) / block_size);
CosSimDyKernel<T><<<grid, threads, 0, ctx.stream()>>>(
x_norm, y_norm, x, y, z, dz, rows, cols, dy);
}
};
template struct CosSimDyFunctor<platform::CUDADeviceContext, float>;
template struct CosSimDyFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <stdlib.h>
#include "paddle/platform/device_context.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T, bool same_row>
struct CosSimFunctor {
CosSimFunctor(const T* x, const T* y, T* x_norm, T* y_norm, T* z, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto* x = x_ + cols_ * row_id;
T xx = 0, xy = 0, yy = 0;
if (same_row) {
auto* y = y_ + cols_ * row_id;
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
tep_x = x[i];
tep_y = y[i];
xx += tep_x * tep_x;
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = sqrt(xx);
yy = sqrt(yy);
y_norm_[row_id] = yy;
x_norm_[row_id] = xx;
z_[row_id] = xy / (xx * yy);
} else { // This can be wrote in a better way.
T tep_x, tep_y;
for (size_t i = 0; i < cols_; ++i) {
tep_x = x[i];
tep_y = y_[i];
xx += tep_x * tep_x;
yy += tep_y * tep_y;
xy += tep_x * tep_y;
}
xx = sqrt(xx);
yy = sqrt(yy);
if (row_id == 0) y_norm_[0] = yy;
x_norm_[row_id] = xx;
z_[row_id] = xy / (xx * yy);
}
}
T* x_norm_;
T* y_norm_;
const T* x_;
const T* y_;
T* z_;
const size_t cols_;
};
template <typename T>
struct CosSimGradFunctor {
CosSimGradFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dx, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dx_(dx),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto x_norm_square = x_norm_[row_id] * x_norm_[row_id];
auto xy_norm_prod = x_norm_[row_id] * y_norm_[row_id];
auto dz = dz_[row_id];
auto z = z_[row_id];
auto* dx = dx_ + cols_ * row_id;
auto* x = x_ + cols_ * row_id;
auto* y = y_ + cols_ * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto reciprocal_x_norm_square = 1 / x_norm_square;
for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y[i] * reciprocal_xy_norm_prod -
z * x[i] * reciprocal_x_norm_square);
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};
template <typename T>
struct CosSimDxFunctor {
CosSimDxFunctor(const T* x_norm, const T* y_norm, const T* x, const T* y,
const T* z, const T* dz, T* dx, int cols)
: x_norm_(x_norm),
y_norm_(y_norm),
x_(x),
y_(y),
z_(z),
dz_(dz),
dx_(dx),
cols_(static_cast<size_t>(cols)) {}
inline HOSTDEVICE void operator()(size_t row_id) const {
auto xy_norm_prod = x_norm_[row_id] * y_norm_[0];
auto dz = dz_[row_id];
auto z = z_[row_id];
auto* x = x_ + cols_ * row_id;
auto reciprocal_xy_norm_prod = 1 / xy_norm_prod;
auto x_norm_square = x_norm_[row_id] * x_norm_[row_id];
auto* dx = dx_ + cols_ * row_id;
auto reciprocal_x_norm_square = 1 / x_norm_square;
for (size_t i = 0; i < cols_; ++i) {
dx[i] = dz * (y_[i] * reciprocal_xy_norm_prod -
z * x[i] * reciprocal_x_norm_square);
}
}
const T* x_norm_;
const T* y_norm_;
const T* x_;
const T* y_;
const T* z_;
const T* dz_;
T* dx_;
const size_t cols_;
};
template <typename DeviceContext, typename T>
struct CosSimDyFunctor {
void operator()(const DeviceContext& ctx, const T* x_norm, const T* y_norm,
const T* x, const T* y, const T* z, const T* dz,
const size_t rows, const size_t cols, T* dy) const;
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -28,7 +28,7 @@ template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
T r_value_update_gate;
T r_value_reset_gate;
T r_value_reset_output;
......@@ -56,7 +56,7 @@ template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
activation_mode_t active_node) {
ActivationType active_node) {
T r_value_update_gate;
T r_value_frame_state;
T r_prev_out = 0;
......@@ -83,7 +83,7 @@ template <class OpResetOutput, typename T>
void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
#ifdef __AVX__
__m256 r_value_update_gate;
__m256 r_value_reset_gate;
......@@ -113,7 +113,7 @@ template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
activation_mode_t active_node) {
ActivationType active_node) {
#ifdef __AVX__
__m256 r_value_update_gate;
__m256 r_value_frame_state;
......@@ -140,9 +140,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput op_reset_output,
hl_gru_value<T> value, int frame_size,
int batch_size,
activation_mode_t active_gate) {
GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_gate) {
for (int b = 0; b < batch_size; b++) {
if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_reset_output(
......@@ -164,9 +163,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output,
hl_gru_value<T> value, int frame_size,
int batch_size,
activation_mode_t active_node) {
GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node) {
for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
......@@ -191,7 +189,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size,
activation_mode_t active_node) {
ActivationType active_node) {
T r_update_gate_value;
T r_update_gate_grad;
T r_frame_state_value;
......@@ -232,7 +230,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
T r_update_gate_value;
T r_update_gate_grad;
T r_reset_gate_value;
......@@ -277,7 +275,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size,
activation_mode_t active_node) {
ActivationType active_node) {
#ifdef __AVX__
__m256 r_update_gate_value;
__m256 r_update_gate_grad;
......@@ -320,7 +318,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
#ifdef __AVX__
__m256 r_update_gate_value;
__m256 r_update_gate_grad;
......@@ -364,9 +362,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad op_state_grad,
hl_gru_value<T> value, hl_gru_grad<T> grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_node) {
ActivationType active_node) {
for (int b = 0; b < batch_size; b++) {
if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_state_grad(
......@@ -393,9 +391,9 @@ inline void backward_state_grad(OpStateGrad op_state_grad,
template <class OpResetGrad, typename T>
inline void backward_reset_grad(OpResetGrad op_reset_grad,
hl_gru_value<T> value, hl_gru_grad<T> grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
for (int b = 0; b < batch_size; b++) {
if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_reset_grad(
......
......@@ -19,8 +19,6 @@ limitations under the License. */
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h"
#include <glog/logging.h>
namespace paddle {
namespace operators {
namespace math {
......@@ -35,7 +33,7 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size,
int batch_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
......@@ -74,7 +72,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
int batch_size,
activation_mode_t active_node) {
ActivationType active_node) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......@@ -111,7 +109,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size, int batch_size,
activation_mode_t active_node) {
ActivationType active_node) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......@@ -159,7 +157,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size, int batch_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......
......@@ -30,7 +30,7 @@ class gru_resetOutput {
public:
HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate,
T &prev_out, T &value_reset_output,
activation_mode_t act_gate) {
ActivationType act_gate) {
value_update_gate = activation(value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = prev_out * value_reset_gate;
......@@ -43,7 +43,7 @@ class gru_resetOutput {
HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &value_reset_gate, __m256 &prev_out,
__m256 &value_reset_output,
activation_mode_t act_gate) {
ActivationType act_gate) {
value_update_gate = activation(value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate);
......@@ -57,7 +57,7 @@ class gru_finalOutput {
public:
HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state,
T &prev_out, T &value_output,
activation_mode_t act_input) {
ActivationType act_input) {
value_frame_state = activation(value_frame_state, act_input);
value_output = prev_out - (value_update_gate * prev_out) +
(value_update_gate * value_frame_state);
......@@ -69,8 +69,7 @@ class gru_finalOutput {
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &value_frame_state, __m256 &prev_out,
__m256 &value_output,
activation_mode_t act_input) {
__m256 &value_output, ActivationType act_input) {
value_frame_state = activation(value_frame_state, act_input);
value_output = _mm256_add_ps(
_mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)),
......@@ -89,7 +88,7 @@ class gru_stateGrad {
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &value_frame_state, T &grad_frame_state,
T &value_prev_out, T &grad_prev_out,
T &grad_output, activation_mode_t act_input) {
T &grad_output, ActivationType act_input) {
grad_update_gate = (grad_output * value_frame_state);
grad_update_gate -= (grad_output * value_prev_out);
grad_prev_out -= (grad_output * value_update_gate);
......@@ -107,7 +106,7 @@ class gru_stateGrad {
__m256 &value_frame_state,
__m256 &grad_frame_state, __m256 &value_prev_out,
__m256 &grad_prev_out, __m256 &grad_output,
activation_mode_t act_input) {
ActivationType act_input) {
grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state);
grad_update_gate = _mm256_sub_ps(
grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out));
......@@ -128,7 +127,7 @@ class gru_resetGrad {
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &value_reset_gate, T &grad_reset_gate,
T &value_prev_out, T &grad_prev_out,
T &grad_reset_output, activation_mode_t act_gate) {
T &grad_reset_output, ActivationType act_gate) {
grad_reset_gate = (grad_reset_output * value_prev_out);
grad_prev_out += (grad_reset_output * value_reset_gate);
grad_update_gate =
......@@ -144,7 +143,7 @@ class gru_resetGrad {
__m256 &grad_update_gate, __m256 &value_reset_gate,
__m256 &grad_reset_gate, __m256 &value_prev_out,
__m256 &grad_prev_out, __m256 &grad_reset_output,
activation_mode_t act_gate) {
ActivationType act_gate) {
grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out);
grad_prev_out = _mm256_add_ps(
grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate));
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <map>
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct BBox {
BBox(T x_min, T y_min, T x_max, T y_max)
: x_min(x_min),
y_min(y_min),
x_max(x_max),
y_max(y_max),
is_difficult(false) {}
BBox() {}
T get_width() const { return x_max - x_min; }
T get_height() const { return y_max - y_min; }
T get_center_x() const { return (x_min + x_max) / 2; }
T get_center_y() const { return (y_min + y_max) / 2; }
T get_area() const { return get_width() * get_height(); }
// coordinate of bounding box
T x_min;
T y_min;
T x_max;
T y_max;
// whether difficult object (e.g. object with heavy occlusion is difficult)
bool is_difficult;
};
// KNCHW ==> NHWC
// template <typename T>
template <typename T>
void GetBBoxFromPriorData(const T* prior_data, const size_t num_bboxes,
std::vector<BBox<T>>& bbox_vec);
template <typename T>
void GetBBoxVarFromPriorData(const T* prior_data, const size_t num,
std::vector<std::vector<T>>& var_vec);
template <typename T>
BBox<T> DecodeBBoxWithVar(BBox<T>& prior_bbox,
const std::vector<T>& prior_bbox_var,
const std::vector<T>& loc_pred_data);
template <typename T1, typename T2>
bool SortScorePairDescend(const std::pair<T1, T2>& pair1,
const std::pair<T1, T2>& pair2);
template <typename T>
bool SortScorePairDescend(const std::pair<T, BBox<T>>& pair1,
const std::pair<T, BBox<T>>& pair2);
template <typename T>
T jaccard_overlap(const BBox<T>& bbox1, const BBox<T>& bbox2);
template <typename T>
void ApplyNmsFast(const std::vector<BBox<T>>& bboxes, const T* conf_score_data,
size_t class_idx, size_t top_k, T conf_threshold,
T nms_threshold, size_t num_priors, size_t num_classes,
std::vector<size_t>* indices);
template <typename T>
int GetDetectionIndices(
const T* conf_data, const size_t num_priors, const size_t num_classes,
const size_t background_label_id, const size_t batch_size,
const T conf_threshold, const size_t nms_top_k, const T nms_threshold,
const size_t top_k,
const std::vector<std::vector<BBox<T>>>& all_decoded_bboxes,
std::vector<std::map<size_t, std::vector<size_t>>>* all_detection_indices);
template <typename T>
BBox<T> ClipBBox(const BBox<T>& bbox);
template <typename T>
void GetDetectionOutput(
const T* conf_data, const size_t num_kept, const size_t num_priors,
const size_t num_classes, const size_t batch_size,
const std::vector<std::map<size_t, std::vector<size_t>>>& all_indices,
const std::vector<std::vector<BBox<T>>>& all_decoded_bboxes, T* out_data);
template <typename T>
void GetBBoxFromPriorData(const T* prior_data, const size_t num_bboxes,
std::vector<BBox<T>>& bbox_vec) {
size_t out_offset = bbox_vec.size();
bbox_vec.resize(bbox_vec.size() + num_bboxes);
for (size_t i = 0; i < num_bboxes; ++i) {
BBox<T> bbox;
bbox.x_min = *(prior_data + i * 8);
bbox.y_min = *(prior_data + i * 8 + 1);
bbox.x_max = *(prior_data + i * 8 + 2);
bbox.y_max = *(prior_data + i * 8 + 3);
bbox_vec[out_offset + i] = bbox;
}
}
template <typename T>
void GetBBoxVarFromPriorData(const T* prior_data, const size_t num,
std::vector<std::vector<T>>& var_vec) {
size_t out_offset = var_vec.size();
var_vec.resize(var_vec.size() + num);
for (size_t i = 0; i < num; ++i) {
std::vector<T> var;
var.push_back(*(prior_data + i * 8 + 4));
var.push_back(*(prior_data + i * 8 + 5));
var.push_back(*(prior_data + i * 8 + 6));
var.push_back(*(prior_data + i * 8 + 7));
var_vec[out_offset + i] = var;
}
}
template <typename T>
BBox<T> DecodeBBoxWithVar(BBox<T>& prior_bbox,
const std::vector<T>& prior_bbox_var,
const std::vector<T>& loc_pred_data) {
T prior_bbox_width = prior_bbox.get_width();
T prior_bbox_height = prior_bbox.get_height();
T prior_bbox_center_x = prior_bbox.get_center_x();
T prior_bbox_center_y = prior_bbox.get_center_y();
T decoded_bbox_center_x =
prior_bbox_var[0] * loc_pred_data[0] * prior_bbox_width +
prior_bbox_center_x;
T decoded_bbox_center_y =
prior_bbox_var[1] * loc_pred_data[1] * prior_bbox_height +
prior_bbox_center_y;
T decoded_bbox_width =
std::exp(prior_bbox_var[2] * loc_pred_data[2]) * prior_bbox_width;
T decoded_bbox_height =
std::exp(prior_bbox_var[3] * loc_pred_data[3]) * prior_bbox_height;
BBox<T> decoded_bbox;
decoded_bbox.x_min = decoded_bbox_center_x - decoded_bbox_width / 2;
decoded_bbox.y_min = decoded_bbox_center_y - decoded_bbox_height / 2;
decoded_bbox.x_max = decoded_bbox_center_x + decoded_bbox_width / 2;
decoded_bbox.y_max = decoded_bbox_center_y + decoded_bbox_height / 2;
return decoded_bbox;
}
template <typename T1, typename T2>
bool SortScorePairDescend(const std::pair<T1, T2>& pair1,
const std::pair<T1, T2>& pair2) {
return pair1.first > pair2.first;
}
template <typename T>
T jaccard_overlap(const BBox<T>& bbox1, const BBox<T>& bbox2) {
if (bbox2.x_min > bbox1.x_max || bbox2.x_max < bbox1.x_min ||
bbox2.y_min > bbox1.y_max || bbox2.y_max < bbox1.y_min) {
return 0.0;
} else {
T inter_x_min = std::max(bbox1.x_min, bbox2.x_min);
T inter_y_min = std::max(bbox1.y_min, bbox2.y_min);
T interX_max = std::min(bbox1.x_max, bbox2.x_max);
T interY_max = std::min(bbox1.y_max, bbox2.y_max);
T inter_width = interX_max - inter_x_min;
T inter_height = interY_max - inter_y_min;
T inter_area = inter_width * inter_height;
T bbox_area1 = bbox1.get_area();
T bbox_area2 = bbox2.get_area();
return inter_area / (bbox_area1 + bbox_area2 - inter_area);
}
}
template <typename T>
void ApplyNmsFast(const std::vector<BBox<T>>& bboxes, const T* conf_score_data,
size_t class_idx, size_t top_k, T conf_threshold,
T nms_threshold, size_t num_priors, size_t num_classes,
std::vector<size_t>* indices) {
std::vector<std::pair<T, size_t>> scores;
for (size_t i = 0; i < num_priors; ++i) {
size_t conf_offset = i * num_classes + class_idx;
if (conf_score_data[conf_offset] > conf_threshold)
scores.push_back(std::make_pair(conf_score_data[conf_offset], i));
}
std::stable_sort(scores.begin(), scores.end(),
SortScorePairDescend<T, size_t>);
if (top_k > 0 && top_k < scores.size()) scores.resize(top_k);
while (scores.size() > 0) {
const size_t idx = scores.front().second;
bool keep = true;
for (size_t i = 0; i < indices->size(); ++i) {
if (keep) {
const size_t saved_idx = (*indices)[i];
T overlap = jaccard_overlap<T>(bboxes[idx], bboxes[saved_idx]);
keep = overlap <= nms_threshold;
} else {
break;
}
}
if (keep) indices->push_back(idx);
scores.erase(scores.begin());
}
}
template <typename T>
int GetDetectionIndices(
const T* conf_data, const size_t num_priors, const size_t num_classes,
const size_t background_label_id, const size_t batch_size,
const T conf_threshold, const size_t nms_top_k, const T nms_threshold,
const size_t top_k,
const std::vector<std::vector<BBox<T>>>& all_decoded_bboxes,
std::vector<std::map<size_t, std::vector<size_t>>>* all_detection_indices) {
int total_keep_num = 0;
for (size_t n = 0; n < batch_size; ++n) {
const std::vector<BBox<T>>& decoded_bboxes = all_decoded_bboxes[n];
size_t num_detected = 0;
std::map<size_t, std::vector<size_t>> indices;
size_t conf_offset = n * num_priors * num_classes;
for (size_t c = 0; c < num_classes; ++c) {
if (c == background_label_id) continue;
ApplyNmsFast<T>(decoded_bboxes, conf_data + conf_offset, c, nms_top_k,
conf_threshold, nms_threshold, num_priors, num_classes,
&(indices[c]));
num_detected += indices[c].size();
}
if (top_k > 0 && num_detected > top_k) {
// std::vector<pair<T,T>> score_index_pairs;
std::vector<std::pair<T, std::pair<size_t, size_t>>> score_index_pairs;
for (size_t c = 0; c < num_classes; ++c) {
const std::vector<size_t>& label_indices = indices[c];
for (size_t i = 0; i < label_indices.size(); ++i) {
size_t idx = label_indices[i];
score_index_pairs.push_back(
std::make_pair((conf_data + conf_offset)[idx * num_classes + c],
std::make_pair(c, idx)));
}
}
std::sort(score_index_pairs.begin(), score_index_pairs.end(),
SortScorePairDescend<T, std::pair<size_t, size_t>>);
score_index_pairs.resize(top_k);
std::map<size_t, std::vector<size_t>> new_indices;
for (size_t i = 0; i < score_index_pairs.size(); ++i) {
size_t label = score_index_pairs[i].second.first;
size_t idx = score_index_pairs[i].second.second;
new_indices[label].push_back(idx);
}
all_detection_indices->push_back(new_indices);
total_keep_num += top_k;
} else {
all_detection_indices->push_back(indices);
total_keep_num += num_detected;
}
}
return total_keep_num;
}
template <typename T>
BBox<T> ClipBBox(const BBox<T>& bbox) {
T one = static_cast<T>(1.0);
T zero = static_cast<T>(0.0);
BBox<T> clipped_bbox;
clipped_bbox.x_min = std::max(std::min(bbox.x_min, one), zero);
clipped_bbox.y_min = std::max(std::min(bbox.y_min, one), zero);
clipped_bbox.x_max = std::max(std::min(bbox.x_max, one), zero);
clipped_bbox.y_max = std::max(std::min(bbox.y_max, one), zero);
return clipped_bbox;
}
template <typename T>
void GetDetectionOutput(
const T* conf_data, const size_t num_kept, const size_t num_priors,
const size_t num_classes, const size_t batch_size,
const std::vector<std::map<size_t, std::vector<size_t>>>& all_indices,
const std::vector<std::vector<BBox<T>>>& all_decoded_bboxes, T* out_data) {
size_t count = 0;
for (size_t n = 0; n < batch_size; ++n) {
for (std::map<size_t, std::vector<size_t>>::const_iterator it =
all_indices[n].begin();
it != all_indices[n].end(); ++it) {
size_t label = it->first;
const std::vector<size_t>& indices = it->second;
const std::vector<BBox<T>>& decoded_bboxes = all_decoded_bboxes[n];
for (size_t i = 0; i < indices.size(); ++i) {
size_t idx = indices[i];
size_t conf_offset = n * num_priors * num_classes + idx * num_classes;
out_data[count * 7] = n;
out_data[count * 7 + 1] = label;
out_data[count * 7 + 2] = (conf_data + conf_offset)[label];
BBox<T> clipped_bbox = ClipBBox<T>(decoded_bboxes[idx]);
out_data[count * 7 + 3] = clipped_bbox.x_min;
out_data[count * 7 + 4] = clipped_bbox.y_min;
out_data[count * 7 + 5] = clipped_bbox.x_max;
out_data[count * 7 + 6] = clipped_bbox.y_max;
++count;
}
}
}
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -21,9 +21,9 @@ namespace math {
template <typename T>
struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context,
hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
#ifndef __NVCC__
if (value.prev_out_value) {
math::gemm<platform::CPUDeviceContext, T>(
......@@ -51,10 +51,10 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
template <typename T>
struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
#ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node);
......
......@@ -21,9 +21,9 @@ namespace math {
template <typename T>
struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context,
hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
......@@ -88,10 +88,10 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
template <typename T>
struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
......
......@@ -11,7 +11,7 @@ limitations under the License. */
#pragma once
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
......@@ -19,9 +19,8 @@ namespace paddle {
namespace operators {
namespace math {
// TODO(guosheng): refine code style in gru_compute
template <typename T>
struct hl_gru_value {
struct GRUMetaValue {
T *gate_weight;
T *state_weight;
T *gate_value;
......@@ -31,7 +30,7 @@ struct hl_gru_value {
};
template <typename T>
struct hl_gru_grad {
struct GRUMetaGrad {
T *gate_weight_grad;
T *state_weight_grad;
T *gate_grad;
......@@ -42,18 +41,18 @@ struct hl_gru_grad {
template <typename DeviceContext, typename T>
struct GRUUnitFunctor {
static void compute(const DeviceContext &context, hl_gru_value<T> value,
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate);
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
};
template <typename DeviceContext, typename T>
struct GRUUnitGradFunctor {
static void compute(const DeviceContext &context, hl_gru_value<T> value,
hl_gru_grad<T> grad, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate);
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
};
} // namespace math
......
......@@ -22,14 +22,6 @@ namespace paddle {
namespace operators {
namespace math {
typedef enum {
HL_ACTIVATION_SIGMOID = 0,
HL_ACTIVATION_RELU = 1,
HL_ACTIVATION_TANH = 2,
HL_ACTIVATION_LINEAR = 3,
HL_ACTIVATION_END
} activation_mode_t;
template <class T>
struct LstmMetaValue {
T *gate_value;
......@@ -54,20 +46,6 @@ struct LstmMetaGrad {
T *check_og_grad;
};
inline activation_mode_t ActiveType(const std::string &type) {
if (type == "sigmoid") {
return HL_ACTIVATION_SIGMOID;
} else if (type == "relu") {
return HL_ACTIVATION_RELU;
} else if (type == "tanh") {
return HL_ACTIVATION_TANH;
} else if (type == "linear" || type == "identity" || type == "") {
return HL_ACTIVATION_LINEAR;
} else {
PADDLE_THROW("Do not support activation type.");
}
}
template <typename DeviceContext, typename T>
class LstmUnitFunctor {
public:
......
......@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/selected_rows_functor.h"
#include <set>
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
......@@ -179,6 +181,118 @@ template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CPUDeviceContext, int64_t>;
// This is a separated namespace for manipulate SelectedRows typed
// data. Like merge duplicated rows, adding two SelectedRows etc.
//
// Another group of functors is called "scatter updates", which means
// use SelectedRows to update a dense tensor with different Ops, like
// add or mul.
namespace scatter {
size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
return std::find(rows.begin(), rows.end(), value) - rows.begin();
}
template <typename T>
struct MergeAdd<platform::CPUDeviceContext, T> {
framework::SelectedRows operator()(const platform::CPUDeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1];
out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CPUDeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
for (size_t i = 0; i < input_rows.size(); i++) {
size_t out_i = FindPos(merge_rows, input_rows[i]);
for (int64_t j = 0; j < input_width; j++) {
out_data[out_i * input_width + j] += input_data[i * input_width + j];
}
}
return out;
}
};
template struct MergeAdd<platform::CPUDeviceContext, float>;
template struct MergeAdd<platform::CPUDeviceContext, double>;
template struct MergeAdd<platform::CPUDeviceContext, int>;
template struct MergeAdd<platform::CPUDeviceContext, int64_t>;
template <typename T>
struct UpdateToTensor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const ScatterOps& op, const framework::SelectedRows& input1,
framework::Tensor* input2) {
auto in1_height = input1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
auto& in1_value = input1.value();
auto& in1_rows = input1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.data<T>();
auto* input2_data = input2->data<T>();
// FIXME(typhoonzero): use macro fix the below messy code.
switch (op) {
case ScatterOps::ASSIGN:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::ADD:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] +=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::SUB:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] -=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::SUBBY:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j] -
input2_data[in1_rows[i] * in1_row_numel + j];
break;
case ScatterOps::MUL:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] *=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::DIV:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] /=
in1_data[i * in1_row_numel + j];
break;
case ScatterOps::DIVBY:
INLINE_FOR2(in1_rows.size(), in1_row_numel)
input2_data[in1_rows[i] * in1_row_numel + j] =
in1_data[i * in1_row_numel + j] /
input2_data[in1_rows[i] * in1_row_numel + j];
break;
}
}
};
} // namespace scatter
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <set>
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/cuda_helper.h"
......@@ -222,6 +224,157 @@ template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, float>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, double>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int>;
template struct SelectedRowsAddToTensor<platform::CUDADeviceContext, int64_t>;
namespace scatter {
template <typename T, int block_size>
__global__ void MergeAddKernel(const T* input, const int64_t* input_rows,
T* out, const int64_t* out_rows,
size_t out_rows_size, int64_t row_numel) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
__shared__ size_t out_idx;
if (tid == 0) {
for (size_t i = 0; i < out_rows_size; i++) {
if (input_rows[ty] == out_rows[i]) {
out_idx = i;
}
}
}
__syncthreads();
input += ty * row_numel;
out += out_idx * row_numel;
for (int index = tid; index < row_numel; index += block_size) {
paddle::platform::CudaAtomicAdd(out + index, input[index]);
}
}
template <typename T>
struct MergeAdd<platform::CUDADeviceContext, T> {
framework::SelectedRows operator()(const platform::CUDADeviceContext& context,
const framework::SelectedRows& input) {
framework::SelectedRows out;
auto input_rows = input.rows();
std::set<int64_t> row_set(input_rows.begin(), input_rows.end());
std::vector<int64_t> merge_rows(row_set.begin(), row_set.end());
auto input_width = input.value().dims()[1];
out.set_rows(merge_rows);
out.set_height(input.height());
out.mutable_value()->mutable_data<T>(
framework::make_ddim(
{static_cast<int64_t>(merge_rows.size()), input_width}),
context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> constant_functor;
constant_functor(context, out.mutable_value(), 0.0);
auto* out_data = out.mutable_value()->data<T>();
auto* input_data = input.value().data<T>();
const int block_size = 256;
dim3 threads(block_size, 1);
dim3 grid1(1, input_rows.size());
MergeAddKernel<
T, 256><<<grid1, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(input_data, input.rows().data(), out_data,
out.rows().data(), out.rows().size(),
input_width);
return out;
}
};
template struct MergeAdd<platform::CUDADeviceContext, float>;
template struct MergeAdd<platform::CUDADeviceContext, double>;
template struct MergeAdd<platform::CUDADeviceContext, int>;
template struct MergeAdd<platform::CUDADeviceContext, int64_t>;
template <typename T, int block_size>
__global__ void UpdateToTensorKernel(const T* selected_rows,
const int64_t* rows, const ScatterOps& op,
T* tensor_out, int64_t row_numel) {
const int ty = blockIdx.y;
int tid = threadIdx.x;
selected_rows += ty * row_numel;
tensor_out += rows[ty] * row_numel;
// FIXME(typhoonzero): use macro fix the below messy code.
switch (op) {
case ScatterOps::ASSIGN:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index];
}
break;
case ScatterOps::ADD:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] += selected_rows[index];
}
break;
case ScatterOps::SUB:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] -= selected_rows[index];
}
break;
case ScatterOps::SUBBY:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index] - tensor_out[index];
}
break;
case ScatterOps::MUL:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] *= selected_rows[index];
}
break;
case ScatterOps::DIV:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] /= selected_rows[index];
}
break;
case ScatterOps::DIVBY:
for (int index = tid; index < row_numel; index += block_size) {
tensor_out[index] = selected_rows[index] / tensor_out[index];
}
break;
}
}
template <typename T>
struct UpdateToTensor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const ScatterOps& op, const framework::SelectedRows& input1,
framework::Tensor* input2) {
// NOTE: Use SelectedRowsAddToTensor for better performance
// no additional MergeAdd called.
MergeAdd<platform::CUDADeviceContext, T> merge_func;
auto merged_in1 = merge_func(context, input1);
auto in1_height = merged_in1.height();
auto in2_dims = input2->dims();
PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]);
auto& in1_value = merged_in1.value();
auto& in1_rows = merged_in1.rows();
int64_t in1_row_numel = in1_value.numel() / in1_rows.size();
PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height);
auto* in1_data = in1_value.template data<T>();
auto* in2_data = input2->data<T>();
dim3 threads(platform::PADDLE_CUDA_NUM_THREADS, 1);
dim3 grid(1, in1_rows.size());
UpdateToTensorKernel<T, platform::PADDLE_CUDA_NUM_THREADS><<<
grid, threads, 0, context.stream()>>>(in1_data, in1_rows.data(), op,
in2_data, in1_row_numel);
}
};
} // namespace scatter
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -12,9 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/platform/device_context.h"
#define INLINE_FOR2(sizei, sizej) \
for (int64_t i = 0; i < sizei; i++) \
for (int64_t j = 0; j < sizej; j++)
namespace paddle {
namespace operators {
namespace math {
......@@ -52,6 +57,78 @@ struct SelectedRowsAddToTensor {
framework::Tensor* input2);
};
namespace scatter {
// functors for manuplating SelectedRows data
template <typename DeviceContext, typename T>
struct MergeAdd {
// unary functor, merge by adding duplicated rows in
// the input SelectedRows object.
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input);
};
template <typename DeviceContext, typename T>
struct Add {
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2) {
framework::SelectedRows out;
out.set_rows(input1.rows());
out.set_height(input1.height());
out.mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 + e_in2;
return out;
}
};
template <typename DeviceContext, typename T>
struct Mul {
// multiply two SelectedRows
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const framework::SelectedRows& input2) {
framework::SelectedRows out;
out.set_rows(input1.rows());
out.set_height(input1.height());
out.mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
auto e_in2 = framework::EigenVector<T>::Flatten(input2.value());
e_out.device(*context.eigen_device()) = e_in1 * e_in2;
return out;
}
// multiply scalar to SelectedRows
framework::SelectedRows operator()(const DeviceContext& context,
const framework::SelectedRows& input1,
const T input2) {
framework::SelectedRows out;
out.set_rows(input1.rows());
out.set_height(input1.height());
out.mutable_value()->mutable_data<T>(input1.value().dims(),
context.GetPlace());
auto e_out = framework::EigenVector<T>::Flatten(*(out.mutable_value()));
auto e_in1 = framework::EigenVector<T>::Flatten(input1.value());
e_out.device(*context.eigen_device()) = input2 * e_in1;
return out;
}
};
enum class ScatterOps { ASSIGN, ADD, SUB, SUBBY, MUL, DIV, DIVBY };
// out = seleted_rows_in / tensor
template <typename DeviceContext, typename T>
struct UpdateToTensor {
void operator()(const DeviceContext& context, const ScatterOps& op,
const framework::SelectedRows& input1,
framework::Tensor* input2);
};
} // namespace scatter
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -116,9 +116,9 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
auto height = dout_tensor.dims()[0];
auto slice = dx_tensor.Slice(0, static_cast<int>(height));
framework::CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx, &slice);
if (dx_tensor.dims()[0] < height) {
if (dx_tensor.dims()[0] > height) {
auto rest_tensor = dx_tensor.Slice(
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
static_cast<int>(height), static_cast<int>(dx_tensor.dims()[0]));
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
}
}
......
......@@ -37,11 +37,11 @@ class SumKernel : public framework::OpKernel<T> {
bool in_place = out_var == in_vars[0];
if (out_var->IsType<framework::LoDTensor>()) {
auto *out = context.Output<Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto *out = context.Output<LoDTensor>("Out");
if (!in_place) {
out->mutable_data<T>(context.GetPlace());
}
auto result = EigenVector<T>::Flatten(*out);
if (!in_place) {
math::SetConstant<DeviceContext, T> constant_functor;
constant_functor(context.template device_context<DeviceContext>(), out,
......
......@@ -130,9 +130,9 @@ class ReadFromArrayOp : public ArrayOp {
auto &x_array = x->Get<framework::LoDTensorArray>();
auto *out = scope.FindVar(Output("Out"));
PADDLE_ENFORCE(out != nullptr, "Out must be set");
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
size_t offset = GetOffset(scope, place);
if (offset < x_array.size()) {
auto *out_tensor = out->GetMutable<framework::LoDTensor>();
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
......
......@@ -52,6 +52,14 @@ class CPUDeviceContext : public DeviceContext {
std::unique_ptr<Eigen::DefaultDevice> eigen_device_;
};
template <typename Place>
struct DefaultDeviceContextType;
template <>
struct DefaultDeviceContextType<platform::CPUPlace> {
using TYPE = CPUDeviceContext;
};
#ifdef PADDLE_WITH_CUDA
class EigenCudaStreamDevice;
......@@ -90,6 +98,11 @@ class CUDADeviceContext : public DeviceContext {
cublasHandle_t cublas_handle_;
};
template <>
struct DefaultDeviceContextType<platform::CUDAPlace> {
using TYPE = CUDADeviceContext;
};
class CUDNNDeviceContext : public CUDADeviceContext {
public:
explicit CUDNNDeviceContext(CUDAPlace place);
......@@ -125,6 +138,13 @@ class DeviceContextPool {
/*! \brief Return handle of single device context. */
const platform::DeviceContext* Get(const platform::Place& place);
template <typename Place>
const typename DefaultDeviceContextType<Place>::TYPE* GetByPlace(
const Place& place) {
return reinterpret_cast<
const typename DefaultDeviceContextType<Place>::TYPE*>(Get(place));
}
private:
static DeviceContextPool* pool;
constexpr static int LEFT_SHIFT = 8;
......
......@@ -62,7 +62,7 @@ struct ForRange<CUDADeviceContext> {
template <typename Function>
inline void operator()(Function func) const {
constexpr size_t num_threads = 1024;
constexpr int num_threads = 1024;
int block_size = limit_ <= num_threads ? limit_ : num_threads;
int grid_size = (limit_ + num_threads - 1) / num_threads;
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#pragma once
#include <iostream>
#include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"
namespace paddle {
......@@ -64,5 +64,31 @@ bool places_are_same_class(const Place &, const Place &);
std::ostream &operator<<(std::ostream &, const Place &);
template <typename Visitor>
struct PlaceVisitorWrapper
: public boost::static_visitor<typename Visitor::result_type> {
const Visitor &visitor_;
explicit PlaceVisitorWrapper(const Visitor &visitor) : visitor_(visitor) {}
typename Visitor::result_type operator()(const CPUPlace &cpu) const {
return visitor_(cpu);
}
typename Visitor::result_type operator()(const CUDAPlace &cuda) const {
#ifdef PADDLE_WITH_CUDA
return visitor_(cuda);
#else
PADDLE_THROW("Paddle is not compiled with CUDA. Cannot visit cuda device");
return typename Visitor::result_type();
#endif
}
};
template <typename Visitor>
typename Visitor::result_type VisitPlace(const Place &place,
const Visitor &visitor) {
return boost::apply_visitor(PlaceVisitorWrapper<Visitor>(visitor), place);
}
} // namespace platform
} // namespace paddle
......@@ -3,6 +3,9 @@ if(WITH_PYTHON)
SRCS pybind.cc exception.cc protobuf.cc const_value.cc
DEPS pybind python backward proto_desc paddle_memory executor prune init
${GLOB_OP_LIB})
if(NOT APPLE AND NOT ANDROID)
target_link_libraries(paddle_pybind rt)
endif(NOT APPLE AND NOT ANDROID)
endif(WITH_PYTHON)
if(WITH_DOC)
......
......@@ -77,10 +77,10 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
} else if (paddle::platform::is_cpu_place(tensor.place())) {
dst_tensor = tensor;
}
return py::buffer_info(
dst_tensor.mutable_data<CUR_TYPE>(dst_tensor.place()),
sizeof(CUR_TYPE), py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(dst_tensor.dims()), dims_outside, strides);
return py::buffer_info(dst_tensor.data<CUR_TYPE>(), sizeof(CUR_TYPE),
py::format_descriptor<CUR_TYPE>::format(),
(size_t)framework::arity(dst_tensor.dims()),
dims_outside, strides);
} else {
constexpr bool less = I + 1 < std::tuple_size<std::tuple<ARGS...>>::value;
return CastToPyBufferImpl<less, I + 1, ARGS...>()(tensor);
......
......@@ -178,7 +178,7 @@ EOF
# run paddle version to install python packages first
RUN apt-get update &&\
${NCCL_DEPS}\
apt-get install -y wget python-pip dmidecode && pip install -U pip && \
apt-get install -y wget python-pip dmidecode python-tk && pip install -U pip && \
pip install /*.whl; apt-get install -f -y && \
apt-get clean -y && \
rm -f /*.whl && \
......
......@@ -71,9 +71,7 @@ function threads_config() {
# auto set OMP_NUM_THREADS and MKL_NUM_THREADS
# according to trainer_count and total processors
# only when MKL enabled
if [ "@WITH_MKL@" == "OFF" ]; then
return 0
fi
# auto set OPENBLAS_NUM_THREADS when do not use MKL
processors=`grep "processor" /proc/cpuinfo|sort -u|wc -l`
trainers=`grep -Eo 'trainer_count.[0-9]+' <<< "$@" |grep -Eo '[0-9]+'|xargs`
if [ -z $trainers ]; then
......@@ -83,12 +81,19 @@ function threads_config() {
if [ $threads -eq 0 ]; then
threads=1
fi
if [ -z "$OMP_NUM_THREADS" ]; then
export OMP_NUM_THREADS=$threads
fi
if [ -z "$MKL_NUM_THREADS" ]; then
export MKL_NUM_THREADS=$threads
if [ "@WITH_MKL@" == "ON" ]; then
if [ -z "$OMP_NUM_THREADS" ]; then
export OMP_NUM_THREADS=$threads
fi
if [ -z "$MKL_NUM_THREADS" ]; then
export MKL_NUM_THREADS=$threads
fi
else
if [ -z "$OPENBLAS_NUM_THREADS" ]; then
export OPENBLAS_NUM_THREADS=$threads
fi
fi
}
PADDLE_CONF_HOME="$HOME/.config/paddle"
......@@ -150,7 +155,7 @@ fi
case "$1" in
"train")
threads_config $@
# echo $OMP_NUM_THREADS $MKL_NUM_THREADS
# echo $OMP_NUM_THREADS $MKL_NUM_THREADS $OPENBLAS_NUM_THREADS
${DEBUGGER} $PADDLE_BIN_PATH/paddle_trainer ${@:2}
;;
"merge_model")
......
......@@ -44,7 +44,7 @@ __all__ = ['train', 'test', 'valid']
DATA_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/102flowers.tgz'
LABEL_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/imagelabels.mat'
SETID_URL = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/setid.mat'
DATA_MD5 = '52808999861908f626f3c1f4e79d11fa'
DATA_MD5 = '33bfc11892f1e405ca193ae9a9f2a118'
LABEL_MD5 = 'e0620be6f572b9609742df49c70aed4d'
SETID_MD5 = 'a5357ecc9cb78c4bef273ce3793fc85c'
# In official 'readme', tstid is the flag of test data
......
......@@ -36,7 +36,7 @@ def __read_gflags_from_env__():
"""
import sys
import core
read_env_flags = ['use_pinned_memory']
read_env_flags = ['use_pinned_memory', 'check_nan_inf']
if core.is_compile_gpu():
read_env_flags.append('fraction_of_gpu_memory_to_use')
core.init_gflags([sys.argv[0]] +
......
......@@ -5,14 +5,17 @@ import collections
__all__ = ['append_backward']
def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
end_idx=None):
def _rename_arg_(op_descs, old_name, new_name, begin_idx=None, end_idx=None):
"""
Traverse all ops in op_descs[begin_idx : end_idx],
if any op has inputs/outputs named "old_name", rename it as 'new_name'
"""
if begin_idx is None:
begin_idx = 0
if end_idx is None:
end_idx = len(op_desc_list)
end_idx = len(op_descs)
for i in range(begin_idx, end_idx):
op_desc = op_desc_list[i]
op_desc = op_descs[i]
if isinstance(op_desc, tuple):
op_desc = op_desc[0]
op_desc.rename_input(old_name, new_name)
......@@ -20,6 +23,9 @@ def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
def _create_op_desc_(op_type, inputs, outputs, attrs):
"""
Create a C++ OpDesc object with specified inputs, outputs and attributes.
"""
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for para, args in inputs.iteritems():
......@@ -34,9 +40,12 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
return op_desc
def _infer_var_data_type_(var_name, block):
grad_var = block.desc.find_var(var_name.encode("ascii"))
fwd_name = _strip_grad_suffix_(var_name.encode("ascii"))
def _infer_var_data_type_(grad_var_name, block):
"""
Infer the data type of given grad variable
"""
grad_var = block.desc.find_var(grad_var_name.encode("ascii"))
fwd_name = _strip_grad_suffix_(grad_var_name.encode("ascii"))
if block.desc.has_var_recursive(fwd_name):
fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
grad_var.set_dtype(fwd_var.dtype())
......@@ -45,6 +54,9 @@ def _infer_var_data_type_(var_name, block):
def _all_in_set_(cands, s):
"""
Test if all elements of 'cands' are in set 's'
"""
for c in cands:
if not c in s:
return False
......@@ -52,18 +64,29 @@ def _all_in_set_(cands, s):
def _strip_grad_suffix_(name):
"""
Strip the grad suffix from the given varibale name
e.g. x@GRAD ==> x
y@GRAD@RENAME@1 ==> y
"""
pos = name.find(core.grad_var_suffix())
return name[:pos] if pos != -1 else name
def _append_grad_suffix_(name):
"""
Append grad suffix to the given variable name
e.g. x ==> x@GRAD
"""
return name + core.grad_var_suffix()
def _addup_repetitive_outputs_(op_descs):
# In backward part, an variable my be the output of more than one ops.
# In this case, the variable should be the accumulation of all the outputs.
# We adopt adding `sum_op`s to implement the accumulate.
"""
In backward part, an variable may be the output of more than one ops.
In this case, the variable should be the accumulation of all the outputs.
`sum_op`s are added to implement the accumulate.
"""
pending_sum_ops = []
var_rename_count = collections.defaultdict(int)
renamed_vars = collections.defaultdict(list)
......@@ -109,6 +132,12 @@ def _addup_repetitive_outputs_(op_descs):
def _remove_no_grad_branch_(op_descs, no_grad_set):
"""
Remove unnecessary grad ops
A grad op can be removed in two cases:
1. all outputs of the grad op are in 'no_grad_set'
2. (TODO) all grad inputs of the grad op are in 'no_grad_set'
"""
# Remove ops whose outputs are all in no_grad_dict
op_descs = filter(
lambda op_desc: not _all_in_set_(op_desc.output_arg_names(), no_grad_set),
......@@ -133,6 +162,21 @@ def _append_backward_ops_(target,
no_grad_dict,
grad_to_var,
callback=None):
"""
Create all grad ops, and insert them into given block
Args:
target(Variable): the target variable of forward pass
block(Block): the block where forward ops are
target_block(Block): the block which is going to hold new generated grad ops
no_grad_dict(dict):
key(int) block index
val(set) a set of varibale names. These varibales have no gradient
grad_to_var(dict)(output argument):
key(str): grad variable name
val(str): corresponding forward variable name
"""
# grad_op_descs holds created grad_op, and will be appended to target_block
grad_op_descs = []
program = block.program
for op in reversed(block.ops):
......@@ -145,6 +189,7 @@ def _append_backward_ops_(target,
no_grad_dict, grad_to_var, callback)
grad_sub_block_list.append(grad_sub_block.desc)
# Getting op's corresponding grad_op
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, no_grad_dict[block.idx], grad_sub_block_list)
grad_op_descs.extend(grad_op_desc)
......@@ -170,6 +215,20 @@ def _append_backward_ops_(target,
def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
"""
Create new variables required by backward pass.
Args:
block(Block): the block where new variables will be created
start_op_idx(int): Only variables required by ops in block.ops[start_op_idx : ] will be created
grad_to_var(dict):
key(str): grad variable name
val(str): corresponding forward variable name
In most cases, this dict is generated by _append_backward_ops_()
grad_info_map(dict)(output argument):
key(str): forward variable name
val(tuple): a tuple of (str, int), str is the corresponding grad name, int is the block index
"""
for op_idx in range(start_op_idx, block.desc.op_size()):
op_desc = block.desc.op(op_idx)
if op_desc.has_attr("sub_block"):
......@@ -197,18 +256,18 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
def append_backward(loss, parameter_list=None, no_grad_set=None):
"""
Create and add gradient Operators in BlockDesc to compute
gradients of `loss` for parameters in parameter_list
:param loss: an variable generated by cost function.
:type loss: Variable
:param no_grad_dict: variable that should not create gradient
:type no_grad_dict: set
:param parameter_list: parameters that need to compute gradient and
update to optimize the lost.
:type: list
:return: list of (parameters, gradients) pair.
:rtype: list[Variable]
Append backward part to main_program
Args:
loss(Variable): The variable generated by cost function.
parameter_list(list): Parameters that need to be updated by optimizer.
If None, it means all parameters need to be updated.
no_grad_set(set): Variables that have no gradients in Block 0.
If None, the set will be generated inside the function and
contains all variables with `step_gradient=True` from all blocks.
Return:
(list[Variable]): list of (parameters, gradients) pair.
"""
assert isinstance(loss, framework.Variable)
......
......@@ -3,7 +3,7 @@ import core
import numpy
import six.moves as six
from framework import Variable
from framework import Variable, default_main_program
__all__ = ['DataFeeder']
......@@ -53,12 +53,16 @@ class DataToLoDTensorConverter(object):
class DataFeeder(object):
def __init__(self, feed_list, place):
def __init__(self, feed_list, place, program=None):
self.feed_dtypes = []
self.feed_names = []
self.feed_shapes = []
self.feed_lod_level = []
if program is None:
program = default_main_program()
for each_var in feed_list:
if isinstance(each_var, basestring):
each_var = program.block(0).var(each_var)
if not isinstance(each_var, Variable):
raise TypeError("Feed list should contain a list of variable")
self.feed_dtypes.append(each_var.dtype)
......
import numpy as np
import contextlib
from framework import Program, default_main_program
from . import core
from framework import Program, default_main_program, Parameter, Variable
__all__ = ['Executor', 'g_scope']
__all__ = ['Executor', 'global_scope', 'scope_guard', 'switch_scope']
g_scope = core.Scope()
def global_scope():
return g_scope
def switch_scope(scope):
global g_scope
ex = g_scope
g_scope = scope
return ex
@contextlib.contextmanager
def scope_guard(scope):
ex = switch_scope(scope)
yield
switch_scope(ex)
def as_numpy(tensor):
if isinstance(tensor, list):
return [as_numpy(t) for t in tensor]
......@@ -117,7 +136,7 @@ class Executor(object):
raise TypeError()
if scope is None:
scope = g_scope
scope = global_scope()
program = program.clone()
global_block = program.global_block()
......
......@@ -188,7 +188,7 @@ def save_inference_model(dirname,
raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable):
feeded_var_names = [feeded_var_names]
target_vars = [target_vars]
else:
if not (bool(target_vars) and all(
isinstance(var, Variable) for var in target_vars)):
......
......@@ -170,7 +170,7 @@ def main():
exe.run(fluid.default_startup_program())
embedding_param = fluid.g_scope.find_var(embedding_name).get_tensor()
embedding_param = fluid.global_scope().find_var(embedding_name).get_tensor()
embedding_param.set(
load_parameter(conll05.get_embedding(), word_dict_len, word_dim), place)
......
import paddle.v2.fluid as fluid
__all__ = ['many_times', 'prog_scope']
def many_times(times):
def __impl__(fn):
def __fn__(*args, **kwargs):
for _ in range(times):
fn(*args, **kwargs)
return __fn__
return __impl__
def prog_scope():
def __impl__(fn):
def __fn__(*args, **kwargs):
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
fn(*args, **kwargs)
return __fn__
return __impl__
import unittest
import numpy as np
from op_test import OpTest
from paddle.v2.fluid import core
from paddle.v2.fluid.op import Operator
class TestAdamOp1(OpTest):
......@@ -176,5 +178,124 @@ def adam_step(inputs, attributes):
return param_out, moment1_out, moment2_out
def adam_step_sparse(inputs, attributes, height, rows, row_numel, np_grad):
'''
Simulate one step of the adam optimizer
:param inputs: dict of inputs
:param attributes: dict of attributes
:return tuple: tuple of output param, moment1, moment2,
beta1 power accumulator and beta2 power accumulator
'''
param = inputs['Param']
# grad = inputs['Grad']
moment1 = inputs['Moment1']
moment2 = inputs['Moment2']
lr = inputs['LearningRate']
beta1_pow = inputs['Beta1Pow']
beta2_pow = inputs['Beta2Pow']
beta1 = attributes['beta1']
beta2 = attributes['beta2']
epsilon = attributes['epsilon']
moment1_out = np.zeros(shape=[height, row_numel])
moment2_out = np.zeros(shape=[height, row_numel])
param_out = np.zeros(shape=[height, row_numel])
for idx, row_id in enumerate(rows):
moment1_out[row_id] = beta1 * moment1[row_id] + (1 - beta1
) * np_grad[idx]
moment2_out[row_id] = beta2 * moment2[row_id] + (
1 - beta2) * np.square(np_grad[idx])
lr_t = lr * np.sqrt(1 - beta2_pow) / (1 - beta1_pow)
param_out[row_id] = param[row_id] - lr_t * (moment1_out[row_id] / (
np.sqrt(moment2_out[row_id]) + epsilon))
return param_out, moment1_out, moment2_out
class TestSparseAdamOp(unittest.TestCase):
def setup(self, scope, place):
beta1 = 0.78
beta2 = 0.836
epsilon = 1e-4
height = 10
rows = [0, 4, 7]
self.rows = rows
row_numel = 12
self.row_numel = row_numel
self.dense_inputs = {
"Param": np.full((height, row_numel), 5.0).astype("float32"),
"Moment1": np.full((height, row_numel), 5.0).astype("float32"),
"Moment2": np.full((height, row_numel), 5.0).astype("float32"),
'Beta1Pow': np.array([beta1**10]).astype("float32"),
'Beta2Pow': np.array([beta2**10]).astype("float32"),
"LearningRate": np.full((1), 2.0).astype("float32")
}
self.attrs = {'epsilon': epsilon, 'beta1': beta1, 'beta2': beta2}
grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(height)
grad_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), row_numel)).astype("float32")
np_array[0, 0] = 2.0
np_array[2, 8] = 4.0
grad_tensor = grad_selected_rows.get_tensor()
grad_tensor.set(np_array, place)
self.sparse_inputs = ["Grad"]
param_out, mom1, mom2 = adam_step_sparse(
self.dense_inputs, self.attrs, height, rows, row_numel, np_array)
self.outputs = {
"ParamOut": param_out,
"Moment1Out": mom1,
"Moment2Out": mom2
}
def check_with_place(self, place):
scope = core.Scope()
self.setup(scope, place)
op_args = dict()
for key, np_array in self.dense_inputs.iteritems():
var = scope.var(key).get_tensor()
var.set(np_array, place)
op_args[key] = key
for s in self.sparse_inputs:
op_args[s] = s
for s in self.outputs:
var = scope.var(s).get_tensor()
var.set(self.outputs[s], place)
op_args[s] = s
for k in self.attrs:
op_args[k] = self.attrs[k]
# create and run sgd operator
adam_op = Operator("adam", **op_args)
adam_op.run(scope, place)
for key, np_array in self.outputs.iteritems():
out_var = scope.var(key).get_tensor()
actual = np.array(out_var)
actual = actual.reshape([actual.size])
np_array = np_array.reshape([np_array.size])
for idx, row_id in enumerate(self.rows):
j = 0
while j < self.row_numel:
pos = row_id * self.row_numel + j
self.assertLess((actual[pos] - np_array[pos]) / actual[pos],
0.00001)
j += 1
def test_sparse_sgd(self):
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
if __name__ == "__main__":
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestUnpoolOp(OpTest):
def setUp(self):
self.op_type = "detection_output"
self.init_test_case()
#loc.shape ((1, 4, 4, 1, 1))
#conf.shape ((1, 4, 2, 1, 1))
loc = np.array([[[[[0.1]], [[0.1]], [[0.1]], [[0.1]]],
[[[0.1]], [[0.1]], [[0.1]], [[0.1]]],
[[[0.1]], [[0.1]], [[0.1]], [[0.1]]],
[[[0.1]], [[0.1]], [[0.1]], [[0.1]]]]])
conf = np.array([[[[[0.1]], [[0.9]]], [[[0.2]], [[0.8]]],
[[[0.3]], [[0.7]]], [[[0.4]], [[0.6]]]]])
priorbox = np.array([
0.1, 0.1, 0.5, 0.5, 0.1, 0.1, 0.2, 0.2, 0.2, 0.2, 0.6, 0.6, 0.1,
0.1, 0.2, 0.2, 0.3, 0.3, 0.7, 0.7, 0.1, 0.1, 0.2, 0.2, 0.4, 0.4,
0.8, 0.8, 0.1, 0.1, 0.2, 0.2
])
output = np.array([
0, 1, 0.68997443, 0.099959746, 0.099959746, 0.50804031, 0.50804031
])
self.inputs = {
'Loc': loc.astype('float32'),
'Conf': conf.astype('float32'),
'PriorBox': priorbox.astype('float32')
}
self.attrs = {
'num_classes': self.num_classes,
'top_k': self.top_k,
'nms_top_k': self.nms_top_k,
'background_label_id': self.background_label_id,
'nms_threshold': self.nms_threshold,
'confidence_threshold': self.confidence_threshold,
}
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self):
self.check_output()
def init_test_case(self):
self.num_classes = 2
self.top_k = 10
self.nms_top_k = 20
self.background_label_id = 0
self.nms_threshold = 0.01
self.confidence_threshold = 0.01
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册