未验证 提交 1d7954fc 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #5255 from guoshengCS/add-GRUOp-dev

Add GRU Operator
......@@ -142,7 +142,8 @@ set(DEPS_OPS
nccl_op
sequence_conv_op
lod_rank_table_op
lstm_op)
lstm_op
gru_op)
op_library(cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op)
op_library(cross_entropy_op DEPS cross_entropy)
......@@ -156,6 +157,7 @@ op_library(nccl_op DEPS nccl_common)
endif()
op_library(sequence_conv_op DEPS context_project)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
op_library(dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS net_op tensor_array)
op_library(recurrent_op SRCS recurrent_op.cc DEPS executor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gru_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class GRUOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(%s) of GRUOp should not be null.", "BatchGate");
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"),
"Output(%s) of GRUOp should not be null.",
"BatchResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
"Output(%s) of GRUOp should not be null.", "BatchHidden");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(%s) of GRUOp should not be null.", "Hidden");
auto input_dims = ctx->GetInputDim("Input");
auto weight_dims = ctx->GetInputDim("Weight");
int input_size = input_dims[1];
int frame_size = weight_dims[0];
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUOp.");
PADDLE_ENFORCE_EQ(
weight_dims[1], frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
if (ctx->HasInput("H0")) {
auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size.");
}
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
}
ctx->SetOutputDim("BatchGate", input_dims);
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
ctx->SetOutputDim("Hidden", {input_dims[0], frame_size});
ctx->ShareLoD("Input", "Hidden");
}
};
class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GRUOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(LoDTensor) The first input is a LodTensor, which supports "
"variable-time length input sequence. The underlying tensor in "
"this LoDTenosr is a matrix with shape (T X 3D), where, T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput("H0",
"(Tensor, optional) The initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size.")
.AsDispensable();
AddInput(
"Weight",
"(Tensor) The learnable hidden-hidden weight matrix with shape "
"(D x 3D), where D is the hidden size. The elements continuous in "
"memory can be divided into two parts. The first part are weights of "
"the update gate and reset gate with shape (D x 2D), and the second "
"part are weights of output candidate with shape (D x D).");
AddInput("Bias",
"(Tensor, optional) Bias vector with shape (1 x 3D) concating "
"bias of the update gate, reset gate and output candidate.")
.AsDispensable();
AddOutput("BatchGate",
"(LoDTensor) To compute with batches, sequence data will be "
"reorganized into several successive batches each containing "
"data from the same time step. The LoDTensor BatchGate contains "
"the update gate, reset gate and output candidate values "
"organized in batches. The LoD size is 2. The first LoD contains "
"the batch offsets and the second LoD contains the indexes in "
"the raw sequence data.")
.AsIntermediate();
AddOutput(
"BatchResetHiddenPrev",
"(LoDTensor) The reseted hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.")
.AsIntermediate();
AddOutput(
"BatchHidden",
"(LoDTensor) The hidden state LoDTensor organized in batches. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.")
.AsIntermediate();
AddOutput(
"Hidden",
"(LoDTensor) the hidden state LoDTensor organized in sequences. "
"This LoDTensor is a matrix with shape (T X D) and has the same LoD "
"with `BatchGate`.");
AddAttr<std::string>("activation",
"(string, default tanh) "
"The activation type used for output candidate {h}_t.")
.SetDefault("tanh");
AddAttr<std::string>(
"gate_activation",
"(string, default sigmoid) "
"The activation type used in update gate and reset gate.")
.SetDefault("sigmoid");
AddAttr<bool>("is_reverse",
"(bool, defalut: False) "
"whether to compute reversed GRU.")
.SetDefault(false);
AddComment(R"DOC(
GRU Operator implements part calculations of the complete GRU as following:
\f[
update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
\f]
@note To implement the complete GRU, fully-connected operator must be used
before to feed xu, xr and xc as the Input of GRU operator.
)DOC");
}
};
class GRUGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUGradOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUGradOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(%s) of GRUGradOp should not be null.", "BatchGate");
PADDLE_ENFORCE(ctx->HasInput("BatchResetHiddenPrev"),
"Input(%s) of GRUGradOp should not be null.",
"BatchResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("BatchHidden"),
"Input(%s) of GRUOp should not be null.", "BatchHidden");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(%s) of GRUGradOp should not be null.", "Hidden");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUGradOp should not be null.", "Hidden");
auto input_dims = ctx->GetInputDim("Input");
auto weight_dims = ctx->GetInputDim("Weight");
int input_size = input_dims[1];
int frame_size = weight_dims[0];
int weight_height = weight_dims[0];
int weight_width = weight_dims[1];
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUOp.");
PADDLE_ENFORCE_EQ(
weight_height, frame_size,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
if (ctx->HasInput("H0")) {
auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size.");
auto h0_grad_name = framework::GradVarName("H0");
if (ctx->HasOutput(h0_grad_name))
ctx->SetOutputDim(h0_grad_name, h0_dims);
}
if (ctx->HasInput("Bias")) {
auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims);
}
auto input_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(input_grad_name))
ctx->SetOutputDim(input_grad_name, input_dims);
auto weight_grad_name = framework::GradVarName("Weight");
if (ctx->HasOutput(weight_grad_name))
ctx->SetOutputDim(weight_grad_name, weight_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(gru, ops::GRUOp, ops::GRUOpMaker, gru_grad, ops::GRUGradOp);
REGISTER_OP_CPU_KERNEL(gru, ops::GRUKernel<paddle::platform::CPUPlace, float>,
ops::GRUKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(gru_grad,
ops::GRUGradKernel<paddle::platform::CPUPlace, float>,
ops::GRUGradKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/gru_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gru, ops::GRUKernel<paddle::platform::GPUPlace, float>,
ops::GRUKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(gru_grad,
ops::GRUGradKernel<paddle::platform::GPUPlace, float>,
ops::GRUGradKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias");
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(context.GetPlace());
auto* batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_hidden->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<LoDTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
context.ShareLoD("Input", "Hidden");
auto hidden_dims = hidden->dims();
bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch;
to_batch(context.device_context(), *input, *batch_gate, true, is_reverse);
int frame_size = hidden_dims[1];
int batch_size = hidden_dims[0];
auto g = EigenMatrix<T>::From(*batch_gate);
auto place = context.GetEigenDevice<Place>();
if (bias) {
auto b = EigenMatrix<T>::From(*bias);
g.device(place) = g +
b.reshape(Eigen::array<int, 2>({{1, frame_size * 3}}))
.broadcast(Eigen::array<int, 2>({{batch_size, 1}}));
}
math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
gru_value.prevOutValue = const_cast<T*>(h0_data);
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.outputValue = hidden_t.data<T>();
gru_value.gateValue = gate_t.data<T>();
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<Place, T>::compute(
context.device_context(), gru_value, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
gru_value.prevOutValue = gru_value.outputValue;
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_hidden->set_lod(batch_gate->lod());
to_seq(context.device_context(), *batch_hidden, *hidden);
}
void Compute(const framework::ExecutionContext& context) const override {
BatchCompute(context);
}
};
template <typename Place, typename T>
class GRUGradKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* batch_gate = context.Input<LoDTensor>("BatchGate");
auto* batch_reset_hidden_prev =
context.Input<LoDTensor>("BatchResetHiddenPrev");
auto* batch_hidden = context.Input<LoDTensor>("BatchHidden");
auto* hidden = context.Input<LoDTensor>("Hidden");
auto* hidden_grad =
context.Input<LoDTensor>(framework::GradVarName("Hidden"));
auto* input_grad =
context.Output<LoDTensor>(framework::GradVarName("Input"));
auto* h0_grad = context.Output<Tensor>(framework::GradVarName("H0"));
auto* weight_grad =
context.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));
auto gate_dims = batch_gate->dims();
auto hidden_dims = hidden->dims();
int frame_size = hidden_dims[1];
math::LoDTensor2BatchFunctor<Place, T> to_batch;
LoDTensor batch_hidden_grad, batch_gate_grad, batch_reset_hidden_prev_grad;
batch_hidden_grad.mutable_data<T>(hidden_dims, context.GetPlace());
batch_gate_grad.mutable_data<T>(gate_dims, context.GetPlace());
batch_reset_hidden_prev_grad.mutable_data<T>(hidden_dims,
context.GetPlace());
math::SetConstant<Place, T> zero;
zero(context.device_context(), &batch_hidden_grad, static_cast<T>(0.0));
zero(context.device_context(), &batch_gate_grad, static_cast<T>(0.0));
zero(context.device_context(), &batch_reset_hidden_prev_grad,
static_cast<T>(0.0));
bool is_reverse = context.Attr<bool>("is_reverse");
batch_hidden_grad.set_lod(batch_hidden->lod());
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, false,
is_reverse);
math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
math::hl_gru_grad<T> gru_grad;
if (weight_grad) {
gru_grad.gateWeightGrad =
weight_grad->mutable_data<T>(context.GetPlace());
zero(context.device_context(), weight_grad, static_cast<T>(0.0));
gru_grad.stateWeightGrad =
weight_grad->data<T>() + 2 * frame_size * frame_size;
} else {
gru_grad.gateWeightGrad = nullptr;
gru_grad.stateWeightGrad = nullptr;
}
auto batch_starts = batch_hidden_grad.lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
gru_value.gateValue = gate_t.data<T>();
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>();
Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend);
gru_grad.outputGrad = hidden_grad_t.data<T>();
Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend);
gru_grad.gateGrad = gate_grad_t.data<T>();
Tensor reset_hidden_prev_grad_t =
batch_reset_hidden_prev_grad.Slice(bstart, bend);
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
if (n == 0) {
gru_value.prevOutValue = const_cast<T*>(h0_data);
if (h0_grad) {
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace());
zero(context.device_context(), h0_grad, static_cast<T>(0.0));
gru_grad.prevOutGrad = h0_grad_data;
} else {
gru_grad.prevOutGrad = nullptr;
}
} else {
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
gru_value.prevOutValue = hidden_prev_t.data<T>();
Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart);
gru_grad.prevOutGrad = hidden_prev_grad_t.data<T>();
}
math::GRUUnitGradFunctor<Place, T>::compute(
context.device_context(), gru_value, gru_grad, frame_size,
cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
}
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_gate_grad.set_lod(batch_gate->lod());
to_seq(context.device_context(), batch_gate_grad, *input_grad);
}
if (bias_grad) {
bias_grad->mutable_data<T>(context.GetPlace());
auto d_b = EigenMatrix<T>::From(*bias_grad);
auto d_g = EigenMatrix<T>::From(batch_gate_grad);
auto place = context.GetEigenDevice<Place>();
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
}
}
void Compute(const framework::ExecutionContext& context) const override {
BatchCompute(context);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,6 +12,7 @@ if(WITH_GPU)
nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
......@@ -22,6 +23,7 @@ else()
cc_library(context_project SRCS context_project.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/gru_compute.h"
namespace paddle {
namespace operators {
namespace math {
namespace detail {
#ifndef __NVCC__
template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput,
T *gateValue, T *resetOutputValue,
T *prevOutputValue, int frameSize,
activation_mode_t active_gate) {
T rValueUpdateGate;
T rValueResetGate;
T rValueResetOutput;
T rPrevOut = 0;
T *updateGate = gateValue;
T *resetGate = gateValue + frameSize;
for (int i = 0; i < frameSize; i++) {
rValueUpdateGate = updateGate[i];
rValueResetGate = resetGate[i];
if (prevOutputValue) {
rPrevOut = prevOutputValue[i];
}
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut,
rValueResetOutput, active_gate);
updateGate[i] = rValueUpdateGate;
resetGate[i] = rValueResetGate;
resetOutputValue[i] = rValueResetOutput;
}
}
template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput,
T *gateValue, T *prevOutputValue,
T *outputValue, int frameSize,
activation_mode_t active_node) {
T rValueUpdateGate;
T rValueFrameState;
T rPrevOut = 0;
T rOutput;
T *updateGate = gateValue;
T *frameState = gateValue + frameSize * 2;
for (int i = 0; i < frameSize; i++) {
rValueUpdateGate = updateGate[i];
rValueFrameState = frameState[i];
if (prevOutputValue) {
rPrevOut = prevOutputValue[i];
}
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
active_node);
frameState[i] = rValueFrameState;
outputValue[i] = rOutput;
}
}
template <class OpResetOutput, typename T>
void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue,
T *resetOutputValue, T *prevOutputValue,
int frameSize,
activation_mode_t active_gate) {
#ifdef __AVX__
__m256 rValueUpdateGate;
__m256 rValueResetGate;
__m256 rValueResetOutput;
__m256 rPrevOut = _mm256_set1_ps(0.0f);
__m256 *updateGate = (__m256 *)gateValue;
__m256 *resetGate = (__m256 *)(gateValue + frameSize);
for (int i = 0; i < frameSize / 8; i++) {
rValueUpdateGate = updateGate[i];
rValueResetGate = resetGate[i];
if (prevOutputValue) {
rPrevOut = ((__m256 *)prevOutputValue)[i];
}
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut,
rValueResetOutput, active_gate);
updateGate[i] = rValueUpdateGate;
resetGate[i] = rValueResetGate;
((__m256 *)resetOutputValue)[i] = rValueResetOutput;
}
#endif
}
template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue,
T *prevOutputValue, T *outputValue,
int frameSize,
activation_mode_t active_node) {
#ifdef __AVX__
__m256 rValueUpdateGate;
__m256 rValueFrameState;
__m256 rPrevOut = _mm256_set1_ps(0.0f);
__m256 rOutput;
__m256 *updateGate = (__m256 *)gateValue;
__m256 *frameState = (__m256 *)(gateValue + frameSize * 2);
for (int i = 0; i < frameSize / 8; i++) {
rValueUpdateGate = updateGate[i];
rValueFrameState = frameState[i];
if (prevOutputValue) {
rPrevOut = ((__m256 *)prevOutputValue)[i];
}
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
active_node);
frameState[i] = rValueFrameState;
((__m256 *)outputValue)[i] = rOutput;
}
#endif
}
template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput opResetOutput,
hl_gru_value<T> value, int frameSize,
int batchSize, activation_mode_t active_gate) {
for (int b = 0; b < batchSize; b++) {
if (OpResetOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_reset_output(
opResetOutput, value.gateValue, value.resetOutputValue,
value.prevOutValue, frameSize, active_gate);
} else {
hl_naive_gru_forward_reset_output(
opResetOutput, value.gateValue, value.resetOutputValue,
value.prevOutValue, frameSize, active_gate);
}
value.gateValue += frameSize * 3;
value.resetOutputValue += frameSize;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
}
}
}
template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput opFinalOutput,
hl_gru_value<T> value, int frameSize,
int batchSize, activation_mode_t active_node) {
for (int b = 0; b < batchSize; b++) {
if (OpFinalOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(opFinalOutput, value.gateValue,
value.prevOutValue, value.outputValue,
frameSize, active_node);
} else {
hl_naive_gru_forward_final_output(opFinalOutput, value.gateValue,
value.prevOutValue, value.outputValue,
frameSize, active_node);
}
value.gateValue += frameSize * 3;
value.outputValue += frameSize;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
}
}
}
template <class OpStateGrad, typename T>
void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *outputGrad,
int frameSize,
activation_mode_t active_node) {
T rUpdateGateValue;
T rUpdateGateGrad;
T rFrameStateValue;
T rFrameStateGrad;
T rOutGrad;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T *updateGateValue = gateValue;
T *updateGateGrad = gateGrad;
T *frameStateValue = gateValue + frameSize * 2;
T *frameStateGrad = gateGrad + frameSize * 2;
for (int i = 0; i < frameSize; i++) {
rUpdateGateValue = updateGateValue[i];
rFrameStateValue = frameStateValue[i];
rOutGrad = outputGrad[i];
if (prevOutValue) {
rPrevOutValue = prevOutValue[i];
}
if (prevOutGrad) {
rPrevOutGrad = prevOutGrad[i];
}
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
active_node);
updateGateGrad[i] = rUpdateGateGrad;
frameStateGrad[i] = rFrameStateGrad;
if (prevOutGrad) {
prevOutGrad[i] = rPrevOutGrad;
}
}
}
template <class OpResetGrad, typename T>
void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *resetOutputGrad,
int frameSize,
activation_mode_t active_gate) {
T rUpdateGateValue;
T rUpdateGateGrad;
T rResetGateValue;
T rResetGateGrad;
T rResetOutputGrad = 0;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T *updateGateValue = gateValue;
T *updateGateGrad = gateGrad;
T *resetGateValue = gateValue + frameSize;
T *resetGateGrad = gateGrad + frameSize;
for (int i = 0; i < frameSize; i++) {
rUpdateGateValue = updateGateValue[i];
rUpdateGateGrad = updateGateGrad[i];
rResetGateValue = resetGateValue[i];
if (prevOutValue && prevOutGrad) {
rResetOutputGrad = resetOutputGrad[i];
}
if (prevOutValue) {
rPrevOutValue = prevOutValue[i];
}
if (prevOutGrad) {
rPrevOutGrad = prevOutGrad[i];
}
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
active_gate);
updateGateGrad[i] = rUpdateGateGrad;
resetGateGrad[i] = rResetGateGrad;
if (prevOutGrad) {
prevOutGrad[i] = rPrevOutGrad;
}
}
}
template <class OpStateGrad, typename T>
void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *outputGrad,
int frameSize,
activation_mode_t active_node) {
#ifdef __AVX__
__m256 rUpdateGateValue;
__m256 rUpdateGateGrad;
__m256 rFrameStateValue;
__m256 rFrameStateGrad;
__m256 rOutGrad;
__m256 rPrevOutValue = _mm256_set1_ps(0.0f);
__m256 rPrevOutGrad = _mm256_set1_ps(0.0f);
__m256 *updateGateValue = (__m256 *)gateValue;
__m256 *updateGateGrad = (__m256 *)gateGrad;
__m256 *frameStateValue = (__m256 *)(gateValue + frameSize * 2);
__m256 *frameStateGrad = (__m256 *)(gateGrad + frameSize * 2);
for (int i = 0; i < frameSize / 8; i++) {
rUpdateGateValue = updateGateValue[i];
rFrameStateValue = frameStateValue[i];
rOutGrad = ((__m256 *)outputGrad)[i];
if (prevOutValue) {
rPrevOutValue = ((__m256 *)prevOutValue)[i];
}
if (prevOutGrad) {
rPrevOutGrad = ((__m256 *)prevOutGrad)[i];
}
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
active_node);
updateGateGrad[i] = rUpdateGateGrad;
frameStateGrad[i] = rFrameStateGrad;
if (prevOutGrad) {
((__m256 *)prevOutGrad)[i] = rPrevOutGrad;
}
}
#endif
}
template <class OpResetGrad, typename T>
void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *resetOutputGrad,
int frameSize,
activation_mode_t active_gate) {
#ifdef __AVX__
__m256 rUpdateGateValue;
__m256 rUpdateGateGrad;
__m256 rResetGateValue;
__m256 rResetGateGrad;
__m256 rResetOutputGrad = _mm256_set1_ps(0.0f);
__m256 rPrevOutValue = _mm256_set1_ps(0.0f);
__m256 rPrevOutGrad = _mm256_set1_ps(0.0f);
__m256 *updateGateValue = (__m256 *)gateValue;
__m256 *updateGateGrad = (__m256 *)gateGrad;
__m256 *resetGateValue = (__m256 *)(gateValue + frameSize);
__m256 *resetGateGrad = (__m256 *)(gateGrad + frameSize);
for (int i = 0; i < frameSize / 8; i++) {
rUpdateGateValue = updateGateValue[i];
rUpdateGateGrad = updateGateGrad[i];
rResetGateValue = resetGateValue[i];
if (prevOutValue && prevOutGrad) {
rResetOutputGrad = ((__m256 *)resetOutputGrad)[i];
}
if (prevOutValue) {
rPrevOutValue = ((__m256 *)prevOutValue)[i];
}
if (prevOutGrad) {
rPrevOutGrad = ((__m256 *)prevOutGrad)[i];
}
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
active_gate);
updateGateGrad[i] = rUpdateGateGrad;
resetGateGrad[i] = rResetGateGrad;
if (prevOutGrad) {
((__m256 *)prevOutGrad)[i] = rPrevOutGrad;
}
}
#endif
}
template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad opStateGrad, hl_gru_value<T> value,
hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node) {
for (int b = 0; b < batchSize; b++) {
if (OpStateGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_state_grad(
opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.outputGrad, frameSize, active_node);
} else {
hl_naive_gru_backward_state_grad(
opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.outputGrad, frameSize, active_node);
}
value.gateValue += frameSize * 3;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
}
grad.gateGrad += frameSize * 3;
grad.outputGrad += frameSize;
if (grad.prevOutGrad) {
grad.prevOutGrad += frameSize;
}
}
}
template <class OpResetGrad, typename T>
inline void backward_reset_grad(OpResetGrad opResetGrad, hl_gru_value<T> value,
hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_gate) {
for (int b = 0; b < batchSize; b++) {
if (OpResetGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_reset_grad(
opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate);
} else {
hl_naive_gru_backward_reset_grad(
opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate);
}
value.gateValue += frameSize * 3;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
}
grad.gateGrad += frameSize * 3;
grad.resetOutputGrad += frameSize;
if (grad.prevOutGrad) {
grad.prevOutGrad += frameSize;
}
}
}
#endif
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h"
#include <glog/logging.h>
namespace paddle {
namespace operators {
namespace math {
namespace detail {
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class OpResetOutput, bool isBatch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput,
T *gateValue, T *resetOutputValue,
T *prevOutputValue, int frameSize,
int batchSize,
activation_mode_t active_gate) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
resetOutputValue += batchIdx * frameSize;
}
T rPrevOut = 0;
T rValueResetOutput;
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0];
T rValueResetGate = gateValue[frameIdx + frameSize * 1];
if (prevOutputValue) {
if (isBatch) prevOutputValue += batchIdx * frameSize;
rPrevOut = prevOutputValue[frameIdx];
}
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput,
active_gate);
gateValue[frameIdx + frameSize * 0] = rValueUpdateGate;
gateValue[frameIdx + frameSize * 1] = rValueResetGate;
resetOutputValue[frameIdx] = rValueResetOutput;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class OpFinalOutput, bool isBatch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput,
T *gateValue, T *prevOutputValue,
T *outputValue, int frameSize,
int batchSize,
activation_mode_t active_node) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
outputValue += batchIdx * frameSize;
}
T rOutput;
T rPrevOut = 0;
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0];
T rValueFrameState = gateValue[frameIdx + frameSize * 2];
if (prevOutputValue) {
if (isBatch) prevOutputValue += batchIdx * frameSize;
rPrevOut = prevOutputValue[frameIdx];
}
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
active_node);
gateValue[frameIdx + frameSize * 2] = rValueFrameState;
outputValue[frameIdx] = rOutput;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class OpStateGrad, bool isBatch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *outputGrad,
int frameSize, int batchSize,
activation_mode_t active_node) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
gateGrad += batchIdx * 3 * frameSize;
outputGrad += batchIdx * frameSize;
}
T rUpdateGateGrad;
T rFrameStateGrad;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0];
T rFrameStateValue = gateValue[frameIdx + frameSize * 2];
T rOutGrad = outputGrad[frameIdx];
if (prevOutValue && prevOutGrad) {
if (isBatch) prevOutValue += batchIdx * frameSize;
rPrevOutValue = prevOutValue[frameIdx];
if (isBatch) prevOutGrad += batchIdx * frameSize;
rPrevOutGrad = prevOutGrad[frameIdx];
}
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
active_node);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
gateGrad[frameIdx + frameSize * 2] = rFrameStateGrad;
if (prevOutGrad) {
prevOutGrad[frameIdx] = rPrevOutGrad;
}
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class OpResetGrad, bool isBatch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *resetOutputGrad,
int frameSize, int batchSize,
activation_mode_t active_gate) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
gateGrad += batchIdx * 3 * frameSize;
resetOutputGrad += batchIdx * frameSize;
}
T rResetGateGrad;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T rResetOutputGrad = 0;
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0];
T rUpdateGateGrad = gateGrad[frameIdx + frameSize * 0];
T rResetGateValue = gateValue[frameIdx + frameSize * 1];
if (prevOutValue && prevOutGrad) {
if (isBatch) prevOutValue += batchIdx * frameSize;
if (isBatch) prevOutGrad += batchIdx * frameSize;
rPrevOutValue = prevOutValue[frameIdx];
rPrevOutGrad = prevOutGrad[frameIdx];
rResetOutputGrad = resetOutputGrad[frameIdx];
}
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
active_gate);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
gateGrad[frameIdx + frameSize * 1] = rResetGateGrad;
if (prevOutGrad) {
prevOutGrad[frameIdx] = rPrevOutGrad;
}
}
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/hostdevice.h"
#include <type_traits>
// TODO(guosheng): refine code style in gru_kernel
namespace paddle {
namespace operators {
namespace math {
namespace detail {
namespace forward {
template <typename T>
class gru_resetOutput {
public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueResetGate, T &prevOut,
T &valueResetOutput, activation_mode_t actGate) {
valueUpdateGate = activation(valueUpdateGate, actGate);
valueResetGate = activation(valueResetGate, actGate);
valueResetOutput = prevOut * valueResetGate;
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueResetGate,
__m256 &prevOut, __m256 &valueResetOutput,
activation_mode_t actGate) {
valueUpdateGate = activation(valueUpdateGate, actGate);
valueResetGate = activation(valueResetGate, actGate);
valueResetOutput = _mm256_mul_ps(prevOut, valueResetGate);
}
#endif
#endif
};
template <typename T>
class gru_finalOutput {
public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueFrameState, T &prevOut,
T &valueOutput, activation_mode_t actInput) {
valueFrameState = activation(valueFrameState, actInput);
valueOutput = prevOut - (valueUpdateGate * prevOut) +
(valueUpdateGate * valueFrameState);
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueFrameState,
__m256 &prevOut, __m256 &valueOutput,
activation_mode_t actInput) {
valueFrameState = activation(valueFrameState, actInput);
valueOutput = _mm256_add_ps(
_mm256_sub_ps(prevOut, _mm256_mul_ps(valueUpdateGate, prevOut)),
_mm256_mul_ps(valueUpdateGate, valueFrameState));
}
#endif
#endif
};
} // namespace forward
namespace backward {
template <typename T>
class gru_stateGrad {
public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
T &valueFrameState, T &gradFrameState,
T &valuePrevOut, T &gradPrevOut, T &gradOutput,
activation_mode_t actInput) {
gradUpdateGate = (gradOutput * valueFrameState);
gradUpdateGate -= (gradOutput * valuePrevOut);
gradPrevOut -= (gradOutput * valueUpdateGate);
gradPrevOut += gradOutput;
gradFrameState =
activation(gradOutput * valueUpdateGate, valueFrameState, actInput);
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate,
__m256 &valueFrameState, __m256 &gradFrameState,
__m256 &valuePrevOut, __m256 &gradPrevOut,
__m256 &gradOutput, activation_mode_t actInput) {
gradUpdateGate = _mm256_mul_ps(gradOutput, valueFrameState);
gradUpdateGate =
_mm256_sub_ps(gradUpdateGate, _mm256_mul_ps(gradOutput, valuePrevOut));
gradPrevOut = _mm256_add_ps(
_mm256_sub_ps(gradPrevOut, _mm256_mul_ps(gradOutput, valueUpdateGate)),
gradOutput);
gradFrameState = activation(_mm256_mul_ps(gradOutput, valueUpdateGate),
valueFrameState, actInput);
}
#endif
#endif
};
template <typename T>
class gru_resetGrad {
public:
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
T &valueResetGate, T &gradResetGate,
T &valuePrevOut, T &gradPrevOut,
T &gradResetOutput, activation_mode_t actGate) {
gradResetGate = (gradResetOutput * valuePrevOut);
gradPrevOut += (gradResetOutput * valueResetGate);
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate);
gradResetGate = activation(gradResetGate, valueResetGate, actGate);
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate,
__m256 &valueResetGate, __m256 &gradResetGate,
__m256 &valuePrevOut, __m256 &gradPrevOut,
__m256 &gradResetOutput,
activation_mode_t actGate) {
gradResetGate = _mm256_mul_ps(gradResetOutput, valuePrevOut);
gradPrevOut = _mm256_add_ps(gradPrevOut,
_mm256_mul_ps(gradResetOutput, valueResetGate));
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate);
gradResetGate = activation(gradResetGate, valueResetGate, actGate);
}
#endif
#endif
};
} // namespace backward
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/operators/math/detail/gru_kernel.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct GRUUnitFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate) {
#ifndef __NVCC__
if (value.prevOutValue) {
math::gemm<platform::CPUPlace, T>(
context, false, false, batchSize, frameSize * 2, frameSize, 1,
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1,
value.gateValue, frameSize * 3);
}
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frameSize, batchSize, active_gate);
if (value.prevOutValue) {
math::gemm<platform::CPUPlace, T>(
context, false, false, batchSize, frameSize, frameSize, 1,
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1,
value.gateValue + frameSize * 2, frameSize * 3);
}
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frameSize, batchSize, active_node);
#endif
}
};
template <typename T>
struct GRUUnitGradFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate) {
#ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frameSize, batchSize, active_node);
if (value.prevOutValue && grad.prevOutGrad) {
math::gemm<platform::CPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize, 1,
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight,
frameSize, 0, grad.resetOutputGrad, frameSize);
if (grad.stateWeightGrad) {
math::gemm<platform::CPUPlace, T>(
context, true, false, frameSize, frameSize, batchSize, 1,
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2,
frameSize * 3, 1, grad.stateWeightGrad, frameSize);
}
}
detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value,
grad, frameSize, batchSize, active_gate);
if (grad.prevOutGrad && value.prevOutValue) {
math::gemm<platform::CPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize * 2, 1,
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1,
grad.prevOutGrad, frameSize);
if (grad.gateWeightGrad) {
math::gemm<platform::CPUPlace, T>(
context, true, false, frameSize, frameSize * 2, batchSize, 1,
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1,
grad.gateWeightGrad, frameSize * 2);
}
}
#endif
}
};
template struct GRUUnitFunctor<platform::CPUPlace, float>;
template struct GRUUnitFunctor<platform::CPUPlace, double>;
template struct GRUUnitGradFunctor<platform::CPUPlace, float>;
template struct GRUUnitGradFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/gru_gpu_kernel.h"
#include "paddle/operators/math/detail/gru_kernel.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct GRUUnitFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate) {
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
} else {
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
}
if (value.prevOutValue) {
math::gemm<platform::GPUPlace, T>(
context, false, false, batchSize, frameSize * 2, frameSize, 1,
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1,
value.gateValue, frameSize * 3);
}
if (batchSize == 1) {
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
/* isBatch= */ false,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_resetOutput<T>(), value.gateValue,
value.resetOutputValue, value.prevOutValue, frameSize, batchSize,
active_gate);
} else {
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
/* isBatch= */ true,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_resetOutput<T>(), value.gateValue,
value.resetOutputValue, value.prevOutValue, frameSize, batchSize,
active_gate);
}
if (value.prevOutValue) {
math::gemm<platform::GPUPlace, T>(
context, false, false, batchSize, frameSize, frameSize, 1,
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1,
value.gateValue + frameSize * 2, frameSize * 3);
}
if (batchSize == 1) {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* isBatch= */ false,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gateValue,
value.prevOutValue, value.outputValue, frameSize, batchSize,
active_node);
} else {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* isBatch= */ true,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gateValue,
value.prevOutValue, value.outputValue, frameSize, batchSize,
active_node);
}
}
};
template <typename T>
struct GRUUnitGradFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate) {
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
} else {
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
}
if (batchSize == 1) {
detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize,
batchSize, active_node);
} else {
detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize,
batchSize, active_node);
}
if (value.prevOutValue && grad.prevOutGrad) {
math::gemm<platform::GPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize, 1,
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight,
frameSize, 0, grad.resetOutputGrad, frameSize);
if (grad.stateWeightGrad) {
math::gemm<platform::GPUPlace, T>(
context, true, false, frameSize, frameSize, batchSize, 1,
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2,
frameSize * 3, 1, grad.stateWeightGrad, frameSize);
}
}
if (batchSize == 1) {
detail::KeGruBackwardResetGrad<
detail::backward::gru_resetGrad<T>,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize,
batchSize, active_gate);
} else {
detail::KeGruBackwardResetGrad<
detail::backward::gru_resetGrad<T>,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize,
batchSize, active_gate);
}
if (grad.prevOutGrad && value.prevOutValue) {
math::gemm<platform::GPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize * 2, 1,
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1,
grad.prevOutGrad, frameSize);
if (grad.gateWeightGrad) {
math::gemm<platform::GPUPlace, T>(
context, true, false, frameSize, frameSize * 2, batchSize, 1,
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1,
grad.gateWeightGrad, frameSize * 2);
}
}
}
};
template struct GRUUnitFunctor<platform::GPUPlace, float>;
template struct GRUUnitFunctor<platform::GPUPlace, double>;
template struct GRUUnitGradFunctor<platform::GPUPlace, float>;
template struct GRUUnitGradFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
// TODO(guosheng): refine code style in gru_compute
template <typename T>
struct hl_gru_value {
T *gateWeight;
T *stateWeight;
T *gateValue;
T *resetOutputValue;
T *outputValue;
T *prevOutValue;
};
template <typename T>
struct hl_gru_grad {
T *gateWeightGrad;
T *stateWeightGrad;
T *gateGrad;
T *resetOutputGrad;
T *outputGrad;
T *prevOutGrad;
};
template <typename Place, typename T>
struct GRUUnitFunctor {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate);
};
template <typename Place, typename T>
struct GRUUnitGradFunctor {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate);
};
} // namespace math
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
import math
from op_test import OpTest
from test_lstm_op import identity, sigmoid, tanh, relu
class TestGRUOp(OpTest):
batch_size = 9
frame_size = 5
activate = {
'identity': identity,
'sigmoid': sigmoid,
'tanh': tanh,
'relu': relu
}
@staticmethod
def seq_to_batch(lod, is_reverse):
idx_in_seq_list = []
seq_starts = lod[0]
seq_lens = []
for i in range(len(seq_starts) - 1):
seq_lens.append(seq_starts[i + 1] - seq_starts[i])
sorted_seqs = sorted(
range(len(seq_lens)), lambda x, y: seq_lens[y] - seq_lens[x])
num_batch = seq_lens[sorted_seqs[0]]
for batch_idx in range(num_batch):
idx_in_seq = []
for i in range(len(seq_lens)):
if seq_lens[sorted_seqs[i]] <= batch_idx:
break
idx = (seq_starts[sorted_seqs[i] + 1] - 1 - batch_idx
) if is_reverse else (
seq_starts[sorted_seqs[i]] + batch_idx)
idx_in_seq.append(idx)
idx_in_seq_list.append(idx_in_seq)
return idx_in_seq_list
def gru_step(self, x, h_p, w, b):
batch_size = x.shape[0]
frame_size = w.shape[0]
g = x + np.tile(b, (batch_size, 1))
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape(
(frame_size, frame_size * 2))
u_r = self.activate[self.attrs['gate_activation']](np.dot(
h_p, w_u_r) + g[:, :frame_size * 2])
u = u_r[:, :frame_size]
r = u_r[:, frame_size:frame_size * 2]
r_h_p = r * h_p
w_c = w.flatten()[frame_size * frame_size * 2:].reshape(
(frame_size, frame_size))
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
g = np.hstack((u_r, c))
h = u * c + (1 - u) * h_p
return g, r_h_p, h
def gru(self):
input, lod = self.inputs['Input']
w = self.inputs['Weight']
b = self.inputs['Bias'] if self.inputs.has_key('Bias') else np.zeros(
(1, self.frame_size * 3))
batch_gate = self.outputs['BatchGate']
batch_reset_hidden_prev = self.outputs['BatchResetHiddenPrev']
batch_hidden = self.outputs['BatchHidden']
hidden = self.outputs['Hidden']
idx_in_seq_list = self.idx_in_seq_list
h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros(
(len(idx_in_seq_list[0]), self.frame_size))
num_batch = len(idx_in_seq_list)
end_idx = 0
for batch_idx in range(num_batch):
x = input[idx_in_seq_list[batch_idx]]
g, r_h_p, h = self.gru_step(x, h_p, w, b)
if batch_idx < (num_batch - 1):
h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
start_idx = end_idx
end_idx = start_idx + len(idx_in_seq_list[batch_idx])
batch_gate[start_idx:end_idx] = g
batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
batch_hidden[start_idx:end_idx] = h
hidden[idx_in_seq_list[batch_idx]] = h
return batch_gate, batch_reset_hidden_prev, hidden
def set_data(self):
lod = [[0, 2, 6, self.batch_size]]
self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse)
batch_size = self.batch_size
frame_size = self.frame_size
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
h0 = np.random.rand(len(self.idx_in_seq_list[0]),
frame_size).astype('float64')
weight = np.random.rand(frame_size, frame_size * 3).astype('float64')
bias = np.random.rand(1, frame_size * 3).astype('float64')
self.inputs = {
'Input': (input, lod),
'H0': h0,
'Weight': weight,
'Bias': bias
}
self.outputs = {
'BatchGate': np.zeros(
(batch_size, frame_size * 3), dtype='float64'),
'BatchResetHiddenPrev': np.zeros(
(batch_size, frame_size), dtype='float64'),
'BatchHidden': np.zeros(
(batch_size, frame_size), dtype='float64'),
'Hidden': np.zeros(
(batch_size, frame_size), dtype='float64')
}
def set_confs(self):
self.is_reverse = False
self.attrs = {
'activation': 'tanh',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
def setUp(self):
self.op_type = "gru"
self.set_confs()
self.set_data()
self.gru()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpNoInitial(TestGRUOp):
def set_data(self):
super(TestGRUOpNoInitial, self).set_data()
self.inputs.pop('H0')
def test_check_grad(self):
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpReverse(TestGRUOp):
def set_confs(self):
self.is_reverse = True
self.attrs = {
'activation': 'identity',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册