From 334c84526b7ec858d1ea952e24b9c96ece39862c Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Wed, 20 Sep 2017 11:06:21 -0700 Subject: [PATCH] lstm unit --- paddle/operators/lstm_unit_op.cc | 101 +++++++++++++++++++++ paddle/operators/lstm_unit_op.h | 147 +++++++++++++++++++++++++++++++ 2 files changed, 248 insertions(+) create mode 100644 paddle/operators/lstm_unit_op.cc create mode 100644 paddle/operators/lstm_unit_op.h diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc new file mode 100644 index 000000000..e3cac2605 --- /dev/null +++ b/paddle/operators/lstm_unit_op.cc @@ -0,0 +1,101 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/lstm_unit_op.h" + +namespace paddle { +namespace operators { + +class LstmUnitOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("C_prev"), + "Input(C_prev) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("C"), + "Output(C) of LSTM should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("H"), + "Output(H) of LSTM should not be null."); + + auto *x = ctx.Input("X"); + auto *c_prev = ctx.Input("C_prev"); + + PADDLE_ENFORCE_EQ(x->dims().size(), 2, "Input(X)'s rank must be 2."); + PADDLE_ENFORCE(x->dims()[0] == c_prev->dims()[0], + "Batch size of inputs and states must be equal"); + PADDLE_ENFORCE(x->dims()[1] == c_prev->dims()[1] * 4, + "Dimension of FC should equal to prev state * 4"); + + int b_size = c_prev->dims()[0]; // batch size + int s_dim = c_prev->dims()[1]; // state dim + ctx.Output("C")->Resize({b_size, s_dim}); + ctx.Output("H")->Resize({b_size, s_dim}); + } +}; + +template +class LstmUnitOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LstmUnitOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "FC input before the non-linear activation."); + AddInput( + "C_prev", + "The cell state tensor of last time-step in the Lstm Unit operator."); + AddOutput("C", "The cell tensor of Lstm Unit operator."); + AddOutput("H", "The hidden state tensor of Lstm Unit operator."); + + AddComment(R"DOC(Lstm-Unit Operator + +Equation: + i, j, f, o = split(X) + C = C_prev * sigm(f + forget_bias) + sigm(i) * tanh(j) + H = C * sigm(o) + +)DOC"); + AddAttr("forget_bias", "The forget bias of Lstm Unit.") + .SetDefault(0.0); + } +}; + +class LstmUnitGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("C")), + "Input(C@GRAD) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("H")), + "Input(H@GRAD) should not be null"); + ctx.Output(framework::GradVarName("X")) + ->Resize(ctx.Input("X")->dims()); + ctx.Output(framework::GradVarName("C_prev")) + ->Resize(ctx.Input("C_prev")->dims()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(lstm_unit, ops::LstmUnitOp, ops::LstmUnitOpMaker, + lstm_unit_grad, ops::LstmUnitGradOp); +REGISTER_OP_CPU_KERNEL(lstm_unit, + ops::LstmUnitKernel); diff --git a/paddle/operators/lstm_unit_op.h b/paddle/operators/lstm_unit_op.h new file mode 100644 index 000000000..6e870f65e --- /dev/null +++ b/paddle/operators/lstm_unit_op.h @@ -0,0 +1,147 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using framework::LoDTensor; +using framework::Tensor; + +template +inline T sigmoid(T x) { + return 1. / (1. + exp(-x)); +} + +template +inline T tanh(T x) { + return 2. * sigmoid(2. * x) - 1.; +} + +template +class LstmUnitKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto* x_tensor = ctx.Input("X"); + auto* c_prev_tensor = ctx.Input("C_prev"); + auto* c_tensor = ctx.Output("C"); + auto* h_tensor = ctx.Output("H"); + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + int b_size = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + T* C = c_tensor->mutable_data(ctx.GetPlace()); + T* H = h_tensor->mutable_data(ctx.GetPlace()); + + const T* X = x_tensor->data(); + const T* C_prev = c_prev_tensor->data(); + + for (int n = 0; n < b_size; ++n) { + for (int d = 0; d < D; ++d) { + const T i = sigmoid(X[d]); + const T f = sigmoid(X[1 * D + d] + forget_bias); + const T o = sigmoid(X[2 * D + d]); + const T g = tanh(X[3 * D + d]); + const T c_prev = C_prev[d]; + const T c = f * c_prev + i * g; + C[d] = c; + const T tanh_c = tanh(c); + H[d] = o * tanh_c; + } + C_prev += D; + X += 4 * D; + C += D; + H += D; + } + } +}; + +template +class LstmUnitGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_ENFORCE(platform::is_cpu_place(ctx.GetPlace()), + "It must use CPUPlace."); + + auto x_tensor = ctx.Input("X"); + auto c_prev_tensor = ctx.Input("C_prev"); + auto c_tensor = ctx.Input("C"); + auto h_tensor = ctx.Input("H"); + + auto hdiff_tensor = ctx.Input(framework::GradVarName("H")); + auto cdiff_tensor = ctx.Input(framework::GradVarName("C")); + + auto xdiff_tensor = ctx.Output(framework::GradVarName("X")); + auto c_prev_diff_tensor = + ctx.Output(framework::GradVarName("C_prev")); + + auto* X = x_tensor->data(); + auto* C_prev = c_prev_tensor->data(); + auto* C = c_tensor->data(); + auto* H = h_tensor->data(); + + auto* H_diff = hdiff_tensor->data(); + auto* C_diff = cdiff_tensor->data(); + + auto* C_prev_diff = c_prev_diff_tensor->mutable_data(ctx.GetPlace()); + auto* X_diff = xdiff_tensor->mutable_data(ctx.GetPlace()); + + int N = c_tensor->dims()[0]; + int D = c_tensor->dims()[1]; + + auto forget_bias = static_cast(ctx.Attr("forget_bias")); + + for (int n = 0; n < N; ++n) { + for (int d = 0; d < D; ++d) { + T* c_prev_diff = C_prev_diff + d; + T* i_diff = X_diff + d; + T* f_diff = X_diff + 1 * D + d; + T* o_diff = X_diff + 2 * D + d; + T* g_diff = X_diff + 3 * D + d; + + const T i = sigmoid(X[d]); + const T f = sigmoid(X[1 * D + d] + forget_bias); + const T o = sigmoid(X[2 * D + d]); + const T g = tanh(X[3 * D + d]); + const T c_prev = C_prev[d]; + const T c = C[d]; + const T tanh_c = tanh(c); + const T c_term_diff = C_diff[d] + H_diff[d] * o * (1 - tanh_c * tanh_c); + *c_prev_diff = c_term_diff * f; + *i_diff = c_term_diff * g * i * (1 - i); + *f_diff = c_term_diff * c_prev * f * (1 - f); + *o_diff = H_diff[d] * tanh_c * o * (1 - o); + *g_diff = c_term_diff * i * (1 - g * g); + } + C_prev += D; + X += 4 * D; + C += D; + H += D; + C_diff += D; + H_diff += D; + X_diff += 4 * D; + C_prev_diff += D; + } + } +}; + +} // namespace operators +} // namespace paddle -- GitLab