提交 b87eabae 编写于 作者: G guosheng

Add GRU Operator

上级 7d653c41
......@@ -116,7 +116,8 @@ set(DEPS_OPS
sum_op
pool_op
pool_with_index_op
lstm_op)
lstm_op
gru_op)
op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
......@@ -128,6 +129,7 @@ op_library(sum_op DEPS net_op)
op_library(pool_op DEPS pooling)
op_library(pool_with_index_op DEPS pooling)
op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(gru_op DEPS sequence2batch gru_compute)
list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS})
foreach(src ${GENERAL_OPS})
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/gru_op.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class GRUOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(%s) of GRUOp should not be null.", "BatchGate");
PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"),
"Output(%s) of GRUOp should not be null.",
"BatchResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
"Output(%s) of GRUOp should not be null.", "BatchHidden");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(%s) of GRUOp should not be null.", "Hidden");
auto input_dims = ctx->GetInputDim("Input");
auto weight_dims = ctx->GetInputDim("Weight");
int input_size = input_dims[1];
int frame_size = weight_dims[0];
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUOp.");
PADDLE_ENFORCE_EQ(
weight_dims[1], frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
auto h0 = Input("H0");
if (h0 != framework::kEmptyVarName) {
auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size.");
}
auto bias = Input("Bias");
if (bias != framework::kEmptyVarName) {
auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
}
ctx->SetOutputDim("BatchGate", input_dims);
ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size});
ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size});
ctx->SetOutputDim("Hidden", {input_dims[0], frame_size});
// ctx->ShareLoD("Input", "Gate");
// ctx->ShareLoD("Input", "ResetHiddenPrev");
ctx->ShareLoD("Input", "Hidden");
}
};
class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
GRUOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTenosr is a matrix with shape (T X 3D), where, T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput("H0",
"(Tensor, optional) the initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size.");
AddInput(
"Weight",
"(Tensor) Weight matrix with shape [hidden_size, hidden_size * 3]. "
"The elements continuous in memory can be divided into two parts. "
"The first part are weights of the update gate and reset gate "
"with shape [hidden_size, hidden_size * 2], and the second part are "
"weights of output candidate with shape [hidden_size, hidden_size]");
AddInput("Bias",
"(Tensor) Bias vector with shape [1, hidden_size * 3] concating "
"bias of the update gate, reset gate and output candidate.");
AddOutput("BatchGate",
"(LoDTensor) the update gata, reset gate and output candidate "
"lod tensor of GRU operator. "
"The shape and lod is the same with the `Input`.")
.AsIntermediate();
AddOutput(
"BatchResetHiddenPrev",
"(LoDTensor) the reseted hidden state lod tensor of GRU operator. "
"The shape and lod is the same with the `Input`.")
.AsIntermediate();
AddOutput(
"BatchHidden",
"(LoDTensor) the reseted hidden state lod tensor of GRU operator. "
"The shape and lod is the same with the `Input`.")
.AsIntermediate();
AddOutput("Hidden",
"(LoDTensor) the hidden state lod tensor of GRU operator. "
"The shape and lod is the same with the `Input`.");
AddAttr<std::string>("activation",
"(string, default tanh) "
"The activation type used for output candidate {h}_t.")
.SetDefault("tanh");
AddAttr<std::string>(
"gate_activation",
"(string, default sigmoid) "
"The activation type used in update gate and reset gate.")
.SetDefault("sigmoid");
AddAttr<bool>("is_reverse",
"(bool, defalut: False) "
"whether to compute reversed GRU.")
.SetDefault(false);
AddComment(R"DOC(
GRUOp implements part calculations of the GRU unit as following:
\f[
update \ gate: u_t = actGate(xu_t + W_u * hidden_prev + bias_u) \\
reset \ gate: r_t = actGate(xr_t + W_r * hidden_prev + bias_r) \\
output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, hidden_prev) + bias_c) \\
output: h_t = dot((1-u_t), hidden_prev) + dot(u_t, {h}_t)
\f]
The rest of GRU unit can be completed by using FCOp's output as the input of GRUOp.
)DOC");
}
};
class GRUGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(%s) of GRUGradOp should not be null.", "Input");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(%s) of GRUGradOp should not be null.", "Weight");
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(%s) of GRUGradOp should not be null.", "BatchGate");
PADDLE_ENFORCE(ctx->HasInput("BatchResetHiddenPrev"),
"Input(%s) of GRUGradOp should not be null.",
"BatchResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("BatchHidden"),
"Input(%s) of GRUOp should not be null.", "BatchHidden");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(%s) of GRUGradOp should not be null.", "Hidden");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUGradOp should not be null.", "Hidden");
auto input_dims = ctx->GetInputDim("Input");
auto weight_dims = ctx->GetInputDim("Weight");
int input_size = input_dims[1];
int frame_size = weight_dims[0];
int weight_height = weight_dims[0];
int weight_width = weight_dims[1];
PADDLE_ENFORCE_EQ(input_size, frame_size * 3,
"The input_size must be 3 times of frame_size in GRUOp.");
PADDLE_ENFORCE_EQ(
weight_height, frame_size,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
PADDLE_ENFORCE_EQ(
weight_width, frame_size * 3,
"The shape of Weight matrix must be [frame_size, frame_size * 3].");
auto h0 = Input("H0");
if (h0 != framework::kEmptyVarName) {
auto h0_dims = ctx->GetInputDim("H0");
PADDLE_ENFORCE_EQ(h0_dims[1], frame_size,
"The width of H0 must be equal to frame_size.");
auto h0_grad_name = framework::GradVarName("H0");
if (ctx->HasOutput(h0_grad_name))
ctx->SetOutputDim(h0_grad_name, h0_dims);
}
auto bias = Input("Bias");
if (bias != framework::kEmptyVarName) {
auto bias_dims = ctx->GetInputDim("Bias");
int bias_height = bias_dims[0];
int bias_width = bias_dims[1];
PADDLE_ENFORCE_EQ(bias_height, 1,
"The shape of Bias must be [1, frame_size * 3].");
PADDLE_ENFORCE_EQ(bias_width, frame_size * 3,
"The shape of Bias must be [1, frame_size * 3].");
auto bias_grad_name = framework::GradVarName("Bias");
if (ctx->HasOutput(bias_grad_name))
ctx->SetOutputDim(bias_grad_name, bias_dims);
}
auto input_grad_name = framework::GradVarName("Input");
if (ctx->HasOutput(input_grad_name))
ctx->SetOutputDim(input_grad_name, input_dims);
auto weight_grad_name = framework::GradVarName("Weight");
if (ctx->HasOutput(weight_grad_name))
ctx->SetOutputDim(weight_grad_name, weight_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(gru, ops::GRUOp, ops::GRUOpMaker, gru_grad, ops::GRUGradOp);
REGISTER_OP_CPU_KERNEL(gru, ops::GRUKernel<paddle::platform::CPUPlace, float>,
ops::GRUKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL(gru_grad,
ops::GRUGradKernel<paddle::platform::CPUPlace, float>,
ops::GRUGradKernel<paddle::platform::CPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/gru_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(gru, ops::GRUKernel<paddle::platform::GPUPlace, float>,
ops::GRUKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL(gru_grad,
ops::GRUGradKernel<paddle::platform::GPUPlace, float>,
ops::GRUGradKernel<paddle::platform::GPUPlace, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* bias = context.Input<Tensor>("Bias");
auto* batch_gate = context.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(context.GetPlace());
auto* batch_reset_hidden_prev =
context.Output<LoDTensor>("BatchResetHiddenPrev");
batch_reset_hidden_prev->mutable_data<T>(context.GetPlace());
auto* batch_hidden = context.Output<LoDTensor>("BatchHidden");
batch_hidden->mutable_data<T>(context.GetPlace());
auto* hidden = context.Output<LoDTensor>("Hidden");
hidden->mutable_data<T>(context.GetPlace());
// context.ShareLoD("Input", "Gate");
// context.ShareLoD("Input", "ResetHiddenPrev");
context.ShareLoD("Input", "Hidden");
// auto gate_dims = gate->dims();
auto hidden_dims = hidden->dims();
// LoDTensor batch_gate, batch_reset_hidden_prev, batch_hidden;
// batch_gate.mutable_data<T>(gate_dims, context.GetPlace());
// batch_reset_hidden_prev.mutable_data<T>(hidden_dims, context.GetPlace());
// batch_hidden.mutable_data<T>(hidden_dims, context.GetPlace());
bool is_reverse = context.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch;
// to_batch(context.device_context(), *input, batch_gate, is_reverse);
to_batch(context.device_context(), *input, *batch_gate, is_reverse);
int frame_size = hidden_dims[1];
int batch_size = hidden_dims[0];
// auto g = EigenMatrix<T>::From(batch_gate);
auto g = EigenMatrix<T>::From(*batch_gate);
auto place = context.GetEigenDevice<Place>();
if (bias) {
auto b = EigenMatrix<T>::From(*bias);
g.device(place) = g +
b.reshape(Eigen::array<int, 2>({{1, frame_size * 3}}))
.broadcast(Eigen::array<int, 2>({{batch_size, 1}}));
}
math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
gru_value.prevOutValue = const_cast<T*>(h0_data);
// auto batch_starts = batch_gate.lod()[0];
auto batch_starts = batch_gate->lod()[0];
// for (auto i = batch_gate->lod()[1].begin(); i !=
// batch_gate->lod()[1].end(); ++i)
// std::cout << static_cast<int>(*i) << ' ';
size_t num_batch = batch_starts.size() - 1;
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
// Tensor gate_t = batch_gate.Slice(bstart, bend);
// Tensor reset_hidden_prev_t = batch_reset_hidden_prev.Slice(bstart,
// bend);
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
gru_value.outputValue = hidden_t.data<T>();
gru_value.gateValue = gate_t.data<T>();
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<Place, T>::compute(
context.device_context(), gru_value, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
gru_value.prevOutValue = gru_value.outputValue;
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
// batch_gate.set_lod(batch_gate.lod());
// to_seq(context.device_context(), batch_gate, *gate);
// batch_reset_hidden_prev.set_lod(batch_gate.lod());
// to_seq(context.device_context(), batch_reset_hidden_prev,
// *reset_hidden_prev);
// batch_hidden.set_lod(batch_gate.lod());
// to_seq(context.device_context(), batch_hidden, *hidden);
batch_hidden->set_lod(batch_gate->lod());
to_seq(context.device_context(), *batch_hidden, *hidden);
}
void Compute(const framework::ExecutionContext& context) const override {
BatchCompute(context);
}
};
template <typename Place, typename T>
class GRUGradKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
auto* h0 = context.Input<Tensor>("H0");
const T* h0_data = h0 ? h0->data<T>() : nullptr;
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
auto* batch_gate = context.Input<LoDTensor>("BatchGate");
auto* batch_reset_hidden_prev =
context.Input<LoDTensor>("BatchResetHiddenPrev");
auto* batch_hidden = context.Input<LoDTensor>("BatchHidden");
auto* hidden = context.Input<LoDTensor>("Hidden");
auto* hidden_grad =
context.Input<LoDTensor>(framework::GradVarName("Hidden"));
auto* input_grad =
context.Output<LoDTensor>(framework::GradVarName("Input"));
auto* h0_grad = context.Output<Tensor>(framework::GradVarName("H0"));
auto* weight_grad =
context.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));
auto gate_dims = batch_gate->dims();
auto hidden_dims = hidden->dims();
int frame_size = hidden_dims[1];
math::LoDTensor2BatchFunctor<Place, T> to_batch;
LoDTensor batch_hidden_grad, batch_gate_grad, batch_reset_hidden_prev_grad;
batch_hidden_grad.mutable_data<T>(hidden_dims, context.GetPlace());
batch_gate_grad.mutable_data<T>(gate_dims, context.GetPlace());
batch_reset_hidden_prev_grad.mutable_data<T>(hidden_dims,
context.GetPlace());
math::SetConstant<Place, T> zero;
zero(context.device_context(), &batch_hidden_grad, static_cast<T>(0.0));
zero(context.device_context(), &batch_gate_grad, static_cast<T>(0.0));
zero(context.device_context(), &batch_reset_hidden_prev_grad,
static_cast<T>(0.0));
// batch_hidden.set_lod(batch_gate->lod());
bool is_reverse = context.Attr<bool>("is_reverse");
batch_hidden_grad.set_lod(batch_hidden->lod());
// context.ShareLoD(framework::GradVarName("Hidden"),
// framework::GradVarName("Input"));
to_batch(context.device_context(), *hidden_grad, batch_hidden_grad,
is_reverse, false);
math::hl_gru_value<T> gru_value;
gru_value.gateWeight = const_cast<T*>(weight_data);
gru_value.stateWeight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
math::hl_gru_grad<T> gru_grad;
if (weight_grad) {
gru_grad.gateWeightGrad =
weight_grad->mutable_data<T>(context.GetPlace());
zero(context.device_context(), weight_grad, static_cast<T>(0.0));
gru_grad.stateWeightGrad =
weight_grad->data<T>() + 2 * frame_size * frame_size;
} else {
gru_grad.gateWeightGrad = nullptr;
gru_grad.stateWeightGrad = nullptr;
}
auto batch_starts = batch_hidden_grad.lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
int cur_batch_size = bend - bstart;
Tensor gate_t = batch_gate->Slice(bstart, bend);
gru_value.gateValue = gate_t.data<T>();
Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend);
gru_value.resetOutputValue = reset_hidden_prev_t.data<T>();
Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend);
gru_grad.outputGrad = hidden_grad_t.data<T>();
Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend);
gru_grad.gateGrad = gate_grad_t.data<T>();
Tensor reset_hidden_prev_grad_t =
batch_reset_hidden_prev_grad.Slice(bstart, bend);
gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data<T>();
if (n == 0) {
gru_value.prevOutValue = const_cast<T*>(h0_data);
if (h0_grad) {
T* h0_grad_data = h0_grad->mutable_data<T>(context.GetPlace());
zero(context.device_context(), h0_grad, static_cast<T>(0.0));
gru_grad.prevOutGrad = h0_grad_data;
} else {
gru_grad.prevOutGrad = nullptr;
}
} else {
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart);
gru_value.prevOutValue = hidden_prev_t.data<T>();
Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart);
gru_grad.prevOutGrad = hidden_prev_grad_t.data<T>();
}
math::GRUUnitGradFunctor<Place, T>::compute(
context.device_context(), gru_value, gru_grad, frame_size,
cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
}
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_gate_grad.set_lod(batch_gate->lod());
to_seq(context.device_context(), batch_gate_grad, *input_grad);
}
if (bias_grad) {
bias_grad->mutable_data<T>(context.GetPlace());
auto d_b = EigenMatrix<T>::From(*bias_grad);
auto d_g = EigenMatrix<T>::From(batch_gate_grad);
auto place = context.GetEigenDevice<Place>();
d_b.device(place) = d_g.sum(Eigen::array<int, 1>({{0}}));
}
}
void Compute(const framework::ExecutionContext& context) const override {
BatchCompute(context);
}
};
} // namespace operators
} // namespace paddle
......@@ -11,6 +11,7 @@ if(WITH_GPU)
nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context)
nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context)
nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions)
nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions)
else()
cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator)
cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function)
......@@ -20,6 +21,7 @@ else()
cc_library(vol2col SRCS vol2col.cc DEPS device_context)
cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context)
cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions)
cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions)
endif()
cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/gru_compute.h"
namespace paddle {
namespace operators {
namespace math {
namespace detail {
#ifndef __NVCC__
template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput,
T *gateValue, T *resetOutputValue,
T *prevOutputValue, int frameSize,
activation_mode_t active_gate) {
T rValueUpdateGate;
T rValueResetGate;
T rValueResetOutput;
T rPrevOut = 0;
T *updateGate = gateValue;
T *resetGate = gateValue + frameSize;
for (int i = 0; i < frameSize; i++) {
rValueUpdateGate = updateGate[i];
rValueResetGate = resetGate[i];
if (prevOutputValue) {
rPrevOut = prevOutputValue[i];
}
hppl::cpu::ForwardAct<T> act;
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut,
rValueResetOutput, act(active_gate));
updateGate[i] = rValueUpdateGate;
resetGate[i] = rValueResetGate;
resetOutputValue[i] = rValueResetOutput;
}
}
template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput,
T *gateValue, T *prevOutputValue,
T *outputValue, int frameSize,
activation_mode_t active_node) {
T rValueUpdateGate;
T rValueFrameState;
T rPrevOut = 0;
T rOutput;
T *updateGate = gateValue;
T *frameState = gateValue + frameSize * 2;
for (int i = 0; i < frameSize; i++) {
rValueUpdateGate = updateGate[i];
rValueFrameState = frameState[i];
if (prevOutputValue) {
rPrevOut = prevOutputValue[i];
}
hppl::cpu::ForwardAct<T> act;
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
act(active_node));
frameState[i] = rValueFrameState;
outputValue[i] = rOutput;
}
}
template <class OpResetOutput, typename T>
void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue,
T *resetOutputValue, T *prevOutputValue,
int frameSize,
activation_mode_t active_gate) {
#ifdef __AVX__
__m256 rValueUpdateGate;
__m256 rValueResetGate;
__m256 rValueResetOutput;
__m256 rPrevOut = _mm256_set1_ps(0.0f);
__m256 *updateGate = (__m256 *)gateValue;
__m256 *resetGate = (__m256 *)(gateValue + frameSize);
for (int i = 0; i < frameSize / 8; i++) {
rValueUpdateGate = updateGate[i];
rValueResetGate = resetGate[i];
if (prevOutputValue) {
rPrevOut = ((__m256 *)prevOutputValue)[i];
}
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut,
rValueResetOutput, hppl::avx::forward[active_gate]);
updateGate[i] = rValueUpdateGate;
resetGate[i] = rValueResetGate;
((__m256 *)resetOutputValue)[i] = rValueResetOutput;
}
#endif
}
template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue,
T *prevOutputValue, T *outputValue,
int frameSize,
activation_mode_t active_node) {
#ifdef __AVX__
__m256 rValueUpdateGate;
__m256 rValueFrameState;
__m256 rPrevOut = _mm256_set1_ps(0.0f);
__m256 rOutput;
__m256 *updateGate = (__m256 *)gateValue;
__m256 *frameState = (__m256 *)(gateValue + frameSize * 2);
for (int i = 0; i < frameSize / 8; i++) {
rValueUpdateGate = updateGate[i];
rValueFrameState = frameState[i];
if (prevOutputValue) {
rPrevOut = ((__m256 *)prevOutputValue)[i];
}
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
hppl::avx::forward[active_node]);
frameState[i] = rValueFrameState;
((__m256 *)outputValue)[i] = rOutput;
}
#endif
}
template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput opResetOutput,
hl_gru_value<T> value, int frameSize,
int batchSize, activation_mode_t active_gate) {
for (int b = 0; b < batchSize; b++) {
if (OpResetOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_reset_output(
opResetOutput, value.gateValue, value.resetOutputValue,
value.prevOutValue, frameSize, active_gate);
} else {
hl_naive_gru_forward_reset_output(
opResetOutput, value.gateValue, value.resetOutputValue,
value.prevOutValue, frameSize, active_gate);
}
value.gateValue += frameSize * 3;
value.resetOutputValue += frameSize;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
}
}
}
template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput opFinalOutput,
hl_gru_value<T> value, int frameSize,
int batchSize, activation_mode_t active_node) {
for (int b = 0; b < batchSize; b++) {
if (OpFinalOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(opFinalOutput, value.gateValue,
value.prevOutValue, value.outputValue,
frameSize, active_node);
} else {
hl_naive_gru_forward_final_output(opFinalOutput, value.gateValue,
value.prevOutValue, value.outputValue,
frameSize, active_node);
}
value.gateValue += frameSize * 3;
value.outputValue += frameSize;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
}
}
}
template <class OpStateGrad, typename T>
void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *outputGrad,
int frameSize,
activation_mode_t active_node) {
T rUpdateGateValue;
T rUpdateGateGrad;
T rFrameStateValue;
T rFrameStateGrad;
T rOutGrad;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T *updateGateValue = gateValue;
T *updateGateGrad = gateGrad;
T *frameStateValue = gateValue + frameSize * 2;
T *frameStateGrad = gateGrad + frameSize * 2;
for (int i = 0; i < frameSize; i++) {
rUpdateGateValue = updateGateValue[i];
rFrameStateValue = frameStateValue[i];
rOutGrad = outputGrad[i];
if (prevOutValue) {
rPrevOutValue = prevOutValue[i];
}
if (prevOutGrad) {
rPrevOutGrad = prevOutGrad[i];
}
hppl::cpu::BackwardAct<T> act;
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
act(active_node));
updateGateGrad[i] = rUpdateGateGrad;
frameStateGrad[i] = rFrameStateGrad;
if (prevOutGrad) {
prevOutGrad[i] = rPrevOutGrad;
}
}
}
template <class OpResetGrad, typename T>
void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *resetOutputGrad,
int frameSize,
activation_mode_t active_gate) {
T rUpdateGateValue;
T rUpdateGateGrad;
T rResetGateValue;
T rResetGateGrad;
T rResetOutputGrad = 0;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T *updateGateValue = gateValue;
T *updateGateGrad = gateGrad;
T *resetGateValue = gateValue + frameSize;
T *resetGateGrad = gateGrad + frameSize;
for (int i = 0; i < frameSize; i++) {
rUpdateGateValue = updateGateValue[i];
rUpdateGateGrad = updateGateGrad[i];
rResetGateValue = resetGateValue[i];
if (prevOutValue && prevOutGrad) {
rResetOutputGrad = resetOutputGrad[i];
}
if (prevOutValue) {
rPrevOutValue = prevOutValue[i];
}
if (prevOutGrad) {
rPrevOutGrad = prevOutGrad[i];
}
hppl::cpu::BackwardAct<T> act;
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
act(active_gate));
updateGateGrad[i] = rUpdateGateGrad;
resetGateGrad[i] = rResetGateGrad;
if (prevOutGrad) {
prevOutGrad[i] = rPrevOutGrad;
}
}
}
template <class OpStateGrad, typename T>
void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *outputGrad,
int frameSize,
activation_mode_t active_node) {
#ifdef __AVX__
__m256 rUpdateGateValue;
__m256 rUpdateGateGrad;
__m256 rFrameStateValue;
__m256 rFrameStateGrad;
__m256 rOutGrad;
__m256 rPrevOutValue = _mm256_set1_ps(0.0f);
__m256 rPrevOutGrad = _mm256_set1_ps(0.0f);
__m256 *updateGateValue = (__m256 *)gateValue;
__m256 *updateGateGrad = (__m256 *)gateGrad;
__m256 *frameStateValue = (__m256 *)(gateValue + frameSize * 2);
__m256 *frameStateGrad = (__m256 *)(gateGrad + frameSize * 2);
for (int i = 0; i < frameSize / 8; i++) {
rUpdateGateValue = updateGateValue[i];
rFrameStateValue = frameStateValue[i];
rOutGrad = ((__m256 *)outputGrad)[i];
if (prevOutValue) {
rPrevOutValue = ((__m256 *)prevOutValue)[i];
}
if (prevOutGrad) {
rPrevOutGrad = ((__m256 *)prevOutGrad)[i];
}
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
hppl::avx::backward[active_node]);
updateGateGrad[i] = rUpdateGateGrad;
frameStateGrad[i] = rFrameStateGrad;
if (prevOutGrad) {
((__m256 *)prevOutGrad)[i] = rPrevOutGrad;
}
}
#endif
}
template <class OpResetGrad, typename T>
void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *resetOutputGrad,
int frameSize,
activation_mode_t active_gate) {
#ifdef __AVX__
__m256 rUpdateGateValue;
__m256 rUpdateGateGrad;
__m256 rResetGateValue;
__m256 rResetGateGrad;
__m256 rResetOutputGrad = _mm256_set1_ps(0.0f);
__m256 rPrevOutValue = _mm256_set1_ps(0.0f);
__m256 rPrevOutGrad = _mm256_set1_ps(0.0f);
__m256 *updateGateValue = (__m256 *)gateValue;
__m256 *updateGateGrad = (__m256 *)gateGrad;
__m256 *resetGateValue = (__m256 *)(gateValue + frameSize);
__m256 *resetGateGrad = (__m256 *)(gateGrad + frameSize);
for (int i = 0; i < frameSize / 8; i++) {
rUpdateGateValue = updateGateValue[i];
rUpdateGateGrad = updateGateGrad[i];
rResetGateValue = resetGateValue[i];
if (prevOutValue && prevOutGrad) {
rResetOutputGrad = ((__m256 *)resetOutputGrad)[i];
}
if (prevOutValue) {
rPrevOutValue = ((__m256 *)prevOutValue)[i];
}
if (prevOutGrad) {
rPrevOutGrad = ((__m256 *)prevOutGrad)[i];
}
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
hppl::avx::backward[active_gate]);
updateGateGrad[i] = rUpdateGateGrad;
resetGateGrad[i] = rResetGateGrad;
if (prevOutGrad) {
((__m256 *)prevOutGrad)[i] = rPrevOutGrad;
}
}
#endif
}
template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad opStateGrad, hl_gru_value<T> value,
hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node) {
for (int b = 0; b < batchSize; b++) {
if (OpStateGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_state_grad(
opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.outputGrad, frameSize, active_node);
} else {
hl_naive_gru_backward_state_grad(
opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.outputGrad, frameSize, active_node);
}
value.gateValue += frameSize * 3;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
}
grad.gateGrad += frameSize * 3;
grad.outputGrad += frameSize;
if (grad.prevOutGrad) {
grad.prevOutGrad += frameSize;
}
}
}
template <class OpResetGrad, typename T>
inline void backward_reset_grad(OpResetGrad opResetGrad, hl_gru_value<T> value,
hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_gate) {
for (int b = 0; b < batchSize; b++) {
if (OpResetGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_reset_grad(
opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate);
} else {
hl_naive_gru_backward_reset_grad(
opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue,
grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate);
}
value.gateValue += frameSize * 3;
if (value.prevOutValue) {
value.prevOutValue += frameSize;
}
grad.gateGrad += frameSize * 3;
grad.resetOutputGrad += frameSize;
if (grad.prevOutGrad) {
grad.prevOutGrad += frameSize;
}
}
}
#endif
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h"
#include <glog/logging.h>
namespace paddle {
namespace operators {
namespace math {
namespace detail {
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class OpResetOutput, bool isBatch, typename T>
__global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput,
T *gateValue, T *resetOutputValue,
T *prevOutputValue, int frameSize,
int batchSize,
activation_mode_t active_gate) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
resetOutputValue += batchIdx * frameSize;
}
T rPrevOut = 0;
T rValueResetOutput;
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0];
T rValueResetGate = gateValue[frameIdx + frameSize * 1];
if (prevOutputValue) {
if (isBatch) prevOutputValue += batchIdx * frameSize;
rPrevOut = prevOutputValue[frameIdx];
}
hppl::gpu::ForwardAct<T> act;
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput,
act(active_gate));
gateValue[frameIdx + frameSize * 0] = rValueUpdateGate;
gateValue[frameIdx + frameSize * 1] = rValueResetGate;
resetOutputValue[frameIdx] = rValueResetOutput;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class OpFinalOutput, bool isBatch, typename T>
__global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput,
T *gateValue, T *prevOutputValue,
T *outputValue, int frameSize,
int batchSize,
activation_mode_t active_node) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
outputValue += batchIdx * frameSize;
}
T rOutput;
T rPrevOut = 0;
T rValueUpdateGate = gateValue[frameIdx + frameSize * 0];
T rValueFrameState = gateValue[frameIdx + frameSize * 2];
if (prevOutputValue) {
if (isBatch) prevOutputValue += batchIdx * frameSize;
rPrevOut = prevOutputValue[frameIdx];
}
hppl::gpu::ForwardAct<T> act;
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
act(active_node));
gateValue[frameIdx + frameSize * 2] = rValueFrameState;
outputValue[frameIdx] = rOutput;
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class OpStateGrad, bool isBatch, typename T>
__global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *outputGrad,
int frameSize, int batchSize,
activation_mode_t active_node) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
gateGrad += batchIdx * 3 * frameSize;
outputGrad += batchIdx * frameSize;
}
T rUpdateGateGrad;
T rFrameStateGrad;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0];
T rFrameStateValue = gateValue[frameIdx + frameSize * 2];
T rOutGrad = outputGrad[frameIdx];
if (prevOutValue && prevOutGrad) {
if (isBatch) prevOutValue += batchIdx * frameSize;
rPrevOutValue = prevOutValue[frameIdx];
if (isBatch) prevOutGrad += batchIdx * frameSize;
rPrevOutGrad = prevOutGrad[frameIdx];
}
hppl::gpu::BackwardAct<T> act;
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
act(active_node));
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
gateGrad[frameIdx + frameSize * 2] = rFrameStateGrad;
if (prevOutGrad) {
prevOutGrad[frameIdx] = rPrevOutGrad;
}
}
/*
* threads(framePerBlock, batchPerBlock)
* grid(frameBlocks, batchBlocks)
*/
template <class OpResetGrad, bool isBatch, typename T>
__global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue,
T *gateGrad, T *prevOutValue,
T *prevOutGrad, T *resetOutputGrad,
int frameSize, int batchSize,
activation_mode_t active_gate) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
int batchIdx = 0;
if (isBatch) {
batchIdx = blockIdx.y * blockDim.y + threadIdx.y;
if (batchIdx >= batchSize) return;
gateValue += batchIdx * 3 * frameSize;
gateGrad += batchIdx * 3 * frameSize;
resetOutputGrad += batchIdx * frameSize;
}
T rResetGateGrad;
T rPrevOutValue = 0;
T rPrevOutGrad = 0;
T rResetOutputGrad = 0;
T rUpdateGateValue = gateValue[frameIdx + frameSize * 0];
T rUpdateGateGrad = gateGrad[frameIdx + frameSize * 0];
T rResetGateValue = gateValue[frameIdx + frameSize * 1];
if (prevOutValue && prevOutGrad) {
if (isBatch) prevOutValue += batchIdx * frameSize;
if (isBatch) prevOutGrad += batchIdx * frameSize;
rPrevOutValue = prevOutValue[frameIdx];
rPrevOutGrad = prevOutGrad[frameIdx];
rResetOutputGrad = resetOutputGrad[frameIdx];
}
hppl::gpu::BackwardAct<T> act;
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
act(active_gate));
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
gateGrad[frameIdx + frameSize * 1] = rResetGateGrad;
if (prevOutGrad) {
prevOutGrad[frameIdx] = rPrevOutGrad;
}
}
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/platform/hostdevice.h"
#include <type_traits>
namespace paddle {
namespace operators {
namespace math {
namespace detail {
namespace forward {
template <typename T>
class gru_resetOutput {
public:
/**
* @param[in,out] valueUpdateGate update gate
* @param[in,out] valueResetGate reset gate
* @param[in] prevOut previous output
* @param[out] valueResetOutput intermediate value for frame state
* @param[in] actGate forward function of gate
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueResetGate, T &prevOut,
T &valueResetOutput,
typename hppl::Active<T>::forward actGate) {
valueUpdateGate = actGate(valueUpdateGate);
valueResetGate = actGate(valueResetGate);
valueResetOutput = prevOut * valueResetGate;
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueResetGate,
__m256 &prevOut, __m256 &valueResetOutput,
typename hppl::Active<__m256>::forward actGate) {
valueUpdateGate = actGate(valueUpdateGate);
valueResetGate = actGate(valueResetGate);
valueResetOutput = _mm256_mul_ps(prevOut, valueResetGate);
}
#endif
#endif
};
template <typename T>
class gru_finalOutput {
public:
/**
* @param[in] valueUpdateGate update gate
* @param[in,out] valueFrameState frame state ({\tilde{h}_t})
* @param[in] prevOut previous output
* @param[out] valueOutput output
* @param[in] actInput forward function of node
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueFrameState, T &prevOut,
T &valueOutput,
typename hppl::Active<T>::forward actInput) {
valueFrameState = actInput(valueFrameState);
valueOutput = prevOut - (valueUpdateGate * prevOut) +
(valueUpdateGate * valueFrameState);
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueFrameState,
__m256 &prevOut, __m256 &valueOutput,
typename hppl::Active<__m256>::forward actInput) {
valueFrameState = actInput(valueFrameState);
valueOutput = _mm256_add_ps(
_mm256_sub_ps(prevOut, _mm256_mul_ps(valueUpdateGate, prevOut)),
_mm256_mul_ps(valueUpdateGate, valueFrameState));
}
#endif
#endif
};
} // namespace forward
namespace backward {
template <typename T>
class gru_stateGrad {
public:
/**
* @param[in] valueUpdateGate update gate value
* @param[out] gradUpdateGate update gate grad
* @param[in] valueFrameState frame state value
* @param[out] gradFrameState frame state grad
* @param[in] valuePrevOut previous output value
* @param[in,out] gradPrevOut previous output grad
* @param[in] gradOutput output grad
* @param[in] actInput backward function of frame state
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
T &valueFrameState, T &gradFrameState,
T &valuePrevOut, T &gradPrevOut, T &gradOutput,
typename hppl::Active<T>::backward actInput) {
gradUpdateGate = (gradOutput * valueFrameState);
gradUpdateGate -= (gradOutput * valuePrevOut);
gradPrevOut -= (gradOutput * valueUpdateGate);
gradPrevOut += gradOutput;
gradFrameState = actInput(gradOutput * valueUpdateGate, valueFrameState);
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate,
__m256 &valueFrameState, __m256 &gradFrameState,
__m256 &valuePrevOut, __m256 &gradPrevOut,
__m256 &gradOutput,
typename hppl::Active<__m256>::backward actInput) {
gradUpdateGate = _mm256_mul_ps(gradOutput, valueFrameState);
gradUpdateGate =
_mm256_sub_ps(gradUpdateGate, _mm256_mul_ps(gradOutput, valuePrevOut));
gradPrevOut = _mm256_add_ps(
_mm256_sub_ps(gradPrevOut, _mm256_mul_ps(gradOutput, valueUpdateGate)),
gradOutput);
gradFrameState =
actInput(_mm256_mul_ps(gradOutput, valueUpdateGate), valueFrameState);
}
#endif
#endif
};
template <typename T>
class gru_resetGrad {
public:
/**
* @param[in] valueUpdateGate update gate value
* @param[in,out] gradUpdateGate update gate grad
* @param[in] valueResetGate reset gate value
* @param[out] gradResetGate reset gate grad
* @param[in] valuePrevOut previous output value
* @param[in,out] gradPrevOut previous output grad
* @param[in] gradResetOutput reset output grad (temp val)
* @param[in] actGate backward function of gate
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
T &valueResetGate, T &gradResetGate,
T &valuePrevOut, T &gradPrevOut,
T &gradResetOutput,
typename hppl::Active<T>::backward actGate) {
gradResetGate = (gradResetOutput * valuePrevOut);
gradPrevOut += (gradResetOutput * valueResetGate);
gradUpdateGate = actGate(gradUpdateGate, valueUpdateGate);
gradResetGate = actGate(gradResetGate, valueResetGate);
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
#else
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate,
__m256 &valueResetGate, __m256 &gradResetGate,
__m256 &valuePrevOut, __m256 &gradPrevOut,
__m256 &gradResetOutput,
typename hppl::Active<__m256>::backward actGate) {
gradResetGate = _mm256_mul_ps(gradResetOutput, valuePrevOut);
gradPrevOut = _mm256_add_ps(gradPrevOut,
_mm256_mul_ps(gradResetOutput, valueResetGate));
gradUpdateGate = actGate(gradUpdateGate, valueUpdateGate);
gradResetGate = actGate(gradResetGate, valueResetGate);
}
#endif
#endif
};
} // namespace backward
} // namespace detail
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/detail/gru_cpu_kernel.h"
#include "paddle/operators/math/detail/gru_kernel.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct GRUUnitFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate) {
#ifndef __NVCC__
if (value.prevOutValue) {
math::gemm<platform::CPUPlace, T>(
context, false, false, batchSize, frameSize * 2, frameSize, 1,
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1,
value.gateValue, frameSize * 3);
}
detail::forward_reset_output(detail::forward::gru_resetOutput<T>(), value,
frameSize, batchSize, active_gate);
if (value.prevOutValue) {
math::gemm<platform::CPUPlace, T>(
context, false, false, batchSize, frameSize, frameSize, 1,
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1,
value.gateValue + frameSize * 2, frameSize * 3);
}
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frameSize, batchSize, active_node);
#endif
}
};
template <typename T>
struct GRUUnitGradFunctor<platform::CPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate) {
#ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frameSize, batchSize, active_node);
if (value.prevOutValue && grad.prevOutGrad) {
math::gemm<platform::CPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize, 1,
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight,
frameSize, 0, grad.resetOutputGrad, frameSize);
if (grad.stateWeightGrad) {
math::gemm<platform::CPUPlace, T>(
context, true, false, frameSize, frameSize, batchSize, 1,
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2,
frameSize * 3, 1, grad.stateWeightGrad, frameSize);
}
}
detail::backward_reset_grad(detail::backward::gru_resetGrad<T>(), value,
grad, frameSize, batchSize, active_gate);
if (grad.prevOutGrad && value.prevOutValue) {
math::gemm<platform::CPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize * 2, 1,
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1,
grad.prevOutGrad, frameSize);
if (grad.gateWeightGrad) {
math::gemm<platform::CPUPlace, T>(
context, true, false, frameSize, frameSize * 2, batchSize, 1,
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1,
grad.gateWeightGrad, frameSize * 2);
}
}
#endif
}
};
template struct GRUUnitFunctor<platform::CPUPlace, float>;
template struct GRUUnitFunctor<platform::CPUPlace, double>;
template struct GRUUnitGradFunctor<platform::CPUPlace, float>;
template struct GRUUnitGradFunctor<platform::CPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/gru_gpu_kernel.h"
#include "paddle/operators/math/detail/gru_kernel.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <typename T>
struct GRUUnitFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate) {
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
} else {
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
}
if (value.prevOutValue) {
math::gemm<platform::GPUPlace, T>(
context, false, false, batchSize, frameSize * 2, frameSize, 1,
value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1,
value.gateValue, frameSize * 3);
}
if (batchSize == 1) {
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
/* isBatch= */ false,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_resetOutput<T>(), value.gateValue,
value.resetOutputValue, value.prevOutValue, frameSize, batchSize,
active_gate);
} else {
detail::KeGruForwardResetOutput<detail::forward::gru_resetOutput<T>,
/* isBatch= */ true,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_resetOutput<T>(), value.gateValue,
value.resetOutputValue, value.prevOutValue, frameSize, batchSize,
active_gate);
}
if (value.prevOutValue) {
math::gemm<platform::GPUPlace, T>(
context, false, false, batchSize, frameSize, frameSize, 1,
value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1,
value.gateValue + frameSize * 2, frameSize * 3);
}
if (batchSize == 1) {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* isBatch= */ false,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gateValue,
value.prevOutValue, value.outputValue, frameSize, batchSize,
active_node);
} else {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* isBatch= */ true,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gateValue,
value.prevOutValue, value.outputValue, frameSize, batchSize,
active_node);
}
}
};
template <typename T>
struct GRUUnitGradFunctor<platform::GPUPlace, T> {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate) {
auto stream =
reinterpret_cast<const platform::CUDADeviceContext &>(context).stream();
dim3 threads;
dim3 grid;
if (batchSize == 1) {
int framePerBlock = frameSize <= 1024 ? frameSize : 1024;
int frameBlocks = (frameSize + 1024 - 1) / 1024;
threads = dim3(framePerBlock, 1);
grid = dim3(frameBlocks, 1);
} else {
threads = dim3(32, 32);
grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32);
}
if (batchSize == 1) {
detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize,
batchSize, active_node);
} else {
detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize,
batchSize, active_node);
}
if (value.prevOutValue && grad.prevOutGrad) {
math::gemm<platform::GPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize, 1,
grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight,
frameSize, 0, grad.resetOutputGrad, frameSize);
if (grad.stateWeightGrad) {
math::gemm<platform::GPUPlace, T>(
context, true, false, frameSize, frameSize, batchSize, 1,
value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2,
frameSize * 3, 1, grad.stateWeightGrad, frameSize);
}
}
if (batchSize == 1) {
detail::KeGruBackwardResetGrad<
detail::backward::gru_resetGrad<T>,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize,
batchSize, active_gate);
} else {
detail::KeGruBackwardResetGrad<
detail::backward::gru_resetGrad<T>,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_resetGrad<T>(), value.gateValue, grad.gateGrad,
value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize,
batchSize, active_gate);
}
if (grad.prevOutGrad && value.prevOutValue) {
math::gemm<platform::GPUPlace, T>(
context, false, true, batchSize, frameSize, frameSize * 2, 1,
grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1,
grad.prevOutGrad, frameSize);
if (grad.gateWeightGrad) {
math::gemm<platform::GPUPlace, T>(
context, true, false, frameSize, frameSize * 2, batchSize, 1,
value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1,
grad.gateWeightGrad, frameSize * 2);
}
}
}
};
template struct GRUUnitFunctor<platform::GPUPlace, float>;
template struct GRUUnitFunctor<platform::GPUPlace, double>;
template struct GRUUnitGradFunctor<platform::GPUPlace, float>;
template struct GRUUnitGradFunctor<platform::GPUPlace, double>;
} // namespace math
} // namespace operators
} // namespace paddle
\ No newline at end of file
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace operators {
namespace math {
// typedef enum {
// HL_ACTIVATION_SIGMOID = 0,
// HL_ACTIVATION_RELU = 1,
// HL_ACTIVATION_TANH = 2,
// HL_ACTIVATION_LINEAR = 3,
// HL_ACTIVATION_END
// } activation_mode_t;
// inline activation_mode_t ActiveType(const std::string &type) {
// if (type == "sigmoid") {
// return HL_ACTIVATION_SIGMOID;
// } else if (type == "relu") {
// return HL_ACTIVATION_RELU;
// } else if (type == "tanh") {
// return HL_ACTIVATION_TANH;
// } else if (type == "linear" || type == "") {
// return HL_ACTIVATION_LINEAR;
// } else {
// PADDLE_THROW("Do not support activation type.");
// }
// }
template <typename T>
struct hl_gru_value {
T *gateWeight;
T *stateWeight;
T *gateValue;
T *resetOutputValue;
T *outputValue;
T *prevOutValue;
};
template <typename T>
struct hl_gru_grad {
T *gateWeightGrad;
T *stateWeightGrad;
T *gateGrad;
T *resetOutputGrad;
T *outputGrad;
T *prevOutGrad;
};
template <typename Place, typename T>
struct GRUUnitFunctor {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, int frameSize, int batchSize,
activation_mode_t active_node,
activation_mode_t active_gate);
};
template <typename Place, typename T>
struct GRUUnitGradFunctor {
static void compute(const platform::DeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate);
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -21,6 +21,128 @@ namespace paddle {
namespace operators {
namespace math {
// template <typename Place, typename T>
// class CopyMatrixRowsFunctor {
// public:
// // If is_src_index is true,
// // copy the indexed rows of input src to the output dst.
// // If is_src_index is false,
// // copy the input src to the indexed rows of output dst.
// // The indexed rows are based on the input index.
// void operator()(const platform::DeviceContext& context,
// const framework::LoDTensor& src, const size_t* index,
// framework::LoDTensor& dst, bool is_src_index);
// };
// template <typename Place, typename T>
// class LoDTensor2BatchFunctor {
// // Calculate the length of each sequence and
// // sort sequence index by the length.
// // example: sequences = {s0, s1, s2}
// // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)}
// //
// struct SeqInfo {
// SeqInfo(int start, int length, int seq_idx)
// : start(start), length(length), seq_idx(seq_idx) {}
// int start;
// int length;
// int seq_idx;
// };
// public:
// void operator()(const platform::DeviceContext& context,
// const framework::LoDTensor& lod_tensor,
// framework::LoDTensor& batch, bool is_reverse) const {
// auto lods = lod_tensor.lod();
// PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence
// now.");
// auto lod = lods[0];
// std::vector<SeqInfo> seq_info;
// for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
// int length = lod[seq_id + 1] - lod[seq_id];
// seq_info.emplace_back(lod[seq_id], length, seq_id);
// }
// std::sort(seq_info.begin(), seq_info.end(),
// [](SeqInfo a, SeqInfo b) { return a.length > b.length; });
// // calculate the start position of each batch
// // (numBatch equal the maxLength of sequences)
// // example: sequences = {s0, s1, s2}
// // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// // num_batch = 5,
// // batchIndex = {b0, b1, b2, b3, b4}
// // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1
// // batch_start_positions[6] = {0, 3, 6, 9, 11, 12}
// // batch_start_positions[0] = len(b0)
// // batch_start_positions[1] = len(b0) + len(b1)
// // batch_start_positions[2] = len(b0) + len(b1) + len(b2)
// // ...
// // seq2batch_idx[12] = {4, 0, 9,
// // 5, 1, 10,
// // 6, 2, 11,
// // 7, 3,
// // 8}
// // The batch number represents batch size after rearranging the
// // input LodTensor. It is also the maximum length of input sequence.
// paddle::framework::LoD batch_lods;
// batch_lods.emplace_back(std::vector<size_t>{0});
// batch_lods.emplace_back(std::vector<size_t>{0});
// // batch_lods[0] is the start positions for batch LoDTensor
// int num_batch = seq_info[0].length;
// batch_lods[0].resize(static_cast<size_t>(num_batch + 1));
// // batch_lods[1] is the raw index in the input LoDTensor
// auto dims = lod_tensor.dims();
// batch_lods[1].resize(static_cast<size_t>(dims[0]));
// size_t* batch_starts = batch_lods[0].data();
// size_t* seq2batch_idx = batch_lods[1].data();
// batch_starts[0] = 0;
// for (size_t n = 0; n < num_batch; n++) {
// auto batch_id = static_cast<int>(batch_starts[n]);
// for (size_t i = 0; i < seq_info.size(); ++i) {
// size_t seq_len = seq_info[i].length;
// int start = seq_info[i].start;
// if (n < seq_len) {
// seq2batch_idx[batch_id] =
// is_reverse ? start + seq_len - 1 - n : start + n;
// batch_id++;
// } else {
// break;
// }
// }
// batch_starts[n + 1] = static_cast<size_t>(batch_id);
// }
// batch.set_lod(batch_lods);
// CopyMatrixRowsFunctor<Place, T> to_batch;
// to_batch(context, lod_tensor, seq2batch_idx, batch, true);
// }
// };
// template <typename Place, typename T>
// class Batch2LoDTensorFunctor {
// public:
// void operator()(const platform::DeviceContext& context,
// const framework::LoDTensor& batch,
// framework::LoDTensor& lod_tensor) const {
// auto in_lod = batch.lod();
// PADDLE_ENFORCE_EQ(in_lod.size(), 2UL,
// "The LoD size of input `batch` should be 2.");
// auto out_lod = lod_tensor.lod()[0];
// auto num = out_lod[out_lod.size() - 1];
// PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]);
// PADDLE_ENFORCE_EQ(num, in_lod[1].size());
// PADDLE_ENFORCE_EQ(num, batch.dims()[0]);
// CopyMatrixRowsFunctor<Place, T> to_seq;
// size_t* index = in_lod[1].data();
// to_seq(context, batch, index, lod_tensor, false);
// }
// };
template <typename Place, typename T>
class CopyMatrixRowsFunctor {
public:
......@@ -53,7 +175,18 @@ class LoDTensor2BatchFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& lod_tensor,
framework::LoDTensor& batch, bool is_reverse) const {
framework::LoDTensor& batch, bool is_reverse = false,
bool is_cal_batch_lod = true) const {
if (!is_cal_batch_lod) {
auto lods = batch.lod();
PADDLE_ENFORCE_EQ(lods.size(), 2UL);
PADDLE_ENFORCE_EQ(lods[1].size(),
static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_batch;
to_batch(context, lod_tensor, lods[1].data(), batch, true);
return;
}
auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
auto lod = lods[0];
......@@ -101,10 +234,10 @@ class LoDTensor2BatchFunctor {
size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (size_t n = 0; n < num_batch; n++) {
for (int n = 0; n < num_batch; n++) {
auto batch_id = static_cast<int>(batch_starts[n]);
for (size_t i = 0; i < seq_info.size(); ++i) {
size_t seq_len = seq_info[i].length;
int seq_len = seq_info[i].length;
int start = seq_info[i].start;
if (n < seq_len) {
seq2batch_idx[batch_id] =
......@@ -132,11 +265,8 @@ class Batch2LoDTensorFunctor {
auto in_lod = batch.lod();
PADDLE_ENFORCE_EQ(in_lod.size(), 2UL,
"The LoD size of input `batch` should be 2.");
auto out_lod = lod_tensor.lod()[0];
auto num = out_lod[out_lod.size() - 1];
PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]);
PADDLE_ENFORCE_EQ(num, in_lod[1].size());
PADDLE_ENFORCE_EQ(num, batch.dims()[0]);
PADDLE_ENFORCE_EQ(in_lod[1].size(),
static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_seq;
size_t* index = in_lod[1].data();
to_seq(context, batch, index, lod_tensor, false);
......
import unittest
import numpy as np
import math
from op_test import OpTest
SIGMOID_THRESHOLD_MIN = -40.0
SIGMOID_THRESHOLD_MAX = 13.0
EXP_MAX_INPUT = 40.0
def identity(x):
return x
def sigmoid(x):
y = np.copy(x)
y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN
y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX
return 1. / (1. + np.exp(-y))
def tanh(x):
y = -2. * x
y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT
return (2. / (1. + np.exp(y))) - 1.
def relu(x):
return np.maximum(x, 0)
class TestGRUOp(OpTest):
batch_size = 9
frame_size = 5
activate = {
'identity': identity,
'sigmoid': sigmoid,
'tanh': tanh,
'relu': relu
}
@staticmethod
def seq_to_batch(lod, is_reverse):
idx_in_seq_list = []
seq_starts = lod[0]
seq_lens = []
for i in range(len(seq_starts) - 1):
seq_lens.append(seq_starts[i + 1] - seq_starts[i])
sorted_seqs = sorted(
range(len(seq_lens)), lambda x, y: seq_lens[y] - seq_lens[x])
num_batch = seq_lens[sorted_seqs[0]]
for batch_idx in range(num_batch):
idx_in_seq = []
for i in range(len(seq_lens)):
if seq_lens[sorted_seqs[i]] <= batch_idx:
break
idx = (seq_starts[sorted_seqs[i] + 1] - 1 - batch_idx
) if is_reverse else (
seq_starts[sorted_seqs[i]] + batch_idx)
idx_in_seq.append(idx)
idx_in_seq_list.append(idx_in_seq)
return idx_in_seq_list
def gru_step(self, x, h_p, w, b):
print x.shape, h_p.shape, w.shape, b.shape
batch_size = x.shape[0]
frame_size = w.shape[0]
g = x + np.tile(b, (batch_size, 1))
w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape(
(frame_size, frame_size * 2))
u_r = self.activate[self.attrs['gate_activation']](np.dot(
h_p, w_u_r) + g[:, :frame_size * 2])
u = u_r[:, :frame_size]
r = u_r[:, frame_size:frame_size * 2]
r_h_p = r * h_p
w_c = w.flatten()[frame_size * frame_size * 2:].reshape(
(frame_size, frame_size))
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
g = np.hstack((u_r, c))
h = u * c + (1 - u) * h_p
return g, r_h_p, h
def gru(self):
input, lod = self.inputs['Input']
w = self.inputs['Weight']
b = self.inputs['Bias'] if self.inputs.has_key('Bias') else np.zeros(
(1, self.frame_size * 3))
batch_gate = self.outputs['BatchGate']
batch_reset_hidden_prev = self.outputs['BatchResetHiddenPrev']
batch_hidden = self.outputs['BatchHidden']
hidden = self.outputs['Hidden']
idx_in_seq_list = self.idx_in_seq_list
h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros(
(len(idx_in_seq_list[0]), self.frame_size))
num_batch = len(idx_in_seq_list)
end_idx = 0
for batch_idx in range(num_batch):
print idx_in_seq_list[batch_idx]
x = input[idx_in_seq_list[batch_idx]]
g, r_h_p, h = self.gru_step(x, h_p, w, b)
if batch_idx < (num_batch - 1):
h_p = h[:len(idx_in_seq_list[batch_idx + 1])]
start_idx = end_idx
end_idx = start_idx + len(idx_in_seq_list[batch_idx])
batch_gate[start_idx:end_idx] = g
batch_reset_hidden_prev[start_idx:end_idx] = r_h_p
batch_hidden[start_idx:end_idx] = h
hidden[idx_in_seq_list[batch_idx]] = h
return batch_gate, batch_reset_hidden_prev, hidden
def set_data(self):
lod = [[0, 2, 6, 9]] #[[0, 1, 2, 3]]
self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse)
print self.idx_in_seq_list
batch_size = self.batch_size
frame_size = self.frame_size
input = np.random.rand(batch_size, frame_size * 3).astype('float64')
h0 = np.random.rand(len(self.idx_in_seq_list[0]),
frame_size).astype('float64')
weight = np.random.rand(frame_size, frame_size * 3).astype('float64')
bias = np.random.rand(1, frame_size * 3).astype('float64')
self.inputs = {
'Input': (input, lod),
'H0': h0,
'Weight': weight,
'Bias': bias
}
self.outputs = {
'BatchGate': np.zeros(
(batch_size, frame_size * 3), dtype='float64'),
'BatchResetHiddenPrev': np.zeros(
(batch_size, frame_size), dtype='float64'),
'BatchHidden': np.zeros(
(batch_size, frame_size), dtype='float64'),
'Hidden': np.zeros(
(batch_size, frame_size), dtype='float64')
}
def set_confs(self):
self.is_reverse = False
self.attrs = {
'activation': 'tanh',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
def setUp(self):
self.op_type = "gru"
self.set_confs()
self.set_data()
self.gru()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpNoInitial(TestGRUOp):
def set_data(self):
super(TestGRUOpNoInitial, self).set_data()
self.inputs.pop('H0')
def test_check_grad(self):
self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOpReverse(TestGRUOp):
def set_confs(self):
self.is_reverse = True
self.attrs = {
'activation': 'identity',
'gate_activation': 'sigmoid',
'is_reverse': self.is_reverse
}
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册