提交 3d8b6ebc 编写于 作者: D dangqingqing

Add LSTM backward implenmentation.

上级 3f1062d7
...@@ -21,7 +21,6 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -21,7 +21,6 @@ class LSTMOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null."); "Input(Input) of LSTM should not be null.");
...@@ -30,8 +29,8 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -30,8 +29,8 @@ class LSTMOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasOutput("Cell"), PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null."); "Output(Cell) of LSTM should not be null.");
auto x_dims = ctx->GetInputDim("Input"); auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2."); PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");
if (ctx->HasInput("H0")) { if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"), PADDLE_ENFORCE(ctx->HasInput("C0"),
...@@ -44,7 +43,7 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -44,7 +43,7 @@ class LSTMOp : public framework::OperatorWithKernel {
"should be the same."); "should be the same.");
} }
int frame_size = x_dims[1] / 4; int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight"); auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2, PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"The rank of Input(Weight) should be 2."); "The rank of Input(Weight) should be 2.");
...@@ -71,9 +70,11 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -71,9 +70,11 @@ class LSTMOp : public framework::OperatorWithKernel {
"4 * %d if disable peepholes connection", "4 * %d if disable peepholes connection",
frame_size); frame_size);
} }
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size}); framework::DDim out_dims({in_dims[0], frame_size});
ctx->SetOutputDim("Cell", {x_dims[0], frame_size}); ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("BatchGate", x_dims); ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchGate", in_dims);
ctx->SetOutputDim("BatchCellPreAct", out_dims);
ctx->ShareLoD("Input", "Hidden"); ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell"); ctx->ShareLoD("Input", "Cell");
} }
...@@ -86,7 +87,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,7 +87,7 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Input", AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support " "(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in " "variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where, T is the " "this LoDTensor is a matrix with shape (T X 4D), where T is the "
"total time steps in this mini-batch, D is the hidden size."); "total time steps in this mini-batch, D is the hidden size.");
AddInput("H0", AddInput("H0",
"(Tensor, optional) the initial hidden state is an optional " "(Tensor, optional) the initial hidden state is an optional "
...@@ -110,21 +111,25 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -110,21 +111,25 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"2. `usePeepholes = True` " "2. `usePeepholes = True` "
" - The shape is (1 x 7D). " " - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}."); " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
AddOutput("Hidden",
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) the cell state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("BatchGate", AddOutput("BatchGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate " "(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the nonlinear computation. This " "and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which " "LoDTensor has the same shape with the reorganized input, which "
"was also be called batch input. The LoD size is 2. The first " "is also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the " "LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence " "indexes, which denote the position of reorganized sequence "
"in the raw input.") "in the raw input.")
.AsIntermediate(); .AsIntermediate();
AddOutput("Hidden", AddOutput("BatchCellPreAct",
"(LoDTensor) the hidden state lod tensor of LSTM operator. " "(LoDTensor) This LoDTensor is get in the forward and used "
"The shape and lod is the same with the `Input`."); "in the backward.")
AddOutput("Cell", .AsIntermediate();
"(LoDTensor) the cell state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddAttr<bool>("usePeepholes", AddAttr<bool>("usePeepholes",
"(bool, defalut: True) " "(bool, defalut: True) "
"whether to enable diagonal/peephole connections.") "whether to enable diagonal/peephole connections.")
...@@ -202,15 +207,28 @@ class LSTMGradOp : public framework::OperatorWithKernel { ...@@ -202,15 +207,28 @@ class LSTMGradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(Hidden@GRAD) should not be null"); "Input(Hidden@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")),
"Input(Cell@GRAD) should not be null"); "Input(Cell@GRAD) should not be null");
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
if (ctx->HasInput("Weight")) {
ctx->SetOutputDim(framework::GradVarName("Weight"), ctx->SetOutputDim(framework::GradVarName("Weight"),
ctx->GetInputDim("Weight")); ctx->GetInputDim("Weight"));
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias")); }
if (ctx->HasInput("Bias")) {
ctx->SetOutputDim(framework::GradVarName("Bias"),
ctx->GetInputDim("Bias"));
}
if (ctx->HasInput("H0")) {
ctx->SetOutputDim(framework::GradVarName("H0"), ctx->GetInputDim("H0"));
}
if (ctx->HasInput("C0")) {
ctx->SetOutputDim(framework::GradVarName("C0"), ctx->GetInputDim("C0"));
}
} }
}; };
......
...@@ -21,8 +21,9 @@ limitations under the License. */ ...@@ -21,8 +21,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using framework::Tensor; using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
...@@ -31,15 +32,15 @@ template <typename Place, typename T> ...@@ -31,15 +32,15 @@ template <typename Place, typename T>
class LSTMKernel : public framework::OpKernel<T> { class LSTMKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::LoDTensor>("Input"); auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<framework::Tensor>("Weight"); auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<framework::Tensor>("Bias"); auto* bias = ctx.Input<Tensor>("Bias");
auto* batch_gate = ctx.Output<framework::LoDTensor>("BatchGate"); auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(ctx.GetPlace()); batch_gate->mutable_data<T>(ctx.GetPlace());
auto* hidden_out = ctx.Output<framework::LoDTensor>("Hidden"); auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
hidden_out->mutable_data<T>(ctx.GetPlace()); hidden_out->mutable_data<T>(ctx.GetPlace());
auto* cell_out = ctx.Output<framework::LoDTensor>("Cell"); auto* cell_out = ctx.Output<LoDTensor>("Cell");
cell_out->mutable_data<T>(ctx.GetPlace()); cell_out->mutable_data<T>(ctx.GetPlace());
// Now the function ShareLoD in InferShape is not implemented. // Now the function ShareLoD in InferShape is not implemented.
...@@ -49,7 +50,8 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -49,7 +50,8 @@ class LSTMKernel : public framework::OpKernel<T> {
bool is_reverse = ctx.Attr<bool>("isReverse"); bool is_reverse = ctx.Attr<bool>("isReverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch; math::LoDTensor2BatchFunctor<Place, T> to_batch;
to_batch(ctx.device_context(), *input, *batch_gate, is_reverse); auto& device_ctx = ctx.device_context();
to_batch(device_ctx, *input, *batch_gate, true, is_reverse);
auto in_dims = input->dims(); auto in_dims = input->dims();
int frame_size = static_cast<int>(in_dims[1] / 4); int frame_size = static_cast<int>(in_dims[1] / 4);
...@@ -69,15 +71,23 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -69,15 +71,23 @@ class LSTMKernel : public framework::OpKernel<T> {
} }
math::LstmMetaValue<T> lstm_value; math::LstmMetaValue<T> lstm_value;
if (bias) {
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later. // the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size; lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size; lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size; lstm_value.checkOg = lstm_value.checkFg + frame_size;
} else {
lstm_value.checkIg = nullptr;
lstm_value.checkFg = nullptr;
lstm_value.checkOg = nullptr;
}
lstm_value.prevStateValue = nullptr; lstm_value.prevStateValue = nullptr;
framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act; // Use the local variable as here.
batch_out.mutable_data<T>(dims, ctx.GetPlace()); LoDTensor batch_hidden, batch_cell;
auto batch_cell_pre_act = *(ctx.Output<LoDTensor>("BatchCellPreAct"));
batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
batch_cell.mutable_data<T>(dims, ctx.GetPlace()); batch_cell.mutable_data<T>(dims, ctx.GetPlace());
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace()); batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());
...@@ -92,7 +102,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -92,7 +102,7 @@ class LSTMKernel : public framework::OpKernel<T> {
int bend = static_cast<int>(batch_starts[n + 1]); int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate_t = batch_gate->Slice(bstart, bend); Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor out_t = batch_out.Slice(bstart, bend); Tensor out_t = batch_hidden.Slice(bstart, bend);
Tensor cell_t = batch_cell.Slice(bstart, bend); Tensor cell_t = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend); Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend);
...@@ -101,9 +111,9 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -101,9 +111,9 @@ class LSTMKernel : public framework::OpKernel<T> {
if (n != 0) { if (n != 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]); int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size; int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end); auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false, math::matmul<Place, T>(device_ctx, pre_hidden_t, false, *weight, false,
*weight, false, static_cast<T>(1.0), &gate_t, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0)); static_cast<T>(1.0));
} }
// else if : FIXME support the initial hidden and cell // else if : FIXME support the initial hidden and cell
...@@ -112,27 +122,181 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -112,27 +122,181 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value.outputValue = out_t.data<T>(); lstm_value.outputValue = out_t.data<T>();
lstm_value.stateValue = cell_t.data<T>(); lstm_value.stateValue = cell_t.data<T>();
lstm_value.stateActiveValue = cell_pre_act_t.data<T>(); lstm_value.stateActiveValue = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<Place, T>::compute(ctx.device_context(), lstm_value, math::LstmUnitFunctor<Place, T>::compute(device_ctx, lstm_value,
frame_size, cur_batch_size, frame_size, cur_batch_size,
gate_act, cell_act, cand_act); gate_act, cell_act, cand_act);
lstm_value.prevStateValue = lstm_value.stateValue; lstm_value.prevStateValue = lstm_value.stateValue;
} }
math::Batch2LoDTensorFunctor<Place, T> to_seq; math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_out.set_lod(batch_gate->lod()); batch_hidden.set_lod(batch_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden // restore the output hidden in LoDTensor from the batch hidden
to_seq(ctx.device_context(), batch_out, *hidden_out); to_seq(device_ctx, batch_hidden, *hidden_out);
batch_cell.set_lod(batch_gate->lod()); batch_cell.set_lod(batch_gate->lod());
// restore the output cell state in LoDTensor from the batch cell // restore the output cell state in LoDTensor from the batch cell
to_seq(ctx.device_context(), batch_cell, *cell_out); to_seq(device_ctx, batch_cell, *cell_out);
} }
}; };
template <typename Place, typename T> template <typename Place, typename T>
class LSTMGradKernel : public framework::OpKernel<T> { class LSTMGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override {} void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden_out = ctx.Input<LoDTensor>("Hidden");
auto* cell_out = ctx.Input<LoDTensor>("Cell");
auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");
auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden"));
auto* cell_g = ctx.Input<LoDTensor>(framework::GradVarName("Cell"));
auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto& device_ctx = ctx.device_context();
if (weight_g) {
math::SetConstant<Place, T> zero;
zero(device_ctx, weight_g, static_cast<T>(0.0));
}
auto in_dims = input->dims();
auto out_dims = hidden_g->dims();
int frame_size = static_cast<int>(in_dims[1] / 4);
PADDLE_ENFORCE_EQ(frame_size, out_dims[1]);
math::LstmMetaValue<T> lstm_value;
if (bias) {
T* bias_data = const_cast<T*>(bias->data<T>());
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
} else {
lstm_value.checkIg = nullptr;
lstm_value.checkFg = nullptr;
lstm_value.checkOg = nullptr;
}
math::LstmMetaGrad<T> lstm_grad;
if (bias && bias_g) {
T* bias_g_data = const_cast<T*>(bias_g->mutable_data<T>(ctx.GetPlace()));
lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
} else {
lstm_grad.checkIgGrad = nullptr;
lstm_grad.checkFgGrad = nullptr;
lstm_grad.checkOgGrad = nullptr;
}
math::LoDTensor2BatchFunctor<Place, T> to_batch;
// use the local variable as here.
LoDTensor batch_hidden;
batch_hidden.mutable_data<T>(out_dims, ctx.GetPlace());
batch_hidden.set_lod(batch_gate->lod());
to_batch(device_ctx, *hidden_out, batch_hidden, false);
LoDTensor batch_hidden_g;
batch_hidden_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_hidden_g.set_lod(batch_gate->lod());
to_batch(device_ctx, *hidden_g, batch_hidden_g, false);
LoDTensor batch_cell;
batch_cell.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell.set_lod(batch_gate->lod());
to_batch(device_ctx, *cell_out, batch_cell, false);
LoDTensor batch_cell_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_g.set_lod(batch_gate->lod());
to_batch(device_ctx, *cell_g, batch_cell_g, false);
LoDTensor batch_gate_g;
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod());
auto gate_act = ctx.Attr<std::string>("gateActivation");
auto cell_act = ctx.Attr<std::string>("cellActivation");
auto cand_act = ctx.Attr<std::string>("candidateActivation");
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (int n = static_cast<int>(num_batch); n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate = batch_gate->Slice(bstart, bend);
Tensor cell = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend);
lstm_value.gateValue = gate.data<T>();
lstm_value.stateValue = cell.data<T>();
lstm_value.stateActiveValue = cell_pre_act.data<T>();
Tensor out_g = batch_hidden_g.Slice(bstart, bend);
Tensor gate_g = batch_gate_g.Slice(bstart, bend);
Tensor cell_g = batch_cell_g.Slice(bstart, bend);
lstm_grad.stateGrad = cell_g.data<T>();
lstm_grad.gateGrad = gate_g.data<T>();
lstm_grad.outputGrad = out_g.data<T>();
if (n != 0) {
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
lstm_value.prevStateValue = cell_pre.data<T>();
lstm_grad.prevStateGrad = cell_pre_g.data<T>();
} else {
lstm_value.prevStateValue = nullptr;
lstm_grad.prevStateGrad = nullptr;
}
int cur_batch_size = bend - bstart;
math::LstmUnitGradFunctor<Place, T>::compute(
device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
if (n != 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(device_ctx, gate_g, false, *weight, true,
static_cast<T>(1.0), &pre_hidden_g,
static_cast<T>(1.0));
if (weight_g) {
/* backward weight */
auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(device_ctx, pre_hidden, true, gate_g, false,
static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
}
}
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
if (in_g) {
/* backward data */
to_seq(device_ctx, batch_gate_g, *in_g);
}
if (bias && bias_g) {
/* backward bias */
bias_g->mutable_data<T>(ctx.GetPlace());
auto bias_g_e = EigenMatrix<T>::From(*bias_g);
auto gate_g_e = EigenMatrix<T>::From(batch_gate_g);
Eigen::array<int, 2> extents({{1, 4 * frame_size}});
Eigen::array<int, 2> offsets({{0, 0}});
auto bg = bias_g_e.slice(offsets, extents)
.reshape(Eigen::array<int, 2>({{1, frame_size * 4}}));
bg.device(ctx.GetEigenDevice<Place>()) =
gate_g_e.sum(Eigen::array<int, 1>({{0}}));
}
}
}; };
} // namespace operators } // namespace operators
......
...@@ -53,7 +53,17 @@ class LoDTensor2BatchFunctor { ...@@ -53,7 +53,17 @@ class LoDTensor2BatchFunctor {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& lod_tensor, const framework::LoDTensor& lod_tensor,
framework::LoDTensor& batch, bool is_reverse) const { framework::LoDTensor& batch, bool is_cal_batch_lod,
bool is_reverse = false) const {
if (!is_cal_batch_lod) {
auto lods = batch.lod();
PADDLE_ENFORCE_EQ(lods.size(), 2UL);
PADDLE_ENFORCE_EQ(lods[1].size(), lod_tensor.dims()[1]);
CopyMatrixRowsFunctor<Place, T> to_batch;
to_batch(context, lod_tensor, lods[1].data(), batch, true);
return;
}
auto lods = lod_tensor.lod(); auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
auto lod = lods[0]; auto lod = lods[0];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册