提交 552c9012 编写于 作者: Y Yibing Liu

Enable backward computation in lstmp_op

上级 f2c4bb67
...@@ -39,21 +39,12 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -39,21 +39,12 @@ class LSTMPOp : public framework::OperatorWithKernel {
"Output(BatchGate) of LSTMP should not be null."); "Output(BatchGate) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"), PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchGate) of LSTMP should not be null."); "Output(BatchGate) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchHidden"),
"Output(BatchHidden) of LSTMP should not be null.");
auto in_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) and Input(H0) of LSTMP should not "
"be null at the same time.");
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
}
int frame_size = in_dims[1] / 4; int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight"); auto w_dims = ctx->GetInputDim("Weight");
auto proj_dims = ctx->GetInputDim("ProjWeight"); auto proj_dims = ctx->GetInputDim("ProjWeight");
...@@ -75,6 +66,18 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -75,6 +66,18 @@ class LSTMPOp : public framework::OperatorWithKernel {
"should be %d.", "should be %d.",
frame_size); frame_size);
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) and Input(H0) of LSTMP should not "
"be null at the same time.");
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
ctx->SetOutputDim("OrderedP0", {h_dims[0], proj_dims[1]});
}
auto b_dims = ctx->GetInputDim("Bias"); auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, PADDLE_ENFORCE_EQ(b_dims[0], 1,
...@@ -98,6 +101,7 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -98,6 +101,7 @@ class LSTMPOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Cell", out_dims); ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchGate", in_dims); ctx->SetOutputDim("BatchGate", in_dims);
ctx->SetOutputDim("BatchCellPreAct", out_dims); ctx->SetOutputDim("BatchCellPreAct", out_dims);
ctx->SetOutputDim("BatchHidden", out_dims);
ctx->ShareLoD("Input", "Projection"); ctx->ShareLoD("Input", "Projection");
ctx->ShareLoD("Input", "Cell"); ctx->ShareLoD("Input", "Cell");
} }
...@@ -169,6 +173,15 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -169,6 +173,15 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) This LoDTensor is obtained in the forward and used " "(LoDTensor) This LoDTensor is obtained in the forward and used "
"in the backward.") "in the backward.")
.AsIntermediate(); .AsIntermediate();
AddOutput("BatchHidden",
"(LoDTensor) This LoDTensor is obtained in the forward and used "
"in the backward.")
.AsIntermediate();
AddOutput("OrderedP0",
"(Tensor) the projection of the initial hidden state "
"H0. This is a tensor with shape (N x P), where N is the "
"batch size and P is the hidden size.")
.AsIntermediate();
AddAttr<bool>("use_peepholes", AddAttr<bool>("use_peepholes",
"(bool, defalut: True) " "(bool, defalut: True) "
"whether to enable diagonal/peephole connections.") "whether to enable diagonal/peephole connections.")
...@@ -177,6 +190,12 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -177,6 +190,12 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) " "(bool, defalut: False) "
"whether to compute reversed LSTMP.") "whether to compute reversed LSTMP.")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("share_cell_act",
"(bool, defalut: True) "
"whether to share activation with cell output. "
"If false, the projection would be linear, else "
"through an activation same with the cell output.")
.SetDefault(true);
AddAttr<std::string>( AddAttr<std::string>(
"gate_activation", "gate_activation",
"(string, default: sigmoid)" "(string, default: sigmoid)"
...@@ -213,7 +232,7 @@ o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\ ...@@ -213,7 +232,7 @@ o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\
h_t = o_t \odot act_h(c_t) h_t = o_t \odot act_h(c_t)
r_t = W_{rh}h_t r_t = act_h'(W_{rh}h_t)
$$ $$
where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix
...@@ -229,7 +248,8 @@ layer. ...@@ -229,7 +248,8 @@ layer.
The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$ The $\odot$ is the element-wise product of the vectors. $act_g$ and $act_h$
are the cell input and cell output activation functions and `tanh` is usually are the cell input and cell output activation functions and `tanh` is usually
used for them. used for them. If `share_cell_act` setted to `False`, $act_h'$ will be linear
else will be same with $act_h$.
Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$ Note that these $W_{xi}x_{t}, W_{xf}x_{t}, W_{xc}x_{t}, W_{xo}x_{t}$
operations on the input $x_{t}$ are NOT included in this operator. operations on the input $x_{t}$ are NOT included in this operator.
...@@ -246,12 +266,14 @@ class LSTMPGradOp : public framework::OperatorWithKernel { ...@@ -246,12 +266,14 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTMP should not be null."); "Input(Input) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Hidden"), PADDLE_ENFORCE(ctx->HasInput("Projection"),
"Input(Hidden) of LSTMP should not be null."); "Input(Projection) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"), PADDLE_ENFORCE(ctx->HasInput("Cell"),
"Input(Cell) of LSTMP should not be null."); "Input(Cell) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"), PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTMP should not be null."); "Input(Weight) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ProjWeight"),
"Input(ProjWeight) of LSTMP should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"), PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTMP should not be null."); "Input(Bias) of LSTMP should not be null.");
...@@ -268,6 +290,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel { ...@@ -268,6 +290,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
SetOutGradDim("Input"); SetOutGradDim("Input");
SetOutGradDim("Weight"); SetOutGradDim("Weight");
SetOutGradDim("ProjWeight");
SetOutGradDim("Bias"); SetOutGradDim("Bias");
SetOutGradDim("H0"); SetOutGradDim("H0");
SetOutGradDim("C0"); SetOutGradDim("C0");
......
...@@ -13,18 +13,25 @@ See the License for the specific language governing permissions and ...@@ -13,18 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_registry.h" #include "paddle/operators/activation_op.h"
#include "paddle/operators/math/detail/activation_functions.h" #include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/lstm_compute.h" #include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h" #include "paddle/operators/math/sequence2batch.h"
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx, inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src, const size_t* index, const framework::Tensor& src, const size_t* index,
...@@ -37,6 +44,21 @@ inline void ReorderInitState(const DeviceContext& ctx, ...@@ -37,6 +44,21 @@ inline void ReorderInitState(const DeviceContext& ctx,
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LSTMPKernel : public framework::OpKernel<T> { class LSTMPKernel : public framework::OpKernel<T> {
public: public:
template <typename Device, typename X, typename Y>
void ActCompute(const math::detail::ActivationType act_type, const Device& d,
X x, Y y) const {
if (act_type == math::detail::ActivationType::kIdentity)
y.device(d) = x;
else if (act_type == math::detail::ActivationType::kSigmoid)
SigmoidFunctor<T>()(d, x, y);
else if (act_type == math::detail::ActivationType::kTanh)
TanhFunctor<T>()(d, x, y);
else if (act_type == math::detail::ActivationType::kReLU)
ReluFunctor<T>()(d, x, y);
else
PADDLE_THROW("unsupported activation type");
}
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input"); auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<Tensor>("Weight"); auto* weight = ctx.Input<Tensor>("Weight");
...@@ -44,6 +66,7 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -44,6 +66,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto* bias = ctx.Input<Tensor>("Bias"); auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden_t0 = ctx.Input<Tensor>("H0"); auto* hidden_t0 = ctx.Input<Tensor>("H0");
auto* ordered_proj0 = ctx.Output<Tensor>("OrderedP0");
auto* cell_t0 = ctx.Input<Tensor>("C0"); auto* cell_t0 = ctx.Input<Tensor>("C0");
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate"); auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
...@@ -97,12 +120,13 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -97,12 +120,13 @@ class LSTMPKernel : public framework::OpKernel<T> {
} }
// Use the local variable as here. // Use the local variable as here.
LoDTensor batch_hidden, batch_proj, batch_cell; LoDTensor batch_proj, batch_cell;
auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct"); auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
batch_hidden.mutable_data<T>(dims, ctx.GetPlace()); // T x D batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());
auto* batch_hidden = ctx.Output<LoDTensor>("BatchHidden");
batch_hidden->mutable_data<T>(dims, ctx.GetPlace()); // T x D
batch_proj.mutable_data<T>(proj_dims, ctx.GetPlace()); // T x P batch_proj.mutable_data<T>(proj_dims, ctx.GetPlace()); // T x P
batch_cell.mutable_data<T>(dims, ctx.GetPlace()); // T x D batch_cell.mutable_data<T>(dims, ctx.GetPlace()); // T x D
batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
...@@ -112,13 +136,15 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -112,13 +136,15 @@ class LSTMPKernel : public framework::OpKernel<T> {
ctx.Attr<std::string>("cell_activation")); ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType( auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation")); ctx.Attr<std::string>("candidate_activation"));
auto share_cell_act = ctx.Attr<bool>("share_cell_act");
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]); int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate_t = batch_gate->Slice(bstart, bend); Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor hidden_t = batch_hidden.Slice(bstart, bend); Tensor hidden_t = batch_hidden->Slice(bstart, bend);
Tensor proj_t = batch_proj.Slice(bstart, bend); Tensor proj_t = batch_proj.Slice(bstart, bend);
Tensor cell_t = batch_cell.Slice(bstart, bend); Tensor cell_t = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend); Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend);
...@@ -140,15 +166,19 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -140,15 +166,19 @@ class LSTMPKernel : public framework::OpKernel<T> {
// Since the batch computing for LSTMP reorders the input sequence // Since the batch computing for LSTMP reorders the input sequence
// according to their length. The initialized hidden state also needs // according to their length. The initialized hidden state also needs
// to reorder. // to reorder.
Tensor ordered_h0, ordered_proj0;
ordered_proj0.Resize({1, proj_weight->dims()[1]}); Tensor ordered_h0;
ordered_proj0.mutable_data<T>(ctx.GetPlace()); ordered_proj0->mutable_data<T>(ctx.GetPlace());
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order, ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
&ordered_h0, true); &ordered_h0, true);
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false, math::matmul<DeviceContext, T>(device_ctx, ordered_h0, false,
*proj_weight, false, static_cast<T>(1.0), *proj_weight, false, static_cast<T>(1.0),
&ordered_proj0, static_cast<T>(0.0)); ordered_proj0, static_cast<T>(0.0));
math::matmul<DeviceContext, T>(device_ctx, ordered_proj0, false, if (share_cell_act) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
ActCompute(cell_act, place, proj0_dev, proj0_dev);
}
math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, false,
*weight, false, static_cast<T>(1.0), *weight, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0)); &gate_t, static_cast<T>(1.0));
} }
...@@ -164,6 +194,10 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -164,6 +194,10 @@ class LSTMPKernel : public framework::OpKernel<T> {
math::matmul<DeviceContext, T>(device_ctx, hidden_t, false, *proj_weight, math::matmul<DeviceContext, T>(device_ctx, hidden_t, false, *proj_weight,
false, static_cast<T>(1.0), &proj_t, false, static_cast<T>(1.0), &proj_t,
static_cast<T>(0.0)); static_cast<T>(0.0));
if (share_cell_act) {
auto proj_t_dev = EigenMatrix<T>::From(proj_t);
ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
}
} }
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
...@@ -180,9 +214,26 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -180,9 +214,26 @@ class LSTMPKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LSTMPGradKernel : public framework::OpKernel<T> { class LSTMPGradKernel : public framework::OpKernel<T> {
public: public:
template <typename Device, typename X, typename Y, typename DX, typename DY>
void ActGradCompute(const math::detail::ActivationType act_type,
const Device& d, X x, Y y, DX dx, DY dy) const {
// x is dummy and won't be used even in Relu(use y instead)
if (act_type == math::detail::ActivationType::kIdentity)
dx.device(d) = dy;
else if (act_type == math::detail::ActivationType::kSigmoid)
SigmoidGradFunctor<T>()(d, x, y, dy, dx);
else if (act_type == math::detail::ActivationType::kTanh)
TanhGradFunctor<T>()(d, x, y, dy, dx);
else if (act_type == math::detail::ActivationType::kReLU)
ReluGradFunctor<T>()(d, x, y, dy, dx);
else
PADDLE_THROW("unsupported activation type");
}
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input"); auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<Tensor>("Weight"); auto* weight = ctx.Input<Tensor>("Weight");
auto* proj_weight = ctx.Input<Tensor>("ProjWeight");
auto* bias = ctx.Input<Tensor>("Bias"); auto* bias = ctx.Input<Tensor>("Bias");
auto* proj_out = ctx.Input<LoDTensor>("Projection"); auto* proj_out = ctx.Input<LoDTensor>("Projection");
...@@ -190,14 +241,19 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -190,14 +241,19 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto* batch_gate = ctx.Input<LoDTensor>("BatchGate"); auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct"); auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");
auto* batch_hidden = ctx.Input<LoDTensor>("BatchHidden");
auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Projection")); auto* projection_g =
ctx.Input<LoDTensor>(framework::GradVarName("Projection"));
auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input")); auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight")); auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
auto* proj_weight_g =
ctx.Output<Tensor>(framework::GradVarName("ProjWeight"));
auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto* h0 = ctx.Input<Tensor>("H0"); auto* h0 = ctx.Input<Tensor>("H0");
auto* ordered_proj0 = ctx.Input<Tensor>("OrderedP0");
auto* c0 = ctx.Input<Tensor>("C0"); auto* c0 = ctx.Input<Tensor>("C0");
auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0")); auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0"));
...@@ -209,6 +265,10 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -209,6 +265,10 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
weight_g->mutable_data<T>(ctx.GetPlace()); weight_g->mutable_data<T>(ctx.GetPlace());
zero(device_ctx, weight_g, static_cast<T>(0.0)); zero(device_ctx, weight_g, static_cast<T>(0.0));
} }
if (proj_weight_g) {
proj_weight_g->mutable_data<T>(ctx.GetPlace());
zero(device_ctx, proj_weight_g, static_cast<T>(0.0));
}
// ordered_h0/c0 is the reordered hidden/cell initialization. // ordered_h0/c0 is the reordered hidden/cell initialization.
// ordered_h0_g/c0_g is the reordered gradient of hidden/cell // ordered_h0_g/c0_g is the reordered gradient of hidden/cell
...@@ -224,7 +284,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -224,7 +284,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
} }
auto in_dims = input->dims(); auto in_dims = input->dims();
auto out_dims = hidden_g->dims(); auto out_dims = cell_out->dims();
framework::DDim proj_dims({in_dims[0], proj_weight->dims()[1]});
int frame_size = static_cast<int>(in_dims[1] / 4); int frame_size = static_cast<int>(in_dims[1] / 4);
PADDLE_ENFORCE_EQ(frame_size, out_dims[1]); PADDLE_ENFORCE_EQ(frame_size, out_dims[1]);
...@@ -267,10 +328,11 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -267,10 +328,11 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
to_batch(ctx, src, dst, false); to_batch(ctx, src, dst, false);
}; };
LoDTensor batch_proj, batch_proj_g, batch_cell; LoDTensor batch_hidden_g, batch_proj, batch_proj_g, batch_cell;
ToBatch(device_ctx, *proj_out, out_dims, batch_proj); batch_hidden_g.mutable_data<T>(out_dims, ctx.GetPlace());
ToBatch(device_ctx, *hidden_g, out_dims, batch_proj_g); ToBatch(device_ctx, *proj_out, proj_dims, batch_proj); // T x P
ToBatch(device_ctx, *cell_out, out_dims, batch_cell); ToBatch(device_ctx, *projection_g, proj_dims, batch_proj_g); // T x P
ToBatch(device_ctx, *cell_out, out_dims, batch_cell); // T x D
LoDTensor batch_cell_g, batch_gate_g; LoDTensor batch_cell_g, batch_gate_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace()); batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
...@@ -286,6 +348,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -286,6 +348,8 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ctx.Attr<std::string>("cell_activation")); ctx.Attr<std::string>("cell_activation"));
auto cand_act = math::detail::GetActivationType( auto cand_act = math::detail::GetActivationType(
ctx.Attr<std::string>("candidate_activation")); ctx.Attr<std::string>("candidate_activation"));
auto share_cell_act = ctx.Attr<bool>("share_cell_act");
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
...@@ -293,6 +357,19 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -293,6 +357,19 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]); int bend = static_cast<int>(batch_starts[n + 1]);
Tensor cur_proj = batch_proj.Slice(bstart, bend);
Tensor proj_g = batch_proj_g.Slice(bstart, bend);
if (share_cell_act) {
auto cur_proj_dev = EigenMatrix<T>::From(cur_proj);
auto proj_g_dev = EigenMatrix<T>::From(proj_g);
ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev,
proj_g_dev);
}
Tensor out_g = batch_hidden_g.Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, proj_g, false, *proj_weight,
true, static_cast<T>(1.0), &out_g,
static_cast<T>(0.0));
Tensor gate = batch_gate->Slice(bstart, bend); Tensor gate = batch_gate->Slice(bstart, bend);
Tensor cell = batch_cell.Slice(bstart, bend); Tensor cell = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend); Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend);
...@@ -300,7 +377,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -300,7 +377,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
lstmp_value.state_value = cell.data<T>(); lstmp_value.state_value = cell.data<T>();
lstmp_value.state_active_value = cell_pre_act.data<T>(); lstmp_value.state_active_value = cell_pre_act.data<T>();
Tensor out_g = batch_proj_g.Slice(bstart, bend);
Tensor gate_g = batch_gate_g.Slice(bstart, bend); Tensor gate_g = batch_gate_g.Slice(bstart, bend);
Tensor cell_g = batch_cell_g.Slice(bstart, bend); Tensor cell_g = batch_cell_g.Slice(bstart, bend);
lstmp_grad.state_grad = cell_g.data<T>(); lstmp_grad.state_grad = cell_g.data<T>();
...@@ -337,19 +413,48 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -337,19 +413,48 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
false, static_cast<T>(1.0), weight_g, false, static_cast<T>(1.0), weight_g,
static_cast<T>(1.0)); static_cast<T>(1.0));
} }
if (proj_weight_g) {
/* backward proj weigh */
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, hidden_t, true, proj_g,
false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
} else { } else {
if (h0 && weight_g) { if (h0 && weight_g) {
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order, ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
&ordered_h0, true); &ordered_h0, true);
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true, gate_g, if (weight_g) {
false, static_cast<T>(1.0), weight_g, math::matmul<DeviceContext, T>(device_ctx, *ordered_proj0, true,
static_cast<T>(1.0)); gate_g, false, static_cast<T>(1.0),
weight_g, static_cast<T>(1.0));
}
} }
if (h0 && h0_g) { if (h0 && (h0_g || proj_weight_g)) {
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace()); ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
Tensor proj0_g;
proj0_g.Resize({in_dims[0], proj_weight->dims()[1]});
proj0_g.mutable_data<T>(ctx.GetPlace());
math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight, math::matmul<DeviceContext, T>(device_ctx, gate_g, false, *weight,
true, static_cast<T>(1.0), true, static_cast<T>(1.0), &proj0_g,
&ordered_h0_g, static_cast<T>(0.0)); static_cast<T>(0.0));
if (share_cell_act) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
auto proj0_g_dev = EigenMatrix<T>::From(proj0_g);
ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev,
proj0_g_dev);
}
// Tensor proj0_g = proj_g.Slice(bstart, bend);
if (h0_g) {
math::matmul<DeviceContext, T>(
device_ctx, proj0_g, false, *proj_weight, true,
static_cast<T>(1.0), &ordered_h0_g, static_cast<T>(0.0));
}
if (proj_weight_g) {
math::matmul<DeviceContext, T>(device_ctx, ordered_h0, true,
proj0_g, false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
} }
} }
} }
......
...@@ -62,7 +62,8 @@ def lstmp( ...@@ -62,7 +62,8 @@ def lstmp(
is_reverse=False, is_reverse=False,
act_gate=None, act_gate=None,
act_cell=None, act_cell=None,
act_cand=None): act_cand=None,
share_cell_act=True):
def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand): def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand):
g = np.dot(r_pre, w_r) # 1 x 4D g = np.dot(r_pre, w_r) # 1 x 4D
g = g + x g = g + x
...@@ -85,6 +86,8 @@ def lstmp( ...@@ -85,6 +86,8 @@ def lstmp(
h = g_o * act_cell(c) h = g_o * act_cell(c)
# projection # projection
r = np.dot(h, w_rh) r = np.dot(h, w_rh)
if share_cell_act:
r = act_cell(r)
return r, c return r, c
def _reverse(x, lod): def _reverse(x, lod):
...@@ -107,6 +110,8 @@ def lstmp( ...@@ -107,6 +110,8 @@ def lstmp(
seq_len = offset[i + 1] - offset[i] seq_len = offset[i + 1] - offset[i]
x = input[offset[i]:offset[i + 1], :] x = input[offset[i]:offset[i + 1], :]
r_pre = np.dot(h0[i], w_rh) # 1 x P r_pre = np.dot(h0[i], w_rh) # 1 x P
if share_cell_act:
r_pre = act_cell(r_pre)
c_pre = c0[i] # 1 x D c_pre = c0[i] # 1 x D
for j in range(seq_len): for j in range(seq_len):
# compute one step # compute one step
...@@ -138,6 +143,7 @@ class TestLstmOp(OpTest): ...@@ -138,6 +143,7 @@ class TestLstmOp(OpTest):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.share_cell_act = True
self.has_initial_state = False self.has_initial_state = False
self.is_reverse = False self.is_reverse = False
self.use_peepholes = True self.use_peepholes = True
...@@ -167,7 +173,7 @@ class TestLstmOp(OpTest): ...@@ -167,7 +173,7 @@ class TestLstmOp(OpTest):
w_rh = np.random.normal(size=(self.D, self.P)).astype('float64') w_rh = np.random.normal(size=(self.D, self.P)).astype('float64')
r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse, r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell], ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand]) ACTVATION[self.act_cand], self.share_cell_act)
self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh} self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh}
...@@ -192,28 +198,30 @@ class TestLstmOp(OpTest): ...@@ -192,28 +198,30 @@ class TestLstmOp(OpTest):
def test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-8) self.check_output(atol=1e-8)
"""
def test_check_grad(self): def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined. # TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) ['Input', 'Weight', 'Bias'], ['Projection'],
""" max_relative_error=5e-3)
"""
class TestLstmOpHasInitial(TestLstmOp): class TestLstmOpHasInitial(TestLstmOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 5, 7]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
self.P = 5
self.act_gate = 'sigmoid' self.act_gate = 'sigmoid'
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.share_cell_act = True
self.has_initial_state = True self.has_initial_state = True
self.is_reverse = True self.is_reverse = True
self.use_peepholes = True self.use_peepholes = True
...@@ -221,63 +229,74 @@ class TestLstmOpHasInitial(TestLstmOp): ...@@ -221,63 +229,74 @@ class TestLstmOpHasInitial(TestLstmOp):
def test_check_grad(self): def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined. # TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'], ['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Projection'],
max_relative_error=5e-4) max_relative_error=5e-3)
def test_check_grad_ingore_bias(self): def test_check_grad_ingore_bias(self):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight'], ['Hidden'], ['Input', 'Weight'], ['Projection'],
max_relative_error=5e-4, max_relative_error=5e-3,
no_grad_set=set('Bias')) no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self): def test_check_grad_ingore_weight(self):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Bias'], ['Hidden'], ['Input', 'Bias'], ['Projection'],
max_relative_error=5e-4, max_relative_error=5e-3,
no_grad_set=set('Weight')) no_grad_set=set('Weight'))
def test_check_grad_ingore_input(self): def test_check_grad_ingore_input(self):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Weight', 'Bias'], ['Hidden'], ['Weight', 'Bias'], ['Projection'],
max_relative_error=5e-4, max_relative_error=5e-3,
no_grad_set=set('Input')) no_grad_set=set('Input'))
def test_check_grad_ingore_h0(self): def test_check_grad_ingore_h0(self):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias', 'C0'], ['Hidden'], ['Input', 'Weight', 'Bias', 'C0'], ['Projection'],
max_relative_error=5e-4, max_relative_error=5e-3,
no_grad_set=set('H0')) no_grad_set=set('H0'))
def test_check_grad_ingore_c0(self): def test_check_grad_ingore_c0(self):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias', 'H0'], ['Hidden'], ['Input', 'Weight', 'Bias', 'H0'], ['Projection'],
max_relative_error=5e-4, max_relative_error=5e-3,
no_grad_set=set('C0')) no_grad_set=set('C0'))
"""
class TestLstmOpRerverse(TestLstmOp): class TestLstmOpRerverse(TestLstmOp):
...@@ -290,6 +309,7 @@ class TestLstmOpRerverse(TestLstmOp): ...@@ -290,6 +309,7 @@ class TestLstmOpRerverse(TestLstmOp):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.share_cell_act = True
self.has_initial_state = False self.has_initial_state = False
self.is_reverse = True self.is_reverse = True
self.use_peepholes = True self.use_peepholes = True
...@@ -305,6 +325,7 @@ class TestLstmOpNotUsePeepholes(TestLstmOp): ...@@ -305,6 +325,7 @@ class TestLstmOpNotUsePeepholes(TestLstmOp):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.share_cell_act = True
self.has_initial_state = False self.has_initial_state = False
self.is_reverse = True self.is_reverse = True
self.use_peepholes = False self.use_peepholes = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册