提交 b3cb2e8e 编写于 作者: Y Yu Yang

Merge branch 'develop' of github.com:baidu/Paddle into feature/remove_int_pair

# How to write a new operator
- [Background](#Background)
- [Implementing C++ Types](#Implementing_C++_Types)
- [Defining ProtoMaker](#Defining_ProtoMaker)
- [Defining Operator](#Defining_Operator)
- [Registering Operator](#Registering_Operator)
- [Compilation](#Compilation)
- [Python Binding](#Python_Binding)
- [Unit Tests](#Unit_Tests)
## Background
Here are the base types needed. For details, please refer to the design docs.
- `framework::OperatorBase`: Operator (Op)base class.
- `framework::OpKernel`: Base class for Op computation.
- `framework::OperatorWithKernel`: Inherited from OperatorBase, describing an operator with computation.
- `class OpProtoAndCheckerMaker`: Describes an Operator's input, output, attributes and description, mainly used to interface with Python API.
An operator can be differentiated by whether in has kernel methods. An operator with kernel inherits from `OperatorWithKernel` while the ones without inherit from `OperatorBase`. This tutorial focuses on implementing operators with kernels. In short, an operator includes the following information:
Information | Where is it defined
-------------- | :----------------------
OpProtoMake definition | `.cc`files, Backward Op does not need an OpProtoMake interface.
Op definition | `.cc` files
Kernel implementation | The kernel methods shared between CPU and GPU are defined in `.h` files. CPU-specific kernels live in `.cc` files, while GPU-specific kernels are implemented in `.cu`files.
Registering the Op | Ops are registered in `.cc` files; For Kernel registration, `.cc` files contain the CPU implementation, while `.cu` files contain the GPU implementation.
New Operator implementations are added to the list [paddle/operators](https://github.com/PaddlePaddle/Paddle/tree/develop/paddle/operators), with file names in the format `*_op.h` (if applicable), `*_op.cc`, `*_op.cu` (if applicable).** The system will use the naming scheme to automatically build operators and their corresponding Python extensions. **
Let's take matrix multiplication operator, [MulOp](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc), as an example to introduce the writing of an Operator with Kernel.
## Implementing C++ Types
### 1. Defining Class ProtoMaker
Matrix Multiplication can be written as $Out = X * Y$, meaning that the operation consists of two inputs and pne output.
First, define `ProtoMaker` to describe the Operator's input, output, and additional comments:
```cpp
class MulOpMaker : public framework::OpProtoAndCheckerMaker {
public:
MulOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor), 2D tensor of size (M x K)");
AddInput("Y", "(Tensor), 2D tensor of size (K x N)");
AddOutput("Out", "(Tensor), 2D tensor of size (M x N)");
AddComment(R"DOC(
Two Element Mul Operator.
The equation is: Out = X * Y
)DOC");
}
};
```
[`MulOpMaker`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc#L43)is inherited from`framework::OpProtoAndCheckerMaker`, consisting of 2 variables in the constructor:
- `framework::OpProto` stores Operator input and variable attribute, used for generating Python API interfaces.
- `framework::OpAttrChecker` is used to validate variable attributes.
The constructor utilizes `AddInput`, `AddOutput`, and `AddComment`, so that the corresponding information will be added to `OpProto`.
The code above adds two inputs `X` and `Y` to `MulOp`, an output `Out`, and their corresponding descriptions, in accordance to Paddle's [naming convention](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/name_convention.md).
An additional example [`ScaleOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/scale_op.cc#L37) is implemented as follows:
```cpp
template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of scale operator.").NotInGradient();
AddOutput("Out", "The output tensor of scale operator.").NotInGradient();
AddComment(R"DOC(Scale operator
The equation is: Out = scale*X
)DOC");
AddAttr<AttrType>("scale", "scale of scale operator.").SetDefault(1.0);
}
};
```
There are two changes in this example:
- `AddInput("X","...").NotInGradient()` expresses that input `X` is not involved in `ScaleOp`'s corresponding computation. If an input to an operator is not participating in back-propagation, please explicitly set `.NotInGradient()`.
- `AddAttr<AttrType>("scale", "...").SetDefault(1.0);` adds `scale`constant as an attribute, and sets the default value to 1.0.
### 2. Defining Operator
The following code defines the interface for MulOp:
```cpp
class MulOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
auto dim0 = ctx.Input<Tensor>("X")->dims();
auto dim1 = ctx.Input<Tensor>("Y")->dims();
PADDLE_ENFORCE_EQ(dim0.size(), 2,
"input X(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("X"));
PADDLE_ENFORCE_EQ(dim1.size(), 2,
"input Y(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("Y"));
PADDLE_ENFORCE_EQ(
dim0[1], dim1[0],
"First matrix's width must be equal with second matrix's height.");
ctx.Output<Tensor>("Out")->Resize({dim0[0], dim1[1]});
}
};
```
[`MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/mul_op.cc#L22) is inherited from `OperatorWithKernel`. Its `public` member
```cpp
using framework::OperatorWithKernel::OperatorWithKernel;
```
expresses an operator constructor using base class `OperatorWithKernel`, alternatively written as
```cpp
MulOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
```
`InferShape` interface needs to be re-written.`InferShape` is a constant method and cannot modify Op's member variables, its constant member `const framework::InferShapeContext &ctx` can be used to extract input, output, and attributes. It functions to
- 1). validate and error out early: it checks input data dimensions and types.
- 2). configures the tensor shape in the output.
Usually `OpProtoMaker` and `Op`'s type definitions are written in `.cc` files, which also include the registration methods introduced later.
### 3. Defining OpKernel
`MulKernel` inherits `framework::OpKernel`, which includes the following templates:
- `typename Place` denotes device type. When different devices, namely the CPU and the GPU, share the same kernel, this template needs to be added. If they don't share kernels, this must not be added. An example of a non-sharing kernel is [`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43).
- `typename T` denotes data type, such as `float` or `double`.
`MulKernel` types need to rewrite the interface for `Compute`.
- `Compute` takes one input variable `const framework::ExecutionContext& context`.
- Compared with `InferShapeContext`, `ExecutionContext` includes device types, and can similarly extract input, output, and attribute variables.
- `Compute` implements the computation logics of an `OpKernel`.
`MulKernel`'s implementation of `Compute` is as follows:
```cpp
template <typename Place, typename T>
class MulKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Y = context.Input<Tensor>("Y");
auto* Z = context.Output<Tensor>("Out");
Z->mutable_data<T>(context.GetPlace());
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
math::matmul<Place, T>(*X, false, *Y, false, 1, Z, 0, device_context);
}
};
```
Note that **different devices (CPU, GPU)share an Op definition; whether or not they share the same `OpKernel` depends on whether `Compute` calls functions that support both devices.**
`MulOp`'s CPU and GPU share the same `Kernel`. A non-sharing `OpKernel` example can be seen in [`OnehotCrossEntropyOpKernel`](https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/cross_entropy_op.h#L43).
To ease the writing of `OpKernel` compute, and for reusing code cross-device, `Eigen unsupported Tensor` module is used to implement `Compute` interface. To learn about how the Eigen library is used in PaddlePaddle, please see [usage document](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/howto/dev/use_eigen_cn.md).
This concludes the forward implementation of an operator. Next its operation and kernel need to be registered in a `.cc` file.
The definition of its corresponding backward operator, if applicable, is similar to that of an forward operator. **Note that a backward operator does not include a `ProtoMaker`**.
### 4. Registering Operator
- In `.cc` files, register forward and backward operator classes and the CPU kernel.
```cpp
namespace ops = paddle::operators;
REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker, mul_grad, ops::MulOpGrad);
REGISTER_OP_CPU_KERNEL(mul, ops::MulKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::CPUPlace, float>);
```
In that code block,
- `REGISTER_OP` registers the `ops::MulOp` class, type named `mul`, its type `ProtoMaker` is `ops::MulOpMaker`, registering `ops::MulOpGrad` as `mul_grad`.
- `REGISTER_OP_WITHOUT_GRADIENT` registers an operator without gradient.
- `REGISTER_OP_CPU_KERNEL` registers `ops::MulKernel` class and specialized template types `paddle::platform::CPUPlace` and `float`, which also registers `ops::MulKernel`.
- Registering GPU Kernel in `.cu` files
- Note that if GPU Kernel is implemented using the `Eigen unsupported` module, then on top of `.cu`, a macro definition `#define EIGEN_USE_GPU` is needed, such as
```cpp
// if use Eigen unsupported module before include head files
#define EIGEN_USE_GPU
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(mul_grad,
ops::MulGradKernel<paddle::platform::GPUPlace, float>);
```
### 5. Compilation
Run the following commands to compile.
```
make mul_op
```
## Python Binding
The system will automatically bind to Python and link it to a generated library.
## Unit Tests
Unit tests include comparing a forward operator's implementations on different devices, comparing a backward operator's implementation on different devices, and a scaling test for the backward operator. Here, we introduce the [unit tests for `MulOp`](https://github.com/PaddlePaddle/Paddle/blob/develop/python/paddle/v2/framework/tests/test_mul_op.py).
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/lstm_unit_op.h"
namespace paddle {
namespace operators {
class LstmUnitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of LSTM should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"),
"Input(C_prev) of LSTM should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"),
"Output(C) of LSTM should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"),
"Output(H) of LSTM should not be null.");
auto *x = ctx.Input<framework::Tensor>("X");
auto *c_prev = ctx.Input<framework::Tensor>("C_prev");
PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2.");
PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0],
"Batch size of inputs and states must be equal");
PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4,
"Dimension of FC should equal to prev state * 4");
int b_size = c_prev->dims()[0]; // batch size
int s_dim = c_prev->dims()[1]; // state dim
ctx.Output<framework::LoDTensor>("C")->Resize({b_size, s_dim});
ctx.Output<framework::LoDTensor>("H")->Resize({b_size, s_dim});
}
};
template <typename AttrType>
class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LstmUnitOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "FC input before the non-linear activation.");
AddInput(
"C_prev",
"The cell state tensor of last time-step in the Lstm Unit operator.");
AddOutput("C", "The cell tensor of Lstm Unit operator.");
AddOutput("H", "The hidden state tensor of Lstm Unit operator.");
AddComment(R"DOC(Lstm-Unit Operator
Equation:
i, f, o, j = split(X)
C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j)
H = C * sigm(o)
)DOC");
AddAttr<AttrType>("forget_bias", "The forget bias of Lstm Unit.")
.SetDefault(0.0);
}
};
class LstmUnitGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")),
"Input(C@GRAD) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")),
"Input(H@GRAD) should not be null");
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("X")->dims());
ctx.Output<framework::LoDTensor>(framework::GradVarName("C_prev"))
->Resize(ctx.Input<Tensor>("C_prev")->dims());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker<float>,
lstm_unit_grad, ops::LstmUnitGradOp);
REGISTER_OP_CPU_KERNEL(lstm_unit,
ops::LstmUnitKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
lstm_unit_grad, ops::LstmUnitGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/operators/cross_entropy_op.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/hostdevice.h"
namespace paddle {
namespace operators {
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename Dtype>
__device__ Dtype cuda_sigmoid(const Dtype x) {
return Dtype(1) / (Dtype(1) + exp(-x));
}
template <typename Dtype>
__device__ Dtype cuda_tanh(const Dtype x) {
return Dtype(1 - exp(-2. * x)) / (Dtype(1) + exp(-2. * x));
}
template <typename T>
__global__ void LSTMUnitKernel(const int nthreads, const int dim,
const T* C_prev, const T* X, T* C, T* H,
const T forget_bias) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int n = index / dim;
const int d = index % dim;
const T* X_offset = X + 4 * dim * n;
const T i = cuda_sigmoid(X_offset[d]);
const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
const T o = cuda_sigmoid(X_offset[2 * dim + d]);
const T g = cuda_tanh(X_offset[3 * dim + d]);
const T c_prev = C_prev[index];
const T c = f * c_prev + i * g;
C[index] = c;
const T tanh_c = cuda_tanh(c);
H[index] = o * tanh_c;
}
}
template <typename T>
__global__ void LSTMUnitGradientKernel(const int nthreads, const int dim,
const T* C_prev, const T* X, const T* C,
const T* H, const T* C_diff,
const T* H_diff, T* C_prev_diff,
T* X_diff, const T forget_bias) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int n = index / dim;
const int d = index % dim;
const T* X_offset = X + 4 * dim * n;
T* c_prev_diff = C_prev_diff + index;
T* X_diff_offset = X_diff + 4 * dim * n;
T* i_diff = X_diff_offset + d;
T* f_diff = X_diff_offset + 1 * dim + d;
T* o_diff = X_diff_offset + 2 * dim + d;
T* g_diff = X_diff_offset + 3 * dim + d;
const T i = cuda_sigmoid(X_offset[d]);
const T f = cuda_sigmoid(X_offset[1 * dim + d] + forget_bias);
const T o = cuda_sigmoid(X_offset[2 * dim + d]);
const T g = cuda_tanh(X_offset[3 * dim + d]);
const T c_prev = C_prev[index];
const T c = C[index];
const T tanh_c = cuda_tanh(c);
const T c_term_diff =
C_diff[index] + H_diff[index] * o * (1 - tanh_c * tanh_c);
*c_prev_diff = c_term_diff * f;
*i_diff = c_term_diff * g * i * (1 - i);
*f_diff = c_term_diff * c_prev * f * (1 - f);
*o_diff = H_diff[index] * tanh_c * o * (1 - o);
*g_diff = c_term_diff * i * (1 - g * g);
}
}
template <typename T, typename AttrType = T>
class LstmUnitOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto* x_tensor = ctx.Input<framework::Tensor>("X");
auto* c_prev_tensor = ctx.Input<framework::Tensor>("C_prev");
auto* c_tensor = ctx.Output<framework::Tensor>("C");
auto* h_tensor = ctx.Output<framework::Tensor>("H");
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
int b_size = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
const T* X = x_tensor->data<T>();
const T* C_prev = c_prev_tensor->data<T>();
T* C = c_tensor->mutable_data<T>(ctx.GetPlace());
T* H = h_tensor->mutable_data<T>(ctx.GetPlace());
int block = 512;
int n = b_size * D;
int grid = (n + block - 1) / block;
LSTMUnitKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, forget_bias);
}
};
template <typename T, typename AttrType = T>
class LstmUnitGradOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto x_tensor = ctx.Input<Tensor>("X");
auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
auto c_tensor = ctx.Input<Tensor>("C");
auto h_tensor = ctx.Input<Tensor>("H");
auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H"));
auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C"));
auto xdiff_tensor = ctx.Output<Tensor>(framework::GradVarName("X"));
auto c_prev_diff_tensor =
ctx.Output<Tensor>(framework::GradVarName("C_prev"));
auto* X = x_tensor->data<T>();
auto* C_prev = c_prev_tensor->data<T>();
auto* C = c_tensor->data<T>();
auto* H = h_tensor->data<T>();
auto* H_diff = hdiff_tensor->data<T>();
auto* C_diff = cdiff_tensor->data<T>();
auto* C_prev_diff = c_prev_diff_tensor->mutable_data<T>(ctx.GetPlace());
auto* X_diff = xdiff_tensor->mutable_data<T>(ctx.GetPlace());
int N = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
int block = 512;
int n = N * D;
int grid = (n + block - 1) / block;
LSTMUnitGradientKernel<T><<<grid, block>>>(n, D, C_prev, X, C, H, C_diff,
H_diff, C_prev_diff, X_diff,
forget_bias);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(lstm_unit, ops::LstmUnitOpCUDAKernel<float>);
REGISTER_OP_GPU_KERNEL(lstm_unit_grad, ops::LstmUnitGradOpCUDAKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "glog/logging.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::LoDTensor;
using framework::Tensor;
template <typename T>
inline T sigmoid(T x) {
return 1. / (1. + exp(-x));
}
template <typename T>
inline T tanh(T x) {
return 2. * sigmoid(2. * x) - 1.;
}
template <typename Place, typename T, typename AttrType = T>
class LstmUnitKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto* x_tensor = ctx.Input<framework::Tensor>("X");
auto* c_prev_tensor = ctx.Input<framework::Tensor>("C_prev");
auto* c_tensor = ctx.Output<framework::Tensor>("C");
auto* h_tensor = ctx.Output<framework::Tensor>("H");
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
int b_size = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
T* C = c_tensor->mutable_data<T>(ctx.GetPlace());
T* H = h_tensor->mutable_data<T>(ctx.GetPlace());
const T* X = x_tensor->data<T>();
const T* C_prev = c_prev_tensor->data<T>();
for (int n = 0; n < b_size; ++n) {
for (int d = 0; d < D; ++d) {
const T i = sigmoid(X[d]);
const T f = sigmoid(X[1 * D + d] + forget_bias);
const T o = sigmoid(X[2 * D + d]);
const T g = tanh(X[3 * D + d]);
const T c_prev = C_prev[d];
const T c = f * c_prev + i * g;
C[d] = c;
const T tanh_c = tanh(c);
H[d] = o * tanh_c;
}
C_prev += D;
X += 4 * D;
C += D;
H += D;
}
}
};
template <typename Place, typename T, typename AttrType = T>
class LstmUnitGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
auto x_tensor = ctx.Input<Tensor>("X");
auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
auto c_tensor = ctx.Input<Tensor>("C");
auto h_tensor = ctx.Input<Tensor>("H");
auto hdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("H"));
auto cdiff_tensor = ctx.Input<Tensor>(framework::GradVarName("C"));
auto xdiff_tensor = ctx.Output<Tensor>(framework::GradVarName("X"));
auto c_prev_diff_tensor =
ctx.Output<Tensor>(framework::GradVarName("C_prev"));
auto* X = x_tensor->data<T>();
auto* C_prev = c_prev_tensor->data<T>();
auto* C = c_tensor->data<T>();
auto* H = h_tensor->data<T>();
auto* H_diff = hdiff_tensor->data<T>();
auto* C_diff = cdiff_tensor->data<T>();
auto* C_prev_diff = c_prev_diff_tensor->mutable_data<T>(ctx.GetPlace());
auto* X_diff = xdiff_tensor->mutable_data<T>(ctx.GetPlace());
int N = c_tensor->dims()[0];
int D = c_tensor->dims()[1];
auto forget_bias = static_cast<T>(ctx.Attr<AttrType>("forget_bias"));
for (int n = 0; n < N; ++n) {
for (int d = 0; d < D; ++d) {
T* c_prev_diff = C_prev_diff + d;
T* i_diff = X_diff + d;
T* f_diff = X_diff + 1 * D + d;
T* o_diff = X_diff + 2 * D + d;
T* g_diff = X_diff + 3 * D + d;
const T i = sigmoid(X[d]);
const T f = sigmoid(X[1 * D + d] + forget_bias);
const T o = sigmoid(X[2 * D + d]);
const T g = tanh(X[3 * D + d]);
const T c_prev = C_prev[d];
const T c = C[d];
const T tanh_c = tanh(c);
const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c);
*c_prev_diff = c_term_diff * f;
*i_diff = c_term_diff * g * i * (1 - i);
*f_diff = c_term_diff * c_prev * f * (1 - f);
*o_diff = H_diff[d] * tanh_c * o * (1 - o);
*g_diff = c_term_diff * i * (1 - g * g);
}
C_prev += D;
X += 4 * D;
C += D;
H += D;
C_diff += D;
H_diff += D;
X_diff += 4 * D;
C_prev_diff += D;
}
}
};
} // namespace operators
} // namespace paddle
......@@ -34,13 +34,14 @@ class DeviceContext {
template <typename DeviceType>
DeviceType* get_eigen_device() const;
virtual void Wait() const {}
};
class CPUDeviceContext : public DeviceContext {
public:
CPUDeviceContext();
explicit CPUDeviceContext(CPUPlace place);
virtual ~CPUDeviceContext() {}
Eigen::DefaultDevice* eigen_device() const;
......@@ -59,7 +60,7 @@ class CUDADeviceContext : public DeviceContext {
virtual ~CUDADeviceContext();
/*! \brief Wait for all operations completion in the stream. */
void Wait() const;
void Wait() const override;
/*! \brief Return place in the device context. */
Place GetPlace() const override;
......
......@@ -237,7 +237,13 @@ All parameter, weight, gradient are variables in Paddle.
return Backward(forwardOp, no_grad_vars).release();
})
.def("infer_shape", &OperatorBase::InferShape)
.def("run", &OperatorBase::Run)
.def("run",
[](OperatorBase &self,
const Scope &scope,
const platform::DeviceContext &dev_ctx) {
self.Run(scope, dev_ctx);
dev_ctx.Wait();
})
.def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs",
......
import unittest
import numpy as np
from op_test import OpTest
def sigmoid_np(x):
return 1. / (1. + np.exp(-x))
def tanh_np(x):
return 2 * sigmoid_np(2. * x) - 1.
class LstmUnitTest(OpTest):
def setUp(self):
self.op_type = "lstm_unit"
x_np = np.random.normal(size=(5, 16)).astype("float32")
c_np = np.random.normal(size=(5, 4)).astype("float32")
i_np, f_np, o_np, j_np = np.split(x_np, 4, axis=1)
forget_bias_np = 0.
self.attrs = {'forget_bias': 0.}
new_c = c_np * sigmoid_np(f_np + forget_bias_np) + sigmoid_np(
i_np) * tanh_np(j_np)
new_h = tanh_np(new_c) * sigmoid_np(o_np)
self.inputs = {'X': x_np, 'C_prev': c_np}
self.outputs = {'C': new_c, 'H': new_h}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X', 'C_prev'], ['C', 'H'], max_relative_error=0.01)
if __name__ == "__main__":
unittest.main()
......@@ -7,6 +7,14 @@ class PReluTest(OpTest):
def setUp(self):
self.op_type = "prelu"
x_np = np.random.normal(size=(10, 10)).astype("float32")
for pos, val in np.ndenumerate(x_np):
# Since zero point in prelu is not differentiable, avoid randomize
# zero.
while abs(val) < 1e-3:
x_np[pos] = np.random.normal()
val = x_np[pos]
x_np_sign = np.sign(x_np)
x_np = x_np_sign * np.maximum(x_np, .005)
alpha_np = np.array([.1])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册