diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index f97bc837dca09060c55cae6a5524c49cd69df28b..2b5fe7e350afb3e04379031b78837a245d5d7a2f 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -116,7 +116,8 @@ set(DEPS_OPS sum_op pool_op pool_with_index_op - lstm_op) + lstm_op + gru_op) op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc @@ -128,6 +129,7 @@ op_library(sum_op DEPS net_op) op_library(pool_op DEPS pooling) op_library(pool_with_index_op DEPS pooling) op_library(lstm_op DEPS sequence2batch lstm_compute) +op_library(gru_op DEPS sequence2batch gru_compute) list(REMOVE_ITEM GENERAL_OPS ${DEPS_OPS}) foreach(src ${GENERAL_OPS}) diff --git a/paddle/operators/gru_op.cc b/paddle/operators/gru_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e80e170fb94cf73697db955230689d4ee7cf4068 --- /dev/null +++ b/paddle/operators/gru_op.cc @@ -0,0 +1,213 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/gru_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class GRUOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(%s) of GRUOp should not be null.", "Input"); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(%s) of GRUOp should not be null.", "Weight"); + PADDLE_ENFORCE(ctx->HasOutput("BatchGate"), + "Output(%s) of GRUOp should not be null.", "BatchGate"); + PADDLE_ENFORCE(ctx->HasOutput("BatchResetHiddenPrev"), + "Output(%s) of GRUOp should not be null.", + "BatchResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"), + "Output(%s) of GRUOp should not be null.", "BatchHidden"); + PADDLE_ENFORCE(ctx->HasOutput("Hidden"), + "Output(%s) of GRUOp should not be null.", "Hidden"); + auto input_dims = ctx->GetInputDim("Input"); + auto weight_dims = ctx->GetInputDim("Weight"); + int input_size = input_dims[1]; + int frame_size = weight_dims[0]; + PADDLE_ENFORCE_EQ(input_size, frame_size * 3, + "The input_size must be 3 times of frame_size in GRUOp."); + PADDLE_ENFORCE_EQ( + weight_dims[1], frame_size * 3, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + auto h0 = Input("H0"); + if (h0 != framework::kEmptyVarName) { + auto h0_dims = ctx->GetInputDim("H0"); + PADDLE_ENFORCE_EQ(h0_dims[1], frame_size, + "The width of H0 must be equal to frame_size."); + } + auto bias = Input("Bias"); + if (bias != framework::kEmptyVarName) { + auto bias_dims = ctx->GetInputDim("Bias"); + int bias_height = bias_dims[0]; + int bias_width = bias_dims[1]; + PADDLE_ENFORCE_EQ(bias_height, 1, + "The shape of Bias must be [1, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, + "The shape of Bias must be [1, frame_size * 3]."); + } + ctx->SetOutputDim("BatchGate", input_dims); + ctx->SetOutputDim("BatchResetHiddenPrev", {input_dims[0], frame_size}); + ctx->SetOutputDim("BatchHidden", {input_dims[0], frame_size}); + ctx->SetOutputDim("Hidden", {input_dims[0], frame_size}); + // ctx->ShareLoD("Input", "Gate"); + // ctx->ShareLoD("Input", "ResetHiddenPrev"); + ctx->ShareLoD("Input", "Hidden"); + } +}; + +class GRUOpMaker : public framework::OpProtoAndCheckerMaker { + public: + GRUOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Input", + "(LoDTensor) the first input is a LodTensor, which support " + "variable-time length input sequence. The underlying tensor in " + "this LoDTenosr is a matrix with shape (T X 3D), where, T is the " + "total time steps in this mini-batch, D is the hidden size."); + AddInput("H0", + "(Tensor, optional) the initial hidden state is an optional " + "input. This is a tensor with shape (N x D), where N is the " + "batch size, D is the hidden size."); + AddInput( + "Weight", + "(Tensor) Weight matrix with shape [hidden_size, hidden_size * 3]. " + "The elements continuous in memory can be divided into two parts. " + "The first part are weights of the update gate and reset gate " + "with shape [hidden_size, hidden_size * 2], and the second part are " + "weights of output candidate with shape [hidden_size, hidden_size]"); + AddInput("Bias", + "(Tensor) Bias vector with shape [1, hidden_size * 3] concating " + "bias of the update gate, reset gate and output candidate."); + AddOutput("BatchGate", + "(LoDTensor) the update gata, reset gate and output candidate " + "lod tensor of GRU operator. " + "The shape and lod is the same with the `Input`.") + .AsIntermediate(); + AddOutput( + "BatchResetHiddenPrev", + "(LoDTensor) the reseted hidden state lod tensor of GRU operator. " + "The shape and lod is the same with the `Input`.") + .AsIntermediate(); + AddOutput( + "BatchHidden", + "(LoDTensor) the reseted hidden state lod tensor of GRU operator. " + "The shape and lod is the same with the `Input`.") + .AsIntermediate(); + AddOutput("Hidden", + "(LoDTensor) the hidden state lod tensor of GRU operator. " + "The shape and lod is the same with the `Input`."); + AddAttr("activation", + "(string, default tanh) " + "The activation type used for output candidate {h}_t.") + .SetDefault("tanh"); + AddAttr( + "gate_activation", + "(string, default sigmoid) " + "The activation type used in update gate and reset gate.") + .SetDefault("sigmoid"); + AddAttr("is_reverse", + "(bool, defalut: False) " + "whether to compute reversed GRU.") + .SetDefault(false); + AddComment(R"DOC( +GRUOp implements part calculations of the GRU unit as following: +\f[ +update \ gate: u_t = actGate(xu_t + W_u * hidden_prev + bias_u) \\ +reset \ gate: r_t = actGate(xr_t + W_r * hidden_prev + bias_r) \\ +output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, hidden_prev) + bias_c) \\ +output: h_t = dot((1-u_t), hidden_prev) + dot(u_t, {h}_t) +\f] +The rest of GRU unit can be completed by using FCOp's output as the input of GRUOp. +)DOC"); + } +}; + +class GRUGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), + "Input(%s) of GRUGradOp should not be null.", "Input"); + PADDLE_ENFORCE(ctx->HasInput("Weight"), + "Input(%s) of GRUGradOp should not be null.", "Weight"); + PADDLE_ENFORCE(ctx->HasInput("BatchGate"), + "Input(%s) of GRUGradOp should not be null.", "BatchGate"); + PADDLE_ENFORCE(ctx->HasInput("BatchResetHiddenPrev"), + "Input(%s) of GRUGradOp should not be null.", + "BatchResetHiddenPrev"); + PADDLE_ENFORCE(ctx->HasInput("BatchHidden"), + "Input(%s) of GRUOp should not be null.", "BatchHidden"); + PADDLE_ENFORCE(ctx->HasInput("Hidden"), + "Input(%s) of GRUGradOp should not be null.", "Hidden"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), + "Input(%s@GRAD) of GRUGradOp should not be null.", "Hidden"); + auto input_dims = ctx->GetInputDim("Input"); + auto weight_dims = ctx->GetInputDim("Weight"); + int input_size = input_dims[1]; + int frame_size = weight_dims[0]; + int weight_height = weight_dims[0]; + int weight_width = weight_dims[1]; + PADDLE_ENFORCE_EQ(input_size, frame_size * 3, + "The input_size must be 3 times of frame_size in GRUOp."); + PADDLE_ENFORCE_EQ( + weight_height, frame_size, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + PADDLE_ENFORCE_EQ( + weight_width, frame_size * 3, + "The shape of Weight matrix must be [frame_size, frame_size * 3]."); + auto h0 = Input("H0"); + if (h0 != framework::kEmptyVarName) { + auto h0_dims = ctx->GetInputDim("H0"); + PADDLE_ENFORCE_EQ(h0_dims[1], frame_size, + "The width of H0 must be equal to frame_size."); + auto h0_grad_name = framework::GradVarName("H0"); + if (ctx->HasOutput(h0_grad_name)) + ctx->SetOutputDim(h0_grad_name, h0_dims); + } + auto bias = Input("Bias"); + if (bias != framework::kEmptyVarName) { + auto bias_dims = ctx->GetInputDim("Bias"); + int bias_height = bias_dims[0]; + int bias_width = bias_dims[1]; + PADDLE_ENFORCE_EQ(bias_height, 1, + "The shape of Bias must be [1, frame_size * 3]."); + PADDLE_ENFORCE_EQ(bias_width, frame_size * 3, + "The shape of Bias must be [1, frame_size * 3]."); + auto bias_grad_name = framework::GradVarName("Bias"); + if (ctx->HasOutput(bias_grad_name)) + ctx->SetOutputDim(bias_grad_name, bias_dims); + } + auto input_grad_name = framework::GradVarName("Input"); + if (ctx->HasOutput(input_grad_name)) + ctx->SetOutputDim(input_grad_name, input_dims); + auto weight_grad_name = framework::GradVarName("Weight"); + if (ctx->HasOutput(weight_grad_name)) + ctx->SetOutputDim(weight_grad_name, weight_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(gru, ops::GRUOp, ops::GRUOpMaker, gru_grad, ops::GRUGradOp); +REGISTER_OP_CPU_KERNEL(gru, ops::GRUKernel, + ops::GRUKernel); +REGISTER_OP_CPU_KERNEL(gru_grad, + ops::GRUGradKernel, + ops::GRUGradKernel); diff --git a/paddle/operators/gru_op.cu b/paddle/operators/gru_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..35538c74b4bf678f8068999bfadb2589a1671be0 --- /dev/null +++ b/paddle/operators/gru_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/gru_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(gru, ops::GRUKernel, + ops::GRUKernel); +REGISTER_OP_GPU_KERNEL(gru_grad, + ops::GRUGradKernel, + ops::GRUGradKernel); diff --git a/paddle/operators/gru_op.h b/paddle/operators/gru_op.h new file mode 100644 index 0000000000000000000000000000000000000000..a04dd8d05fe90a21a86d7e7bdc429d63d60fea71 --- /dev/null +++ b/paddle/operators/gru_op.h @@ -0,0 +1,258 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/operators/math/gru_compute.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/sequence2batch.h" + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using LoDTensor = framework::LoDTensor; + +template +using EigenMatrix = framework::EigenMatrix; + +template +class GRUKernel : public framework::OpKernel { + public: + void BatchCompute(const framework::ExecutionContext& context) const { + auto* input = context.Input("Input"); + auto* h0 = context.Input("H0"); + const T* h0_data = h0 ? h0->data() : nullptr; + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* bias = context.Input("Bias"); + auto* batch_gate = context.Output("BatchGate"); + batch_gate->mutable_data(context.GetPlace()); + auto* batch_reset_hidden_prev = + context.Output("BatchResetHiddenPrev"); + batch_reset_hidden_prev->mutable_data(context.GetPlace()); + auto* batch_hidden = context.Output("BatchHidden"); + batch_hidden->mutable_data(context.GetPlace()); + auto* hidden = context.Output("Hidden"); + hidden->mutable_data(context.GetPlace()); + + // context.ShareLoD("Input", "Gate"); + // context.ShareLoD("Input", "ResetHiddenPrev"); + context.ShareLoD("Input", "Hidden"); + + // auto gate_dims = gate->dims(); + auto hidden_dims = hidden->dims(); + + // LoDTensor batch_gate, batch_reset_hidden_prev, batch_hidden; + // batch_gate.mutable_data(gate_dims, context.GetPlace()); + // batch_reset_hidden_prev.mutable_data(hidden_dims, context.GetPlace()); + // batch_hidden.mutable_data(hidden_dims, context.GetPlace()); + + bool is_reverse = context.Attr("is_reverse"); + math::LoDTensor2BatchFunctor to_batch; + // to_batch(context.device_context(), *input, batch_gate, is_reverse); + to_batch(context.device_context(), *input, *batch_gate, is_reverse); + + int frame_size = hidden_dims[1]; + int batch_size = hidden_dims[0]; + // auto g = EigenMatrix::From(batch_gate); + auto g = EigenMatrix::From(*batch_gate); + auto place = context.GetEigenDevice(); + if (bias) { + auto b = EigenMatrix::From(*bias); + g.device(place) = g + + b.reshape(Eigen::array({{1, frame_size * 3}})) + .broadcast(Eigen::array({{batch_size, 1}})); + } + + math::hl_gru_value gru_value; + gru_value.gateWeight = const_cast(weight_data); + gru_value.stateWeight = + const_cast(weight_data + 2 * frame_size * frame_size); + gru_value.prevOutValue = const_cast(h0_data); + // auto batch_starts = batch_gate.lod()[0]; + auto batch_starts = batch_gate->lod()[0]; + // for (auto i = batch_gate->lod()[1].begin(); i != + // batch_gate->lod()[1].end(); ++i) + // std::cout << static_cast(*i) << ' '; + size_t num_batch = batch_starts.size() - 1; + for (size_t n = 0; n < num_batch; n++) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + // Tensor gate_t = batch_gate.Slice(bstart, bend); + // Tensor reset_hidden_prev_t = batch_reset_hidden_prev.Slice(bstart, + // bend); + Tensor gate_t = batch_gate->Slice(bstart, bend); + Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); + Tensor hidden_t = batch_hidden->Slice(bstart, bend); + gru_value.outputValue = hidden_t.data(); + gru_value.gateValue = gate_t.data(); + gru_value.resetOutputValue = reset_hidden_prev_t.data(); + math::GRUUnitFunctor::compute( + context.device_context(), gru_value, frame_size, cur_batch_size, + math::ActiveType(context.Attr("activation")), + math::ActiveType(context.Attr("gate_activation"))); + gru_value.prevOutValue = gru_value.outputValue; + } + + math::Batch2LoDTensorFunctor to_seq; + // batch_gate.set_lod(batch_gate.lod()); + // to_seq(context.device_context(), batch_gate, *gate); + // batch_reset_hidden_prev.set_lod(batch_gate.lod()); + // to_seq(context.device_context(), batch_reset_hidden_prev, + // *reset_hidden_prev); + // batch_hidden.set_lod(batch_gate.lod()); + // to_seq(context.device_context(), batch_hidden, *hidden); + batch_hidden->set_lod(batch_gate->lod()); + to_seq(context.device_context(), *batch_hidden, *hidden); + } + + void Compute(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + +template +class GRUGradKernel : public framework::OpKernel { + public: + void BatchCompute(const framework::ExecutionContext& context) const { + auto* h0 = context.Input("H0"); + const T* h0_data = h0 ? h0->data() : nullptr; + auto* weight = context.Input("Weight"); + const T* weight_data = weight->data(); + auto* batch_gate = context.Input("BatchGate"); + auto* batch_reset_hidden_prev = + context.Input("BatchResetHiddenPrev"); + auto* batch_hidden = context.Input("BatchHidden"); + auto* hidden = context.Input("Hidden"); + auto* hidden_grad = + context.Input(framework::GradVarName("Hidden")); + auto* input_grad = + context.Output(framework::GradVarName("Input")); + auto* h0_grad = context.Output(framework::GradVarName("H0")); + auto* weight_grad = + context.Output(framework::GradVarName("Weight")); + auto* bias_grad = context.Output(framework::GradVarName("Bias")); + + auto gate_dims = batch_gate->dims(); + auto hidden_dims = hidden->dims(); + int frame_size = hidden_dims[1]; + + math::LoDTensor2BatchFunctor to_batch; + LoDTensor batch_hidden_grad, batch_gate_grad, batch_reset_hidden_prev_grad; + batch_hidden_grad.mutable_data(hidden_dims, context.GetPlace()); + batch_gate_grad.mutable_data(gate_dims, context.GetPlace()); + batch_reset_hidden_prev_grad.mutable_data(hidden_dims, + context.GetPlace()); + math::SetConstant zero; + zero(context.device_context(), &batch_hidden_grad, static_cast(0.0)); + zero(context.device_context(), &batch_gate_grad, static_cast(0.0)); + zero(context.device_context(), &batch_reset_hidden_prev_grad, + static_cast(0.0)); + + // batch_hidden.set_lod(batch_gate->lod()); + bool is_reverse = context.Attr("is_reverse"); + batch_hidden_grad.set_lod(batch_hidden->lod()); + // context.ShareLoD(framework::GradVarName("Hidden"), + // framework::GradVarName("Input")); + to_batch(context.device_context(), *hidden_grad, batch_hidden_grad, + is_reverse, false); + + math::hl_gru_value gru_value; + gru_value.gateWeight = const_cast(weight_data); + gru_value.stateWeight = + const_cast(weight_data + 2 * frame_size * frame_size); + + math::hl_gru_grad gru_grad; + if (weight_grad) { + gru_grad.gateWeightGrad = + weight_grad->mutable_data(context.GetPlace()); + zero(context.device_context(), weight_grad, static_cast(0.0)); + gru_grad.stateWeightGrad = + weight_grad->data() + 2 * frame_size * frame_size; + } else { + gru_grad.gateWeightGrad = nullptr; + gru_grad.stateWeightGrad = nullptr; + } + + auto batch_starts = batch_hidden_grad.lod()[0]; + size_t num_batch = batch_starts.size() - 1; + for (int n = static_cast(num_batch) - 1; n >= 0; n--) { + int bstart = static_cast(batch_starts[n]); + int bend = static_cast(batch_starts[n + 1]); + int cur_batch_size = bend - bstart; + + Tensor gate_t = batch_gate->Slice(bstart, bend); + gru_value.gateValue = gate_t.data(); + Tensor reset_hidden_prev_t = batch_reset_hidden_prev->Slice(bstart, bend); + gru_value.resetOutputValue = reset_hidden_prev_t.data(); + + Tensor hidden_grad_t = batch_hidden_grad.Slice(bstart, bend); + gru_grad.outputGrad = hidden_grad_t.data(); + Tensor gate_grad_t = batch_gate_grad.Slice(bstart, bend); + gru_grad.gateGrad = gate_grad_t.data(); + Tensor reset_hidden_prev_grad_t = + batch_reset_hidden_prev_grad.Slice(bstart, bend); + gru_grad.resetOutputGrad = reset_hidden_prev_grad_t.data(); + if (n == 0) { + gru_value.prevOutValue = const_cast(h0_data); + if (h0_grad) { + T* h0_grad_data = h0_grad->mutable_data(context.GetPlace()); + zero(context.device_context(), h0_grad, static_cast(0.0)); + gru_grad.prevOutGrad = h0_grad_data; + } else { + gru_grad.prevOutGrad = nullptr; + } + } else { + int bstart_pre = static_cast(batch_starts[n - 1]); + Tensor hidden_prev_t = batch_hidden->Slice(bstart_pre, bstart); + gru_value.prevOutValue = hidden_prev_t.data(); + Tensor hidden_prev_grad_t = batch_hidden_grad.Slice(bstart_pre, bstart); + gru_grad.prevOutGrad = hidden_prev_grad_t.data(); + } + + math::GRUUnitGradFunctor::compute( + context.device_context(), gru_value, gru_grad, frame_size, + cur_batch_size, + math::ActiveType(context.Attr("activation")), + math::ActiveType(context.Attr("gate_activation"))); + } + if (input_grad) { + input_grad->mutable_data(context.GetPlace()); + math::Batch2LoDTensorFunctor to_seq; + batch_gate_grad.set_lod(batch_gate->lod()); + to_seq(context.device_context(), batch_gate_grad, *input_grad); + } + if (bias_grad) { + bias_grad->mutable_data(context.GetPlace()); + auto d_b = EigenMatrix::From(*bias_grad); + auto d_g = EigenMatrix::From(batch_gate_grad); + auto place = context.GetEigenDevice(); + d_b.device(place) = d_g.sum(Eigen::array({{0}})); + } + } + + void Compute(const framework::ExecutionContext& context) const override { + BatchCompute(context); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 5598669ef96535b7d47150052b3841771c37c60b..a29e2c59143c5a3d18b93f54f6ab66bd62ae1ebe 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -11,6 +11,7 @@ if(WITH_GPU) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) + nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) @@ -20,6 +21,7 @@ else() cc_library(vol2col SRCS vol2col.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) + cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/detail/gru_cpu_kernel.h b/paddle/operators/math/detail/gru_cpu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..378b87c8701f52f11dea990ba47f261d837c0e18 --- /dev/null +++ b/paddle/operators/math/detail/gru_cpu_kernel.h @@ -0,0 +1,428 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/operators/math/gru_compute.h" + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +#ifndef __NVCC__ + +template +void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput, + T *gateValue, T *resetOutputValue, + T *prevOutputValue, int frameSize, + activation_mode_t active_gate) { + T rValueUpdateGate; + T rValueResetGate; + T rValueResetOutput; + T rPrevOut = 0; + T *updateGate = gateValue; + T *resetGate = gateValue + frameSize; + + for (int i = 0; i < frameSize; i++) { + rValueUpdateGate = updateGate[i]; + rValueResetGate = resetGate[i]; + if (prevOutputValue) { + rPrevOut = prevOutputValue[i]; + } + + hppl::cpu::ForwardAct act; + opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, + rValueResetOutput, act(active_gate)); + + updateGate[i] = rValueUpdateGate; + resetGate[i] = rValueResetGate; + resetOutputValue[i] = rValueResetOutput; + } +} + +template +void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput, + T *gateValue, T *prevOutputValue, + T *outputValue, int frameSize, + activation_mode_t active_node) { + T rValueUpdateGate; + T rValueFrameState; + T rPrevOut = 0; + T rOutput; + T *updateGate = gateValue; + T *frameState = gateValue + frameSize * 2; + + for (int i = 0; i < frameSize; i++) { + rValueUpdateGate = updateGate[i]; + rValueFrameState = frameState[i]; + if (prevOutputValue) { + rPrevOut = prevOutputValue[i]; + } + + hppl::cpu::ForwardAct act; + opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, + act(active_node)); + + frameState[i] = rValueFrameState; + outputValue[i] = rOutput; + } +} + +template +void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue, + T *resetOutputValue, T *prevOutputValue, + int frameSize, + activation_mode_t active_gate) { +#ifdef __AVX__ + __m256 rValueUpdateGate; + __m256 rValueResetGate; + __m256 rValueResetOutput; + __m256 rPrevOut = _mm256_set1_ps(0.0f); + __m256 *updateGate = (__m256 *)gateValue; + __m256 *resetGate = (__m256 *)(gateValue + frameSize); + + for (int i = 0; i < frameSize / 8; i++) { + rValueUpdateGate = updateGate[i]; + rValueResetGate = resetGate[i]; + if (prevOutputValue) { + rPrevOut = ((__m256 *)prevOutputValue)[i]; + } + + opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, + rValueResetOutput, hppl::avx::forward[active_gate]); + + updateGate[i] = rValueUpdateGate; + resetGate[i] = rValueResetGate; + ((__m256 *)resetOutputValue)[i] = rValueResetOutput; + } +#endif +} + +template +void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue, + T *prevOutputValue, T *outputValue, + int frameSize, + activation_mode_t active_node) { +#ifdef __AVX__ + __m256 rValueUpdateGate; + __m256 rValueFrameState; + __m256 rPrevOut = _mm256_set1_ps(0.0f); + __m256 rOutput; + __m256 *updateGate = (__m256 *)gateValue; + __m256 *frameState = (__m256 *)(gateValue + frameSize * 2); + + for (int i = 0; i < frameSize / 8; i++) { + rValueUpdateGate = updateGate[i]; + rValueFrameState = frameState[i]; + if (prevOutputValue) { + rPrevOut = ((__m256 *)prevOutputValue)[i]; + } + + opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, + hppl::avx::forward[active_node]); + + frameState[i] = rValueFrameState; + ((__m256 *)outputValue)[i] = rOutput; + } +#endif +} + +template +inline void forward_reset_output(OpResetOutput opResetOutput, + hl_gru_value value, int frameSize, + int batchSize, activation_mode_t active_gate) { + for (int b = 0; b < batchSize; b++) { + if (OpResetOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { + hl_avx_gru_forward_reset_output( + opResetOutput, value.gateValue, value.resetOutputValue, + value.prevOutValue, frameSize, active_gate); + } else { + hl_naive_gru_forward_reset_output( + opResetOutput, value.gateValue, value.resetOutputValue, + value.prevOutValue, frameSize, active_gate); + } + + value.gateValue += frameSize * 3; + value.resetOutputValue += frameSize; + if (value.prevOutValue) { + value.prevOutValue += frameSize; + } + } +} + +template +inline void forward_final_output(OpFinalOutput opFinalOutput, + hl_gru_value value, int frameSize, + int batchSize, activation_mode_t active_node) { + for (int b = 0; b < batchSize; b++) { + if (OpFinalOutput::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { + hl_avx_gru_forward_final_output(opFinalOutput, value.gateValue, + value.prevOutValue, value.outputValue, + frameSize, active_node); + } else { + hl_naive_gru_forward_final_output(opFinalOutput, value.gateValue, + value.prevOutValue, value.outputValue, + frameSize, active_node); + } + + value.gateValue += frameSize * 3; + value.outputValue += frameSize; + if (value.prevOutValue) { + value.prevOutValue += frameSize; + } + } +} + +template +void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue, + T *gateGrad, T *prevOutValue, + T *prevOutGrad, T *outputGrad, + int frameSize, + activation_mode_t active_node) { + T rUpdateGateValue; + T rUpdateGateGrad; + T rFrameStateValue; + T rFrameStateGrad; + T rOutGrad; + T rPrevOutValue = 0; + T rPrevOutGrad = 0; + T *updateGateValue = gateValue; + T *updateGateGrad = gateGrad; + T *frameStateValue = gateValue + frameSize * 2; + T *frameStateGrad = gateGrad + frameSize * 2; + + for (int i = 0; i < frameSize; i++) { + rUpdateGateValue = updateGateValue[i]; + rFrameStateValue = frameStateValue[i]; + rOutGrad = outputGrad[i]; + if (prevOutValue) { + rPrevOutValue = prevOutValue[i]; + } + if (prevOutGrad) { + rPrevOutGrad = prevOutGrad[i]; + } + + hppl::cpu::BackwardAct act; + opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, + rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, + act(active_node)); + + updateGateGrad[i] = rUpdateGateGrad; + frameStateGrad[i] = rFrameStateGrad; + if (prevOutGrad) { + prevOutGrad[i] = rPrevOutGrad; + } + } +} + +template +void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue, + T *gateGrad, T *prevOutValue, + T *prevOutGrad, T *resetOutputGrad, + int frameSize, + activation_mode_t active_gate) { + T rUpdateGateValue; + T rUpdateGateGrad; + T rResetGateValue; + T rResetGateGrad; + T rResetOutputGrad = 0; + T rPrevOutValue = 0; + T rPrevOutGrad = 0; + T *updateGateValue = gateValue; + T *updateGateGrad = gateGrad; + T *resetGateValue = gateValue + frameSize; + T *resetGateGrad = gateGrad + frameSize; + + for (int i = 0; i < frameSize; i++) { + rUpdateGateValue = updateGateValue[i]; + rUpdateGateGrad = updateGateGrad[i]; + rResetGateValue = resetGateValue[i]; + + if (prevOutValue && prevOutGrad) { + rResetOutputGrad = resetOutputGrad[i]; + } + if (prevOutValue) { + rPrevOutValue = prevOutValue[i]; + } + if (prevOutGrad) { + rPrevOutGrad = prevOutGrad[i]; + } + + hppl::cpu::BackwardAct act; + opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, + rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, + act(active_gate)); + + updateGateGrad[i] = rUpdateGateGrad; + resetGateGrad[i] = rResetGateGrad; + if (prevOutGrad) { + prevOutGrad[i] = rPrevOutGrad; + } + } +} + +template +void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue, + T *gateGrad, T *prevOutValue, + T *prevOutGrad, T *outputGrad, + int frameSize, + activation_mode_t active_node) { +#ifdef __AVX__ + __m256 rUpdateGateValue; + __m256 rUpdateGateGrad; + __m256 rFrameStateValue; + __m256 rFrameStateGrad; + __m256 rOutGrad; + __m256 rPrevOutValue = _mm256_set1_ps(0.0f); + __m256 rPrevOutGrad = _mm256_set1_ps(0.0f); + __m256 *updateGateValue = (__m256 *)gateValue; + __m256 *updateGateGrad = (__m256 *)gateGrad; + __m256 *frameStateValue = (__m256 *)(gateValue + frameSize * 2); + __m256 *frameStateGrad = (__m256 *)(gateGrad + frameSize * 2); + + for (int i = 0; i < frameSize / 8; i++) { + rUpdateGateValue = updateGateValue[i]; + rFrameStateValue = frameStateValue[i]; + rOutGrad = ((__m256 *)outputGrad)[i]; + if (prevOutValue) { + rPrevOutValue = ((__m256 *)prevOutValue)[i]; + } + if (prevOutGrad) { + rPrevOutGrad = ((__m256 *)prevOutGrad)[i]; + } + + opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, + rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, + hppl::avx::backward[active_node]); + + updateGateGrad[i] = rUpdateGateGrad; + frameStateGrad[i] = rFrameStateGrad; + if (prevOutGrad) { + ((__m256 *)prevOutGrad)[i] = rPrevOutGrad; + } + } +#endif +} + +template +void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue, + T *gateGrad, T *prevOutValue, + T *prevOutGrad, T *resetOutputGrad, + int frameSize, + activation_mode_t active_gate) { +#ifdef __AVX__ + __m256 rUpdateGateValue; + __m256 rUpdateGateGrad; + __m256 rResetGateValue; + __m256 rResetGateGrad; + __m256 rResetOutputGrad = _mm256_set1_ps(0.0f); + __m256 rPrevOutValue = _mm256_set1_ps(0.0f); + __m256 rPrevOutGrad = _mm256_set1_ps(0.0f); + __m256 *updateGateValue = (__m256 *)gateValue; + __m256 *updateGateGrad = (__m256 *)gateGrad; + __m256 *resetGateValue = (__m256 *)(gateValue + frameSize); + __m256 *resetGateGrad = (__m256 *)(gateGrad + frameSize); + + for (int i = 0; i < frameSize / 8; i++) { + rUpdateGateValue = updateGateValue[i]; + rUpdateGateGrad = updateGateGrad[i]; + rResetGateValue = resetGateValue[i]; + + if (prevOutValue && prevOutGrad) { + rResetOutputGrad = ((__m256 *)resetOutputGrad)[i]; + } + if (prevOutValue) { + rPrevOutValue = ((__m256 *)prevOutValue)[i]; + } + if (prevOutGrad) { + rPrevOutGrad = ((__m256 *)prevOutGrad)[i]; + } + + opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, + rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, + hppl::avx::backward[active_gate]); + + updateGateGrad[i] = rUpdateGateGrad; + resetGateGrad[i] = rResetGateGrad; + if (prevOutGrad) { + ((__m256 *)prevOutGrad)[i] = rPrevOutGrad; + } + } +#endif +} + +template +inline void backward_state_grad(OpStateGrad opStateGrad, hl_gru_value value, + hl_gru_grad grad, int frameSize, + int batchSize, activation_mode_t active_node) { + for (int b = 0; b < batchSize; b++) { + if (OpStateGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { + hl_avx_gru_backward_state_grad( + opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue, + grad.prevOutGrad, grad.outputGrad, frameSize, active_node); + } else { + hl_naive_gru_backward_state_grad( + opStateGrad, value.gateValue, grad.gateGrad, value.prevOutValue, + grad.prevOutGrad, grad.outputGrad, frameSize, active_node); + } + + value.gateValue += frameSize * 3; + if (value.prevOutValue) { + value.prevOutValue += frameSize; + } + + grad.gateGrad += frameSize * 3; + grad.outputGrad += frameSize; + if (grad.prevOutGrad) { + grad.prevOutGrad += frameSize; + } + } +} + +template +inline void backward_reset_grad(OpResetGrad opResetGrad, hl_gru_value value, + hl_gru_grad grad, int frameSize, + int batchSize, activation_mode_t active_gate) { + for (int b = 0; b < batchSize; b++) { + if (OpResetGrad::avx && !(frameSize & (8 - 1)) && (sizeof(T) == 4)) { + hl_avx_gru_backward_reset_grad( + opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue, + grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate); + } else { + hl_naive_gru_backward_reset_grad( + opResetGrad, value.gateValue, grad.gateGrad, value.prevOutValue, + grad.prevOutGrad, grad.resetOutputGrad, frameSize, active_gate); + } + + value.gateValue += frameSize * 3; + if (value.prevOutValue) { + value.prevOutValue += frameSize; + } + + grad.gateGrad += frameSize * 3; + grad.resetOutputGrad += frameSize; + if (grad.prevOutGrad) { + grad.prevOutGrad += frameSize; + } + } +} + +#endif + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/gru_gpu_kernel.h b/paddle/operators/math/detail/gru_gpu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..f7f8c131a0d28a09f6a015dd06a5cca91f2af253 --- /dev/null +++ b/paddle/operators/math/detail/gru_gpu_kernel.h @@ -0,0 +1,207 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/operators/math/gru_compute.h" +#include "paddle/platform/cuda_helper.h" +#include "paddle/platform/device_context.h" + +#include + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput, + T *gateValue, T *resetOutputValue, + T *prevOutputValue, int frameSize, + int batchSize, + activation_mode_t active_gate) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + gateValue += batchIdx * 3 * frameSize; + resetOutputValue += batchIdx * frameSize; + } + + T rPrevOut = 0; + T rValueResetOutput; + T rValueUpdateGate = gateValue[frameIdx + frameSize * 0]; + T rValueResetGate = gateValue[frameIdx + frameSize * 1]; + + if (prevOutputValue) { + if (isBatch) prevOutputValue += batchIdx * frameSize; + rPrevOut = prevOutputValue[frameIdx]; + } + + hppl::gpu::ForwardAct act; + opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput, + act(active_gate)); + + gateValue[frameIdx + frameSize * 0] = rValueUpdateGate; + gateValue[frameIdx + frameSize * 1] = rValueResetGate; + resetOutputValue[frameIdx] = rValueResetOutput; +} + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput, + T *gateValue, T *prevOutputValue, + T *outputValue, int frameSize, + int batchSize, + activation_mode_t active_node) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + gateValue += batchIdx * 3 * frameSize; + outputValue += batchIdx * frameSize; + } + + T rOutput; + T rPrevOut = 0; + T rValueUpdateGate = gateValue[frameIdx + frameSize * 0]; + T rValueFrameState = gateValue[frameIdx + frameSize * 2]; + + if (prevOutputValue) { + if (isBatch) prevOutputValue += batchIdx * frameSize; + rPrevOut = prevOutputValue[frameIdx]; + } + + hppl::gpu::ForwardAct act; + opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, + act(active_node)); + + gateValue[frameIdx + frameSize * 2] = rValueFrameState; + outputValue[frameIdx] = rOutput; +} + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue, + T *gateGrad, T *prevOutValue, + T *prevOutGrad, T *outputGrad, + int frameSize, int batchSize, + activation_mode_t active_node) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + gateValue += batchIdx * 3 * frameSize; + gateGrad += batchIdx * 3 * frameSize; + outputGrad += batchIdx * frameSize; + } + + T rUpdateGateGrad; + T rFrameStateGrad; + T rPrevOutValue = 0; + T rPrevOutGrad = 0; + T rUpdateGateValue = gateValue[frameIdx + frameSize * 0]; + T rFrameStateValue = gateValue[frameIdx + frameSize * 2]; + T rOutGrad = outputGrad[frameIdx]; + + if (prevOutValue && prevOutGrad) { + if (isBatch) prevOutValue += batchIdx * frameSize; + rPrevOutValue = prevOutValue[frameIdx]; + + if (isBatch) prevOutGrad += batchIdx * frameSize; + rPrevOutGrad = prevOutGrad[frameIdx]; + } + + hppl::gpu::BackwardAct act; + opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, + rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, + act(active_node)); + + gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad; + gateGrad[frameIdx + frameSize * 2] = rFrameStateGrad; + if (prevOutGrad) { + prevOutGrad[frameIdx] = rPrevOutGrad; + } +} + +/* + * threads(framePerBlock, batchPerBlock) + * grid(frameBlocks, batchBlocks) + */ +template +__global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue, + T *gateGrad, T *prevOutValue, + T *prevOutGrad, T *resetOutputGrad, + int frameSize, int batchSize, + activation_mode_t active_gate) { + const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x; + if (frameIdx >= frameSize) return; + int batchIdx = 0; + if (isBatch) { + batchIdx = blockIdx.y * blockDim.y + threadIdx.y; + if (batchIdx >= batchSize) return; + gateValue += batchIdx * 3 * frameSize; + gateGrad += batchIdx * 3 * frameSize; + resetOutputGrad += batchIdx * frameSize; + } + + T rResetGateGrad; + T rPrevOutValue = 0; + T rPrevOutGrad = 0; + T rResetOutputGrad = 0; + T rUpdateGateValue = gateValue[frameIdx + frameSize * 0]; + T rUpdateGateGrad = gateGrad[frameIdx + frameSize * 0]; + T rResetGateValue = gateValue[frameIdx + frameSize * 1]; + + if (prevOutValue && prevOutGrad) { + if (isBatch) prevOutValue += batchIdx * frameSize; + if (isBatch) prevOutGrad += batchIdx * frameSize; + rPrevOutValue = prevOutValue[frameIdx]; + rPrevOutGrad = prevOutGrad[frameIdx]; + rResetOutputGrad = resetOutputGrad[frameIdx]; + } + + hppl::gpu::BackwardAct act; + opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, + rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, + act(active_gate)); + + gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad; + gateGrad[frameIdx + frameSize * 1] = rResetGateGrad; + if (prevOutGrad) { + prevOutGrad[frameIdx] = rPrevOutGrad; + } +} +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/detail/gru_kernel.h b/paddle/operators/math/detail/gru_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..a1b4dd7e62285327dd804fb01c0044d1c22516e4 --- /dev/null +++ b/paddle/operators/math/detail/gru_kernel.h @@ -0,0 +1,191 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/detail/hl_activation_functions.h" +#include "paddle/platform/hostdevice.h" + +#include + +namespace paddle { +namespace operators { +namespace math { +namespace detail { + +namespace forward { + +template +class gru_resetOutput { + public: + /** + * @param[in,out] valueUpdateGate update gate + * @param[in,out] valueResetGate reset gate + * @param[in] prevOut previous output + * @param[out] valueResetOutput intermediate value for frame state + * @param[in] actGate forward function of gate + */ + HOSTDEVICE void operator()(T &valueUpdateGate, T &valueResetGate, T &prevOut, + T &valueResetOutput, + typename hppl::Active::forward actGate) { + valueUpdateGate = actGate(valueUpdateGate); + valueResetGate = actGate(valueResetGate); + valueResetOutput = prevOut * valueResetGate; + } +#ifndef __NVCC__ +#ifndef __AVX__ + static const bool avx = false; +#else + static const bool avx = true; + HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueResetGate, + __m256 &prevOut, __m256 &valueResetOutput, + typename hppl::Active<__m256>::forward actGate) { + valueUpdateGate = actGate(valueUpdateGate); + valueResetGate = actGate(valueResetGate); + valueResetOutput = _mm256_mul_ps(prevOut, valueResetGate); + } +#endif +#endif +}; + +template +class gru_finalOutput { + public: + /** + * @param[in] valueUpdateGate update gate + * @param[in,out] valueFrameState frame state ({\tilde{h}_t}) + * @param[in] prevOut previous output + * @param[out] valueOutput output + * @param[in] actInput forward function of node + */ + HOSTDEVICE void operator()(T &valueUpdateGate, T &valueFrameState, T &prevOut, + T &valueOutput, + typename hppl::Active::forward actInput) { + valueFrameState = actInput(valueFrameState); + valueOutput = prevOut - (valueUpdateGate * prevOut) + + (valueUpdateGate * valueFrameState); + } +#ifndef __NVCC__ +#ifndef __AVX__ + static const bool avx = false; +#else + static const bool avx = true; + HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueFrameState, + __m256 &prevOut, __m256 &valueOutput, + typename hppl::Active<__m256>::forward actInput) { + valueFrameState = actInput(valueFrameState); + valueOutput = _mm256_add_ps( + _mm256_sub_ps(prevOut, _mm256_mul_ps(valueUpdateGate, prevOut)), + _mm256_mul_ps(valueUpdateGate, valueFrameState)); + } +#endif +#endif +}; +} // namespace forward + +namespace backward { + +template +class gru_stateGrad { + public: + /** + * @param[in] valueUpdateGate update gate value + * @param[out] gradUpdateGate update gate grad + * @param[in] valueFrameState frame state value + * @param[out] gradFrameState frame state grad + * @param[in] valuePrevOut previous output value + * @param[in,out] gradPrevOut previous output grad + * @param[in] gradOutput output grad + * @param[in] actInput backward function of frame state + */ + HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate, + T &valueFrameState, T &gradFrameState, + T &valuePrevOut, T &gradPrevOut, T &gradOutput, + typename hppl::Active::backward actInput) { + gradUpdateGate = (gradOutput * valueFrameState); + gradUpdateGate -= (gradOutput * valuePrevOut); + gradPrevOut -= (gradOutput * valueUpdateGate); + gradPrevOut += gradOutput; + gradFrameState = actInput(gradOutput * valueUpdateGate, valueFrameState); + } +#ifndef __NVCC__ +#ifndef __AVX__ + static const bool avx = false; +#else + static const bool avx = true; + HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate, + __m256 &valueFrameState, __m256 &gradFrameState, + __m256 &valuePrevOut, __m256 &gradPrevOut, + __m256 &gradOutput, + typename hppl::Active<__m256>::backward actInput) { + gradUpdateGate = _mm256_mul_ps(gradOutput, valueFrameState); + gradUpdateGate = + _mm256_sub_ps(gradUpdateGate, _mm256_mul_ps(gradOutput, valuePrevOut)); + gradPrevOut = _mm256_add_ps( + _mm256_sub_ps(gradPrevOut, _mm256_mul_ps(gradOutput, valueUpdateGate)), + gradOutput); + gradFrameState = + actInput(_mm256_mul_ps(gradOutput, valueUpdateGate), valueFrameState); + } +#endif +#endif +}; + +template +class gru_resetGrad { + public: + /** + * @param[in] valueUpdateGate update gate value + * @param[in,out] gradUpdateGate update gate grad + * @param[in] valueResetGate reset gate value + * @param[out] gradResetGate reset gate grad + * @param[in] valuePrevOut previous output value + * @param[in,out] gradPrevOut previous output grad + * @param[in] gradResetOutput reset output grad (temp val) + * @param[in] actGate backward function of gate + */ + HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate, + T &valueResetGate, T &gradResetGate, + T &valuePrevOut, T &gradPrevOut, + T &gradResetOutput, + typename hppl::Active::backward actGate) { + gradResetGate = (gradResetOutput * valuePrevOut); + gradPrevOut += (gradResetOutput * valueResetGate); + gradUpdateGate = actGate(gradUpdateGate, valueUpdateGate); + gradResetGate = actGate(gradResetGate, valueResetGate); + } +#ifndef __NVCC__ +#ifndef __AVX__ + static const bool avx = false; +#else + static const bool avx = true; + HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate, + __m256 &valueResetGate, __m256 &gradResetGate, + __m256 &valuePrevOut, __m256 &gradPrevOut, + __m256 &gradResetOutput, + typename hppl::Active<__m256>::backward actGate) { + gradResetGate = _mm256_mul_ps(gradResetOutput, valuePrevOut); + gradPrevOut = _mm256_add_ps(gradPrevOut, + _mm256_mul_ps(gradResetOutput, valueResetGate)); + gradUpdateGate = actGate(gradUpdateGate, valueUpdateGate); + gradResetGate = actGate(gradResetGate, valueResetGate); + } +#endif +#endif +}; + +} // namespace backward + +} // namespace detail +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/gru_compute.cc b/paddle/operators/math/gru_compute.cc new file mode 100644 index 0000000000000000000000000000000000000000..125af449d3f700e24be5e4b7615c3b0e03fd4e5b --- /dev/null +++ b/paddle/operators/math/gru_compute.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/gru_compute.h" +#include "paddle/operators/math/detail/gru_cpu_kernel.h" +#include "paddle/operators/math/detail/gru_kernel.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct GRUUnitFunctor { + static void compute(const platform::DeviceContext &context, + hl_gru_value value, int frameSize, int batchSize, + activation_mode_t active_node, + activation_mode_t active_gate) { +#ifndef __NVCC__ + if (value.prevOutValue) { + math::gemm( + context, false, false, batchSize, frameSize * 2, frameSize, 1, + value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1, + value.gateValue, frameSize * 3); + } + + detail::forward_reset_output(detail::forward::gru_resetOutput(), value, + frameSize, batchSize, active_gate); + + if (value.prevOutValue) { + math::gemm( + context, false, false, batchSize, frameSize, frameSize, 1, + value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1, + value.gateValue + frameSize * 2, frameSize * 3); + } + + detail::forward_final_output(detail::forward::gru_finalOutput(), value, + frameSize, batchSize, active_node); +#endif + } +}; + +template +struct GRUUnitGradFunctor { + static void compute(const platform::DeviceContext &context, + hl_gru_value value, hl_gru_grad grad, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate) { +#ifndef __NVCC__ + detail::backward_state_grad(detail::backward::gru_stateGrad(), value, + grad, frameSize, batchSize, active_node); + + if (value.prevOutValue && grad.prevOutGrad) { + math::gemm( + context, false, true, batchSize, frameSize, frameSize, 1, + grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight, + frameSize, 0, grad.resetOutputGrad, frameSize); + + if (grad.stateWeightGrad) { + math::gemm( + context, true, false, frameSize, frameSize, batchSize, 1, + value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2, + frameSize * 3, 1, grad.stateWeightGrad, frameSize); + } + } + + detail::backward_reset_grad(detail::backward::gru_resetGrad(), value, + grad, frameSize, batchSize, active_gate); + + if (grad.prevOutGrad && value.prevOutValue) { + math::gemm( + context, false, true, batchSize, frameSize, frameSize * 2, 1, + grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1, + grad.prevOutGrad, frameSize); + + if (grad.gateWeightGrad) { + math::gemm( + context, true, false, frameSize, frameSize * 2, batchSize, 1, + value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1, + grad.gateWeightGrad, frameSize * 2); + } + } +#endif + } +}; + +template struct GRUUnitFunctor; +template struct GRUUnitFunctor; +template struct GRUUnitGradFunctor; +template struct GRUUnitGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/gru_compute.cu b/paddle/operators/math/gru_compute.cu new file mode 100644 index 0000000000000000000000000000000000000000..4eb558142b6ed5f15bca3ebe22a0f73604d06f40 --- /dev/null +++ b/paddle/operators/math/gru_compute.cu @@ -0,0 +1,178 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/detail/gru_gpu_kernel.h" +#include "paddle/operators/math/detail/gru_kernel.h" +#include "paddle/operators/math/gru_compute.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct GRUUnitFunctor { + static void compute(const platform::DeviceContext &context, + hl_gru_value value, int frameSize, int batchSize, + activation_mode_t active_node, + activation_mode_t active_gate) { + auto stream = + reinterpret_cast(context).stream(); + dim3 threads; + dim3 grid; + if (batchSize == 1) { + int framePerBlock = frameSize <= 1024 ? frameSize : 1024; + int frameBlocks = (frameSize + 1024 - 1) / 1024; + threads = dim3(framePerBlock, 1); + grid = dim3(frameBlocks, 1); + } else { + threads = dim3(32, 32); + grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + } + + if (value.prevOutValue) { + math::gemm( + context, false, false, batchSize, frameSize * 2, frameSize, 1, + value.prevOutValue, frameSize, value.gateWeight, frameSize * 2, 1, + value.gateValue, frameSize * 3); + } + + if (batchSize == 1) { + detail::KeGruForwardResetOutput, + /* isBatch= */ false, + T><<>>( + detail::forward::gru_resetOutput(), value.gateValue, + value.resetOutputValue, value.prevOutValue, frameSize, batchSize, + active_gate); + } else { + detail::KeGruForwardResetOutput, + /* isBatch= */ true, + T><<>>( + detail::forward::gru_resetOutput(), value.gateValue, + value.resetOutputValue, value.prevOutValue, frameSize, batchSize, + active_gate); + } + + if (value.prevOutValue) { + math::gemm( + context, false, false, batchSize, frameSize, frameSize, 1, + value.resetOutputValue, frameSize, value.stateWeight, frameSize, 1, + value.gateValue + frameSize * 2, frameSize * 3); + } + + if (batchSize == 1) { + detail::KeGruForwardFinalOutput, + /* isBatch= */ false, + T><<>>( + detail::forward::gru_finalOutput(), value.gateValue, + value.prevOutValue, value.outputValue, frameSize, batchSize, + active_node); + } else { + detail::KeGruForwardFinalOutput, + /* isBatch= */ true, + T><<>>( + detail::forward::gru_finalOutput(), value.gateValue, + value.prevOutValue, value.outputValue, frameSize, batchSize, + active_node); + } + } +}; + +template +struct GRUUnitGradFunctor { + static void compute(const platform::DeviceContext &context, + hl_gru_value value, hl_gru_grad grad, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate) { + auto stream = + reinterpret_cast(context).stream(); + dim3 threads; + dim3 grid; + if (batchSize == 1) { + int framePerBlock = frameSize <= 1024 ? frameSize : 1024; + int frameBlocks = (frameSize + 1024 - 1) / 1024; + threads = dim3(framePerBlock, 1); + grid = dim3(frameBlocks, 1); + } else { + threads = dim3(32, 32); + grid = dim3((frameSize + 32 - 1) / 32, (batchSize + 32 - 1) / 32); + } + + if (batchSize == 1) { + detail::KeGruBackwardStateGrad< + detail::backward::gru_stateGrad, + /* isBatch= */ false><<>>( + detail::backward::gru_stateGrad(), value.gateValue, grad.gateGrad, + value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize, + batchSize, active_node); + } else { + detail::KeGruBackwardStateGrad< + detail::backward::gru_stateGrad, + /* isBatch= */ true><<>>( + detail::backward::gru_stateGrad(), value.gateValue, grad.gateGrad, + value.prevOutValue, grad.prevOutGrad, grad.outputGrad, frameSize, + batchSize, active_node); + } + + if (value.prevOutValue && grad.prevOutGrad) { + math::gemm( + context, false, true, batchSize, frameSize, frameSize, 1, + grad.gateGrad + frameSize * 2, frameSize * 3, value.stateWeight, + frameSize, 0, grad.resetOutputGrad, frameSize); + + if (grad.stateWeightGrad) { + math::gemm( + context, true, false, frameSize, frameSize, batchSize, 1, + value.resetOutputValue, frameSize, grad.gateGrad + frameSize * 2, + frameSize * 3, 1, grad.stateWeightGrad, frameSize); + } + } + + if (batchSize == 1) { + detail::KeGruBackwardResetGrad< + detail::backward::gru_resetGrad, + /* isBatch= */ false><<>>( + detail::backward::gru_resetGrad(), value.gateValue, grad.gateGrad, + value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize, + batchSize, active_gate); + } else { + detail::KeGruBackwardResetGrad< + detail::backward::gru_resetGrad, + /* isBatch= */ true><<>>( + detail::backward::gru_resetGrad(), value.gateValue, grad.gateGrad, + value.prevOutValue, grad.prevOutGrad, grad.resetOutputGrad, frameSize, + batchSize, active_gate); + } + + if (grad.prevOutGrad && value.prevOutValue) { + math::gemm( + context, false, true, batchSize, frameSize, frameSize * 2, 1, + grad.gateGrad, frameSize * 3, value.gateWeight, frameSize * 2, 1, + grad.prevOutGrad, frameSize); + + if (grad.gateWeightGrad) { + math::gemm( + context, true, false, frameSize, frameSize * 2, batchSize, 1, + value.prevOutValue, frameSize, grad.gateGrad, frameSize * 3, 1, + grad.gateWeightGrad, frameSize * 2); + } + } + } +}; + +template struct GRUUnitFunctor; +template struct GRUUnitFunctor; +template struct GRUUnitGradFunctor; +template struct GRUUnitGradFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle \ No newline at end of file diff --git a/paddle/operators/math/gru_compute.h b/paddle/operators/math/gru_compute.h new file mode 100644 index 0000000000000000000000000000000000000000..45ce48658aa363ac4bfb02cc3d10f571c5d00ff7 --- /dev/null +++ b/paddle/operators/math/gru_compute.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/operators/math/lstm_compute.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace math { + +// typedef enum { +// HL_ACTIVATION_SIGMOID = 0, +// HL_ACTIVATION_RELU = 1, +// HL_ACTIVATION_TANH = 2, +// HL_ACTIVATION_LINEAR = 3, +// HL_ACTIVATION_END +// } activation_mode_t; + +// inline activation_mode_t ActiveType(const std::string &type) { +// if (type == "sigmoid") { +// return HL_ACTIVATION_SIGMOID; +// } else if (type == "relu") { +// return HL_ACTIVATION_RELU; +// } else if (type == "tanh") { +// return HL_ACTIVATION_TANH; +// } else if (type == "linear" || type == "") { +// return HL_ACTIVATION_LINEAR; +// } else { +// PADDLE_THROW("Do not support activation type."); +// } +// } + +template +struct hl_gru_value { + T *gateWeight; + T *stateWeight; + T *gateValue; + T *resetOutputValue; + T *outputValue; + T *prevOutValue; +}; + +template +struct hl_gru_grad { + T *gateWeightGrad; + T *stateWeightGrad; + T *gateGrad; + T *resetOutputGrad; + T *outputGrad; + T *prevOutGrad; +}; + +template +struct GRUUnitFunctor { + static void compute(const platform::DeviceContext &context, + hl_gru_value value, int frameSize, int batchSize, + activation_mode_t active_node, + activation_mode_t active_gate); +}; + +template +struct GRUUnitGradFunctor { + static void compute(const platform::DeviceContext &context, + hl_gru_value value, hl_gru_grad grad, int frameSize, + int batchSize, activation_mode_t active_node, + activation_mode_t active_gate); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/sequence2batch.h b/paddle/operators/math/sequence2batch.h index 03cd018e46e90c9bbe689c9686377e0e998ee513..577496928c0257a58263f475eb0ed04cc5fe5efb 100644 --- a/paddle/operators/math/sequence2batch.h +++ b/paddle/operators/math/sequence2batch.h @@ -21,6 +21,128 @@ namespace paddle { namespace operators { namespace math { +// template +// class CopyMatrixRowsFunctor { +// public: +// // If is_src_index is true, +// // copy the indexed rows of input src to the output dst. +// // If is_src_index is false, +// // copy the input src to the indexed rows of output dst. +// // The indexed rows are based on the input index. +// void operator()(const platform::DeviceContext& context, +// const framework::LoDTensor& src, const size_t* index, +// framework::LoDTensor& dst, bool is_src_index); +// }; + +// template +// class LoDTensor2BatchFunctor { +// // Calculate the length of each sequence and +// // sort sequence index by the length. +// // example: sequences = {s0, s1, s2} +// // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 +// // seq_info[3] = {(4, 5, 1), (0, 4, 0), (9, 3, 2)} +// // +// struct SeqInfo { +// SeqInfo(int start, int length, int seq_idx) +// : start(start), length(length), seq_idx(seq_idx) {} +// int start; +// int length; +// int seq_idx; +// }; + +// public: +// void operator()(const platform::DeviceContext& context, +// const framework::LoDTensor& lod_tensor, +// framework::LoDTensor& batch, bool is_reverse) const { +// auto lods = lod_tensor.lod(); +// PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence +// now."); +// auto lod = lods[0]; + +// std::vector seq_info; +// for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { +// int length = lod[seq_id + 1] - lod[seq_id]; +// seq_info.emplace_back(lod[seq_id], length, seq_id); +// } + +// std::sort(seq_info.begin(), seq_info.end(), +// [](SeqInfo a, SeqInfo b) { return a.length > b.length; }); + +// // calculate the start position of each batch +// // (numBatch equal the maxLength of sequences) +// // example: sequences = {s0, s1, s2} +// // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 +// // num_batch = 5, +// // batchIndex = {b0, b1, b2, b3, b4} +// // b0: 1 0 2, b1: 1 0 2, b2: 1 0 2, b3: 1 0, b4: 1 +// // batch_start_positions[6] = {0, 3, 6, 9, 11, 12} +// // batch_start_positions[0] = len(b0) +// // batch_start_positions[1] = len(b0) + len(b1) +// // batch_start_positions[2] = len(b0) + len(b1) + len(b2) +// // ... +// // seq2batch_idx[12] = {4, 0, 9, +// // 5, 1, 10, +// // 6, 2, 11, +// // 7, 3, +// // 8} +// // The batch number represents batch size after rearranging the +// // input LodTensor. It is also the maximum length of input sequence. + +// paddle::framework::LoD batch_lods; +// batch_lods.emplace_back(std::vector{0}); +// batch_lods.emplace_back(std::vector{0}); + +// // batch_lods[0] is the start positions for batch LoDTensor +// int num_batch = seq_info[0].length; +// batch_lods[0].resize(static_cast(num_batch + 1)); +// // batch_lods[1] is the raw index in the input LoDTensor +// auto dims = lod_tensor.dims(); +// batch_lods[1].resize(static_cast(dims[0])); + +// size_t* batch_starts = batch_lods[0].data(); +// size_t* seq2batch_idx = batch_lods[1].data(); +// batch_starts[0] = 0; +// for (size_t n = 0; n < num_batch; n++) { +// auto batch_id = static_cast(batch_starts[n]); +// for (size_t i = 0; i < seq_info.size(); ++i) { +// size_t seq_len = seq_info[i].length; +// int start = seq_info[i].start; +// if (n < seq_len) { +// seq2batch_idx[batch_id] = +// is_reverse ? start + seq_len - 1 - n : start + n; +// batch_id++; +// } else { +// break; +// } +// } +// batch_starts[n + 1] = static_cast(batch_id); +// } +// batch.set_lod(batch_lods); + +// CopyMatrixRowsFunctor to_batch; +// to_batch(context, lod_tensor, seq2batch_idx, batch, true); +// } +// }; + +// template +// class Batch2LoDTensorFunctor { +// public: +// void operator()(const platform::DeviceContext& context, +// const framework::LoDTensor& batch, +// framework::LoDTensor& lod_tensor) const { +// auto in_lod = batch.lod(); +// PADDLE_ENFORCE_EQ(in_lod.size(), 2UL, +// "The LoD size of input `batch` should be 2."); +// auto out_lod = lod_tensor.lod()[0]; +// auto num = out_lod[out_lod.size() - 1]; +// PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]); +// PADDLE_ENFORCE_EQ(num, in_lod[1].size()); +// PADDLE_ENFORCE_EQ(num, batch.dims()[0]); +// CopyMatrixRowsFunctor to_seq; +// size_t* index = in_lod[1].data(); +// to_seq(context, batch, index, lod_tensor, false); +// } +// }; template class CopyMatrixRowsFunctor { public: @@ -53,7 +175,18 @@ class LoDTensor2BatchFunctor { public: void operator()(const platform::DeviceContext& context, const framework::LoDTensor& lod_tensor, - framework::LoDTensor& batch, bool is_reverse) const { + framework::LoDTensor& batch, bool is_reverse = false, + bool is_cal_batch_lod = true) const { + if (!is_cal_batch_lod) { + auto lods = batch.lod(); + PADDLE_ENFORCE_EQ(lods.size(), 2UL); + PADDLE_ENFORCE_EQ(lods[1].size(), + static_cast(lod_tensor.dims()[0])); + CopyMatrixRowsFunctor to_batch; + to_batch(context, lod_tensor, lods[1].data(), batch, true); + return; + } + auto lods = lod_tensor.lod(); PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); auto lod = lods[0]; @@ -101,10 +234,10 @@ class LoDTensor2BatchFunctor { size_t* batch_starts = batch_lods[0].data(); size_t* seq2batch_idx = batch_lods[1].data(); batch_starts[0] = 0; - for (size_t n = 0; n < num_batch; n++) { + for (int n = 0; n < num_batch; n++) { auto batch_id = static_cast(batch_starts[n]); for (size_t i = 0; i < seq_info.size(); ++i) { - size_t seq_len = seq_info[i].length; + int seq_len = seq_info[i].length; int start = seq_info[i].start; if (n < seq_len) { seq2batch_idx[batch_id] = @@ -132,11 +265,8 @@ class Batch2LoDTensorFunctor { auto in_lod = batch.lod(); PADDLE_ENFORCE_EQ(in_lod.size(), 2UL, "The LoD size of input `batch` should be 2."); - auto out_lod = lod_tensor.lod()[0]; - auto num = out_lod[out_lod.size() - 1]; - PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]); - PADDLE_ENFORCE_EQ(num, in_lod[1].size()); - PADDLE_ENFORCE_EQ(num, batch.dims()[0]); + PADDLE_ENFORCE_EQ(in_lod[1].size(), + static_cast(lod_tensor.dims()[0])); CopyMatrixRowsFunctor to_seq; size_t* index = in_lod[1].data(); to_seq(context, batch, index, lod_tensor, false); diff --git a/python/paddle/v2/framework/tests/test_gru_op.py b/python/paddle/v2/framework/tests/test_gru_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e4cd12642779efa688dac9a478f2df71aec248ab --- /dev/null +++ b/python/paddle/v2/framework/tests/test_gru_op.py @@ -0,0 +1,183 @@ +import unittest +import numpy as np +import math +from op_test import OpTest + +SIGMOID_THRESHOLD_MIN = -40.0 +SIGMOID_THRESHOLD_MAX = 13.0 +EXP_MAX_INPUT = 40.0 + + +def identity(x): + return x + + +def sigmoid(x): + y = np.copy(x) + y[x < SIGMOID_THRESHOLD_MIN] = SIGMOID_THRESHOLD_MIN + y[x > SIGMOID_THRESHOLD_MAX] = SIGMOID_THRESHOLD_MAX + return 1. / (1. + np.exp(-y)) + + +def tanh(x): + y = -2. * x + y[y > EXP_MAX_INPUT] = EXP_MAX_INPUT + return (2. / (1. + np.exp(y))) - 1. + + +def relu(x): + return np.maximum(x, 0) + + +class TestGRUOp(OpTest): + batch_size = 9 + frame_size = 5 + activate = { + 'identity': identity, + 'sigmoid': sigmoid, + 'tanh': tanh, + 'relu': relu + } + + @staticmethod + def seq_to_batch(lod, is_reverse): + idx_in_seq_list = [] + seq_starts = lod[0] + seq_lens = [] + for i in range(len(seq_starts) - 1): + seq_lens.append(seq_starts[i + 1] - seq_starts[i]) + sorted_seqs = sorted( + range(len(seq_lens)), lambda x, y: seq_lens[y] - seq_lens[x]) + num_batch = seq_lens[sorted_seqs[0]] + for batch_idx in range(num_batch): + idx_in_seq = [] + for i in range(len(seq_lens)): + if seq_lens[sorted_seqs[i]] <= batch_idx: + break + idx = (seq_starts[sorted_seqs[i] + 1] - 1 - batch_idx + ) if is_reverse else ( + seq_starts[sorted_seqs[i]] + batch_idx) + idx_in_seq.append(idx) + idx_in_seq_list.append(idx_in_seq) + return idx_in_seq_list + + def gru_step(self, x, h_p, w, b): + print x.shape, h_p.shape, w.shape, b.shape + batch_size = x.shape[0] + frame_size = w.shape[0] + g = x + np.tile(b, (batch_size, 1)) + w_u_r = w.flatten()[:frame_size * frame_size * 2].reshape( + (frame_size, frame_size * 2)) + u_r = self.activate[self.attrs['gate_activation']](np.dot( + h_p, w_u_r) + g[:, :frame_size * 2]) + u = u_r[:, :frame_size] + r = u_r[:, frame_size:frame_size * 2] + r_h_p = r * h_p + w_c = w.flatten()[frame_size * frame_size * 2:].reshape( + (frame_size, frame_size)) + c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) + + g[:, frame_size * 2:]) + g = np.hstack((u_r, c)) + h = u * c + (1 - u) * h_p + return g, r_h_p, h + + def gru(self): + input, lod = self.inputs['Input'] + w = self.inputs['Weight'] + b = self.inputs['Bias'] if self.inputs.has_key('Bias') else np.zeros( + (1, self.frame_size * 3)) + batch_gate = self.outputs['BatchGate'] + batch_reset_hidden_prev = self.outputs['BatchResetHiddenPrev'] + batch_hidden = self.outputs['BatchHidden'] + hidden = self.outputs['Hidden'] + idx_in_seq_list = self.idx_in_seq_list + h_p = self.inputs['H0'] if self.inputs.has_key('H0') else np.zeros( + (len(idx_in_seq_list[0]), self.frame_size)) + num_batch = len(idx_in_seq_list) + end_idx = 0 + for batch_idx in range(num_batch): + print idx_in_seq_list[batch_idx] + x = input[idx_in_seq_list[batch_idx]] + g, r_h_p, h = self.gru_step(x, h_p, w, b) + if batch_idx < (num_batch - 1): + h_p = h[:len(idx_in_seq_list[batch_idx + 1])] + start_idx = end_idx + end_idx = start_idx + len(idx_in_seq_list[batch_idx]) + batch_gate[start_idx:end_idx] = g + batch_reset_hidden_prev[start_idx:end_idx] = r_h_p + batch_hidden[start_idx:end_idx] = h + hidden[idx_in_seq_list[batch_idx]] = h + return batch_gate, batch_reset_hidden_prev, hidden + + def set_data(self): + lod = [[0, 2, 6, 9]] #[[0, 1, 2, 3]] + self.idx_in_seq_list = self.seq_to_batch(lod, self.is_reverse) + print self.idx_in_seq_list + batch_size = self.batch_size + frame_size = self.frame_size + input = np.random.rand(batch_size, frame_size * 3).astype('float64') + h0 = np.random.rand(len(self.idx_in_seq_list[0]), + frame_size).astype('float64') + weight = np.random.rand(frame_size, frame_size * 3).astype('float64') + bias = np.random.rand(1, frame_size * 3).astype('float64') + + self.inputs = { + 'Input': (input, lod), + 'H0': h0, + 'Weight': weight, + 'Bias': bias + } + + self.outputs = { + 'BatchGate': np.zeros( + (batch_size, frame_size * 3), dtype='float64'), + 'BatchResetHiddenPrev': np.zeros( + (batch_size, frame_size), dtype='float64'), + 'BatchHidden': np.zeros( + (batch_size, frame_size), dtype='float64'), + 'Hidden': np.zeros( + (batch_size, frame_size), dtype='float64') + } + + def set_confs(self): + self.is_reverse = False + self.attrs = { + 'activation': 'tanh', + 'gate_activation': 'sigmoid', + 'is_reverse': self.is_reverse + } + + def setUp(self): + self.op_type = "gru" + self.set_confs() + self.set_data() + self.gru() + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden']) + + +class TestGRUOpNoInitial(TestGRUOp): + def set_data(self): + super(TestGRUOpNoInitial, self).set_data() + self.inputs.pop('H0') + + def test_check_grad(self): + self.check_grad(['Input', 'Weight', 'Bias'], ['Hidden']) + + +class TestGRUOpReverse(TestGRUOp): + def set_confs(self): + self.is_reverse = True + self.attrs = { + 'activation': 'identity', + 'gate_activation': 'sigmoid', + 'is_reverse': self.is_reverse + } + + +if __name__ == "__main__": + unittest.main()