提交 3d8b6ebc 编写于 作者: D dangqingqing

Add LSTM backward implenmentation.

上级 3f1062d7
......@@ -21,7 +21,6 @@ class LSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
......@@ -30,8 +29,8 @@ class LSTMOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
auto x_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
......@@ -44,7 +43,7 @@ class LSTMOp : public framework::OperatorWithKernel {
"should be the same.");
}
int frame_size = x_dims[1] / 4;
int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"The rank of Input(Weight) should be 2.");
......@@ -71,9 +70,11 @@ class LSTMOp : public framework::OperatorWithKernel {
"4 * %d if disable peepholes connection",
frame_size);
}
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size});
ctx->SetOutputDim("Cell", {x_dims[0], frame_size});
ctx->SetOutputDim("BatchGate", x_dims);
framework::DDim out_dims({in_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchGate", in_dims);
ctx->SetOutputDim("BatchCellPreAct", out_dims);
ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell");
}
......@@ -86,7 +87,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where, T is the "
"this LoDTensor is a matrix with shape (T X 4D), where T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput("H0",
"(Tensor, optional) the initial hidden state is an optional "
......@@ -110,21 +111,25 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"2. `usePeepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
AddOutput("Hidden",
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) the cell state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("BatchGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which "
"was also be called batch input. The LoD size is 2. The first "
"is also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence "
"in the raw input.")
.AsIntermediate();
AddOutput("Hidden",
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) the cell state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("BatchCellPreAct",
"(LoDTensor) This LoDTensor is get in the forward and used "
"in the backward.")
.AsIntermediate();
AddAttr<bool>("usePeepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
......@@ -202,15 +207,28 @@ class LSTMGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(Hidden@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")),
"Input(Cell@GRAD) should not be null");
ctx->SetOutputDim(framework::GradVarName("Weight"),
ctx->GetInputDim("Weight"));
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias"));
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
if (ctx->HasInput("Weight")) {
ctx->SetOutputDim(framework::GradVarName("Weight"),
ctx->GetInputDim("Weight"));
}
if (ctx->HasInput("Bias")) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
}
if (ctx->HasInput("H0")) {
ctx->SetOutputDim(framework::GradVarName("H0"), ctx->GetInputDim("H0"));
}
if (ctx->HasInput("C0")) {
ctx->SetOutputDim(framework::GradVarName("C0"), ctx->GetInputDim("C0"));
}
}
};
......
......@@ -21,8 +21,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
using framework::LoDTensor;
using framework::Tensor;
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
......@@ -31,15 +32,15 @@ template <typename Place, typename T>
class LSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::LoDTensor>("Input");
auto* weight = ctx.Input<framework::Tensor>("Weight");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias");
auto* batch_gate = ctx.Output<framework::LoDTensor>("BatchGate");
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(ctx.GetPlace());
auto* hidden_out = ctx.Output<framework::LoDTensor>("Hidden");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
hidden_out->mutable_data<T>(ctx.GetPlace());
auto* cell_out = ctx.Output<framework::LoDTensor>("Cell");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
cell_out->mutable_data<T>(ctx.GetPlace());
// Now the function ShareLoD in InferShape is not implemented.
......@@ -49,7 +50,8 @@ class LSTMKernel : public framework::OpKernel<T> {
bool is_reverse = ctx.Attr<bool>("isReverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch;
to_batch(ctx.device_context(), *input, *batch_gate, is_reverse);
auto& device_ctx = ctx.device_context();
to_batch(device_ctx, *input, *batch_gate, true, is_reverse);
auto in_dims = input->dims();
int frame_size = static_cast<int>(in_dims[1] / 4);
......@@ -69,15 +71,23 @@ class LSTMKernel : public framework::OpKernel<T> {
}
math::LstmMetaValue<T> lstm_value;
T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
if (bias) {
T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
} else {
lstm_value.checkIg = nullptr;
lstm_value.checkFg = nullptr;
lstm_value.checkOg = nullptr;
}
lstm_value.prevStateValue = nullptr;
framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act;
batch_out.mutable_data<T>(dims, ctx.GetPlace());
// Use the local variable as here.
LoDTensor batch_hidden, batch_cell;
auto batch_cell_pre_act = *(ctx.Output<LoDTensor>("BatchCellPreAct"));
batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());
......@@ -92,7 +102,7 @@ class LSTMKernel : public framework::OpKernel<T> {
int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor out_t = batch_out.Slice(bstart, bend);
Tensor out_t = batch_hidden.Slice(bstart, bend);
Tensor cell_t = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend);
......@@ -101,9 +111,9 @@ class LSTMKernel : public framework::OpKernel<T> {
if (n != 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false,
*weight, false, static_cast<T>(1.0), &gate_t,
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(device_ctx, pre_hidden_t, false, *weight, false,
static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
}
// else if : FIXME support the initial hidden and cell
......@@ -112,27 +122,181 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value.outputValue = out_t.data<T>();
lstm_value.stateValue = cell_t.data<T>();
lstm_value.stateActiveValue = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<Place, T>::compute(ctx.device_context(), lstm_value,
math::LstmUnitFunctor<Place, T>::compute(device_ctx, lstm_value,
frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
lstm_value.prevStateValue = lstm_value.stateValue;
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_out.set_lod(batch_gate->lod());
batch_hidden.set_lod(batch_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden
to_seq(ctx.device_context(), batch_out, *hidden_out);
to_seq(device_ctx, batch_hidden, *hidden_out);
batch_cell.set_lod(batch_gate->lod());
// restore the output cell state in LoDTensor from the batch cell
to_seq(ctx.device_context(), batch_cell, *cell_out);
to_seq(device_ctx, batch_cell, *cell_out);
}
};
template <typename Place, typename T>
class LSTMGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden_out = ctx.Input<LoDTensor>("Hidden");
auto* cell_out = ctx.Input<LoDTensor>("Cell");
auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");
auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden"));
auto* cell_g = ctx.Input<LoDTensor>(framework::GradVarName("Cell"));
auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto& device_ctx = ctx.device_context();
if (weight_g) {
math::SetConstant<Place, T> zero;
zero(device_ctx, weight_g, static_cast<T>(0.0));
}
auto in_dims = input->dims();
auto out_dims = hidden_g->dims();
int frame_size = static_cast<int>(in_dims[1] / 4);
PADDLE_ENFORCE_EQ(frame_size, out_dims[1]);
math::LstmMetaValue<T> lstm_value;
if (bias) {
T* bias_data = const_cast<T*>(bias->data<T>());
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
} else {
lstm_value.checkIg = nullptr;
lstm_value.checkFg = nullptr;
lstm_value.checkOg = nullptr;
}
math::LstmMetaGrad<T> lstm_grad;
if (bias && bias_g) {
T* bias_g_data = const_cast<T*>(bias_g->mutable_data<T>(ctx.GetPlace()));
lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
} else {
lstm_grad.checkIgGrad = nullptr;
lstm_grad.checkFgGrad = nullptr;
lstm_grad.checkOgGrad = nullptr;
}
math::LoDTensor2BatchFunctor<Place, T> to_batch;
// use the local variable as here.
LoDTensor batch_hidden;
batch_hidden.mutable_data<T>(out_dims, ctx.GetPlace());
batch_hidden.set_lod(batch_gate->lod());
to_batch(device_ctx, *hidden_out, batch_hidden, false);
LoDTensor batch_hidden_g;
batch_hidden_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_hidden_g.set_lod(batch_gate->lod());
to_batch(device_ctx, *hidden_g, batch_hidden_g, false);
LoDTensor batch_cell;
batch_cell.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell.set_lod(batch_gate->lod());
to_batch(device_ctx, *cell_out, batch_cell, false);
LoDTensor batch_cell_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_g.set_lod(batch_gate->lod());
to_batch(device_ctx, *cell_g, batch_cell_g, false);
LoDTensor batch_gate_g;
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod());
auto gate_act = ctx.Attr<std::string>("gateActivation");
auto cell_act = ctx.Attr<std::string>("cellActivation");
auto cand_act = ctx.Attr<std::string>("candidateActivation");
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (int n = static_cast<int>(num_batch); n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate = batch_gate->Slice(bstart, bend);
Tensor cell = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend);
lstm_value.gateValue = gate.data<T>();
lstm_value.stateValue = cell.data<T>();
lstm_value.stateActiveValue = cell_pre_act.data<T>();
Tensor out_g = batch_hidden_g.Slice(bstart, bend);
Tensor gate_g = batch_gate_g.Slice(bstart, bend);
Tensor cell_g = batch_cell_g.Slice(bstart, bend);
lstm_grad.stateGrad = cell_g.data<T>();
lstm_grad.gateGrad = gate_g.data<T>();
lstm_grad.outputGrad = out_g.data<T>();
if (n != 0) {
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
lstm_value.prevStateValue = cell_pre.data<T>();
lstm_grad.prevStateGrad = cell_pre_g.data<T>();
} else {
lstm_value.prevStateValue = nullptr;
lstm_grad.prevStateGrad = nullptr;
}
int cur_batch_size = bend - bstart;
math::LstmUnitGradFunctor<Place, T>::compute(
device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
if (n != 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(device_ctx, gate_g, false, *weight, true,
static_cast<T>(1.0), &pre_hidden_g,
static_cast<T>(1.0));
if (weight_g) {
/* backward weight */
auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(device_ctx, pre_hidden, true, gate_g, false,
static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
}
}
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
if (in_g) {
/* backward data */
to_seq(device_ctx, batch_gate_g, *in_g);
}
if (bias && bias_g) {
/* backward bias */
bias_g->mutable_data<T>(ctx.GetPlace());
auto bias_g_e = EigenMatrix<T>::From(*bias_g);
auto gate_g_e = EigenMatrix<T>::From(batch_gate_g);
Eigen::array<int, 2> extents({{1, 4 * frame_size}});
Eigen::array<int, 2> offsets({{0, 0}});
auto bg = bias_g_e.slice(offsets, extents)
.reshape(Eigen::array<int, 2>({{1, frame_size * 4}}));
bg.device(ctx.GetEigenDevice<Place>()) =
gate_g_e.sum(Eigen::array<int, 1>({{0}}));
}
}
};
} // namespace operators
......
......@@ -53,7 +53,17 @@ class LoDTensor2BatchFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& lod_tensor,
framework::LoDTensor& batch, bool is_reverse) const {
framework::LoDTensor& batch, bool is_cal_batch_lod,
bool is_reverse = false) const {
if (!is_cal_batch_lod) {
auto lods = batch.lod();
PADDLE_ENFORCE_EQ(lods.size(), 2UL);
PADDLE_ENFORCE_EQ(lods[1].size(), lod_tensor.dims()[1]);
CopyMatrixRowsFunctor<Place, T> to_batch;
to_batch(context, lod_tensor, lods[1].data(), batch, true);
return;
}
auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
auto lod = lods[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册