未验证 提交 41d05339 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #5429 from qingqing01/lstm_fix

Enable hidden/cell state initialization and enhance unit testing in LSTM operator.
...@@ -24,6 +24,11 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -24,6 +24,11 @@ class LSTMOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null."); "Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"), PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null."); "Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"), PADDLE_ENFORCE(ctx->HasOutput("Cell"),
...@@ -59,11 +64,13 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -59,11 +64,13 @@ class LSTMOp : public framework::OperatorWithKernel {
"The second dimension of Input(Weight) " "The second dimension of Input(Weight) "
"should be 4 * %d.", "should be 4 * %d.",
frame_size); frame_size);
auto b_dims = ctx->GetInputDim("Bias"); auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2."); PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1, PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1."); "The first dimension of Input(Bias) should be 1.");
if (ctx->Attrs().Get<bool>("usePeepholes")) {
if (ctx->Attrs().Get<bool>("use_peepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size, PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
"The second dimension of Input(Bias) should be " "The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection", "7 * %d if enable peepholes connection",
...@@ -74,6 +81,7 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -74,6 +81,7 @@ class LSTMOp : public framework::OperatorWithKernel {
"4 * %d if disable peepholes connection", "4 * %d if disable peepholes connection",
frame_size); frame_size);
} }
framework::DDim out_dims({in_dims[0], frame_size}); framework::DDim out_dims({in_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims); ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims); ctx->SetOutputDim("Cell", out_dims);
...@@ -118,14 +126,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -118,14 +126,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Bias", AddInput("Bias",
"(Tensor) the learnable weights, which contains two parts: " "(Tensor) the learnable weights, which contains two parts: "
"input-hidden bias weight and peephole connections weight if " "input-hidden bias weight and peephole connections weight if "
"setting `usePeepholes` True. " "setting `use_peepholes` True. "
"1. `usePeepholes = False` " "1. `use_peepholes = False` "
" - The shape is (1 x 4D). " " - The shape is (1 x 4D). "
" - Bias = {b_c, b_i, b_f, b_o}." " - Bias = {b_c, b_i, b_f, b_o}."
"2. `usePeepholes = True` " "2. `use_peepholes = True` "
" - The shape is (1 x 7D). " " - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.") " - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
.AsDispensable();
AddOutput("Hidden", AddOutput("Hidden",
"(LoDTensor) the hidden state of LSTM operator. " "(LoDTensor) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`."); "The shape is (T x D), and lod is the same with the `Input`.");
...@@ -145,29 +152,32 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -145,29 +152,32 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) This LoDTensor is obtained in the forward and used " "(LoDTensor) This LoDTensor is obtained in the forward and used "
"in the backward.") "in the backward.")
.AsIntermediate(); .AsIntermediate();
AddAttr<bool>("usePeepholes", AddAttr<bool>("use_peepholes",
"(bool, default True) " "(bool, defalut: True) "
"whether to enable diagonal/peephole connections.") "whether to enable diagonal/peephole connections.")
.SetDefault(true); .SetDefault(true);
AddAttr<bool>("isReverse", AddAttr<bool>("is_reverse",
"(bool, default False) " "(bool, defalut: False) "
"whether to compute reversed LSTM.") "whether to compute reversed LSTM.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>( AddAttr<std::string>(
"gateActivation", "gate_activation",
"(string, default sigmoid)" "(string, default: sigmoid)"
"The activation for input gate, forget gate and output " "The activation for input gate, forget gate and output "
"gate, `sigmoid` by default.") "gate, `sigmoid` by default.")
.SetDefault("sigmoid"); .SetDefault("sigmoid")
AddAttr<std::string>("cellActivation", .InEnum({"sigmoid", "tanh", "relu", "identity"});
"(string, default tanh)" AddAttr<std::string>("cell_activation",
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.") "The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh"); .SetDefault("tanh")
AddAttr<std::string>("candidateActivation", .InEnum({"sigmoid", "tanh", "relu", "identity"});
"(string, default tanh)" AddAttr<std::string>("candidate_activation",
"(string, default: tanh)"
"The activation for candidate hidden state, " "The activation for candidate hidden state, "
"`tanh` by default.") "`tanh` by default.")
.SetDefault("tanh"); .SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC( AddComment(R"DOC(
Long-Short Term Memory (LSTM) Operator. Long-Short Term Memory (LSTM) Operator.
...@@ -203,7 +213,7 @@ are the cell input and cell output activation functions and `tanh` is usually ...@@ -203,7 +213,7 @@ are the cell input and cell output activation functions and `tanh` is usually
used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state, used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state,
which is computed based on the current input and the previous hidden state. which is computed based on the current input and the previous hidden state.
Set usePeepholes False to disable peephole connection Set `use_peepholes` False to disable peephole connection
(http://www.bioinf.jku.at/publications/older/2604.pdf). The formula (http://www.bioinf.jku.at/publications/older/2604.pdf). The formula
is omitted here. is omitted here.
...@@ -226,23 +236,27 @@ class LSTMGradOp : public framework::OperatorWithKernel { ...@@ -226,23 +236,27 @@ class LSTMGradOp : public framework::OperatorWithKernel {
"Input(Hidden) of LSTM should not be null."); "Input(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"), PADDLE_ENFORCE(ctx->HasInput("Cell"),
"Input(Cell) of LSTM should not be null."); "Input(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchGate"), PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(BatchGate) of LSTM should not be null."); "Input(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"), PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTM should not be null."); "Input(BatchGate) of LSTM should not be null.");
auto in_g_name = framework::GradVarName("Input"); auto SetOutGradDim = [&ctx](const std::string& name) {
if (ctx->HasOutput(in_g_name)) auto g_name = framework::GradVarName(name);
ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input")); if (ctx->HasOutput(g_name))
ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
auto w_g_name = framework::GradVarName("Weight"); };
if (ctx->HasOutput(w_g_name))
ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight")); SetOutGradDim("Input");
SetOutGradDim("Weight");
auto b_g_name = framework::GradVarName("Bias"); SetOutGradDim("Bias");
if (ctx->HasOutput(b_g_name)) SetOutGradDim("H0");
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias")); SetOutGradDim("C0");
} }
protected: protected:
......
...@@ -28,6 +28,15 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -28,6 +28,15 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
inline void ReorderInitState(const platform::DeviceContext& ctx,
const framework::Tensor& src, const size_t* index,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index, *dst, indexed_src);
}
template <typename Place, typename T> template <typename Place, typename T>
class LSTMKernel : public framework::OpKernel<T> { class LSTMKernel : public framework::OpKernel<T> {
public: public:
...@@ -36,6 +45,9 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -36,6 +45,9 @@ class LSTMKernel : public framework::OpKernel<T> {
auto* weight = ctx.Input<Tensor>("Weight"); auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias"); auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden_t0 = ctx.Input<Tensor>("H0");
auto* cell_t0 = ctx.Input<Tensor>("C0");
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate"); auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(ctx.GetPlace()); batch_gate->mutable_data<T>(ctx.GetPlace());
auto* hidden_out = ctx.Output<LoDTensor>("Hidden"); auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
...@@ -43,12 +55,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -43,12 +55,7 @@ class LSTMKernel : public framework::OpKernel<T> {
auto* cell_out = ctx.Output<LoDTensor>("Cell"); auto* cell_out = ctx.Output<LoDTensor>("Cell");
cell_out->mutable_data<T>(ctx.GetPlace()); cell_out->mutable_data<T>(ctx.GetPlace());
// Now the function ShareLoD in InferShape is not implemented. bool is_reverse = ctx.Attr<bool>("is_reverse");
// So copy LoD here.
ctx.ShareLoD("Input", "Hidden");
ctx.ShareLoD("Input", "Cell");
bool is_reverse = ctx.Attr<bool>("isReverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch; math::LoDTensor2BatchFunctor<Place, T> to_batch;
auto& device_ctx = ctx.device_context(); auto& device_ctx = ctx.device_context();
to_batch(device_ctx, *input, *batch_gate, true, is_reverse); to_batch(device_ctx, *input, *batch_gate, true, is_reverse);
...@@ -71,7 +78,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -71,7 +78,7 @@ class LSTMKernel : public framework::OpKernel<T> {
} }
math::LstmMetaValue<T> lstm_value; math::LstmMetaValue<T> lstm_value;
if (bias) { if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later. // the code style in LstmMetaValue will be updated later.
...@@ -84,6 +91,16 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -84,6 +91,16 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value.checkOg = nullptr; lstm_value.checkOg = nullptr;
} }
lstm_value.prevStateValue = nullptr; lstm_value.prevStateValue = nullptr;
Tensor ordered_c0;
const size_t* order = batch_gate->lod()[2].data();
if (cell_t0) {
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<Place, T>(device_ctx, *cell_t0, order, &ordered_c0,
true);
lstm_value.prevStateValue = ordered_c0.data<T>();
}
// Use the local variable as here. // Use the local variable as here.
LoDTensor batch_hidden, batch_cell; LoDTensor batch_hidden, batch_cell;
...@@ -94,9 +111,9 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -94,9 +111,9 @@ class LSTMKernel : public framework::OpKernel<T> {
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
auto gate_act = ctx.Attr<std::string>("gateActivation"); auto gate_act = ctx.Attr<std::string>("gate_activation");
auto cell_act = ctx.Attr<std::string>("cellActivation"); auto cell_act = ctx.Attr<std::string>("cell_activation");
auto cand_act = ctx.Attr<std::string>("candidateActivation"); auto cand_act = ctx.Attr<std::string>("candidate_activation");
for (size_t n = 0; n < num_batch; n++) { for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]); int bstart = static_cast<int>(batch_starts[n]);
...@@ -109,15 +126,28 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -109,15 +126,28 @@ class LSTMKernel : public framework::OpKernel<T> {
int cur_batch_size = bend - bstart; int cur_batch_size = bend - bstart;
if (n != 0) { if (n > 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]); int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size; int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end); auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(device_ctx, pre_hidden_t, false, *weight, false, math::matmul<Place, T>(device_ctx, pre_hidden_t, false, *weight, false,
static_cast<T>(1.0), &gate_t, static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0)); static_cast<T>(1.0));
} else if (hidden_t0) {
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor ordered_h0;
ReorderInitState<Place, T>(device_ctx, *hidden_t0, order, &ordered_h0,
true);
math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false,
static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
} }
// else if : FIXME support the initial hidden and cell
lstm_value.gateValue = gate_t.data<T>(); lstm_value.gateValue = gate_t.data<T>();
lstm_value.outputValue = out_t.data<T>(); lstm_value.outputValue = out_t.data<T>();
...@@ -160,6 +190,12 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -160,6 +190,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight")); auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto* h0 = ctx.Input<Tensor>("H0");
auto* c0 = ctx.Input<Tensor>("C0");
auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0"));
auto* c0_g = ctx.Output<Tensor>(framework::GradVarName("C0"));
auto& device_ctx = ctx.device_context(); auto& device_ctx = ctx.device_context();
math::SetConstant<Place, T> zero; math::SetConstant<Place, T> zero;
if (weight_g) { if (weight_g) {
...@@ -167,13 +203,25 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -167,13 +203,25 @@ class LSTMGradKernel : public framework::OpKernel<T> {
zero(device_ctx, weight_g, static_cast<T>(0.0)); zero(device_ctx, weight_g, static_cast<T>(0.0));
} }
// ordered_h0/c0 is the reordered hidden/cell initialization.
// ordered_h0_g/c0_g is the reordered gradient of hidden/cell
// initialization.
Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
const size_t* order = batch_gate->lod()[2].data();
if (c0) {
ReorderInitState<Place, T>(device_ctx, *c0, order, &ordered_c0, true);
}
if (c0 && c0_g) {
ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
}
auto in_dims = input->dims(); auto in_dims = input->dims();
auto out_dims = hidden_g->dims(); auto out_dims = hidden_g->dims();
int frame_size = static_cast<int>(in_dims[1] / 4); int frame_size = static_cast<int>(in_dims[1] / 4);
PADDLE_ENFORCE_EQ(frame_size, out_dims[1]); PADDLE_ENFORCE_EQ(frame_size, out_dims[1]);
math::LstmMetaValue<T> lstm_value; math::LstmMetaValue<T> lstm_value;
if (bias) { if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
lstm_value.checkIg = bias_data + 4 * frame_size; lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size; lstm_value.checkFg = lstm_value.checkIg + frame_size;
...@@ -185,9 +233,13 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -185,9 +233,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
} }
math::LstmMetaGrad<T> lstm_grad; math::LstmMetaGrad<T> lstm_grad;
if (bias && bias_g) { if (bias && bias_g) {
T* bias_g_data = const_cast<T*>(bias_g->mutable_data<T>(ctx.GetPlace())); bias_g->mutable_data<T>(ctx.GetPlace());
zero(device_ctx, bias_g, static_cast<T>(0.0)); zero(device_ctx, bias_g, static_cast<T>(0.0));
}
if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
T* bias_g_data = bias_g->data<T>();
lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size; lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size; lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size; lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
...@@ -199,36 +251,30 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -199,36 +251,30 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math::LoDTensor2BatchFunctor<Place, T> to_batch; math::LoDTensor2BatchFunctor<Place, T> to_batch;
// use the local variable as here. auto ToBatch = [&batch_gate, &to_batch](
LoDTensor batch_hidden; const platform::DeviceContext& ctx, const framework::LoDTensor& src,
batch_hidden.mutable_data<T>(out_dims, ctx.GetPlace()); const framework::DDim& dims, framework::LoDTensor& dst) {
batch_hidden.set_lod(batch_gate->lod()); dst.mutable_data<T>(dims, ctx.GetPlace());
to_batch(device_ctx, *hidden_out, batch_hidden, false); dst.set_lod(batch_gate->lod());
to_batch(ctx, src, dst, false);
};
LoDTensor batch_hidden_g; LoDTensor batch_hidden, batch_hidden_g, batch_cell;
batch_hidden_g.mutable_data<T>(out_dims, ctx.GetPlace()); ToBatch(device_ctx, *hidden_out, out_dims, batch_hidden);
batch_hidden_g.set_lod(batch_gate->lod()); ToBatch(device_ctx, *hidden_g, out_dims, batch_hidden_g);
to_batch(device_ctx, *hidden_g, batch_hidden_g, false); ToBatch(device_ctx, *cell_out, out_dims, batch_cell);
LoDTensor batch_cell; LoDTensor batch_cell_g, batch_gate_g;
batch_cell.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell.set_lod(batch_gate->lod());
to_batch(device_ctx, *cell_out, batch_cell, false);
LoDTensor batch_cell_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace()); batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_g.set_lod(batch_gate->lod());
// TODO(qingqing) support the case output cell has gradient. // TODO(qingqing) support the case output cell has gradient.
// to_batch(device_ctx, *cell_g, batch_cell_g, false); // to_batch(device_ctx, *cell_g, batch_cell_g, false);
zero(device_ctx, &batch_cell_g, static_cast<T>(0.0)); zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));
LoDTensor batch_gate_g;
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace()); batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod()); batch_gate_g.set_lod(batch_gate->lod());
auto gate_act = ctx.Attr<std::string>("gateActivation"); auto gate_act = ctx.Attr<std::string>("gate_activation");
auto cell_act = ctx.Attr<std::string>("cellActivation"); auto cell_act = ctx.Attr<std::string>("cell_activation");
auto cand_act = ctx.Attr<std::string>("candidateActivation"); auto cand_act = ctx.Attr<std::string>("candidate_activation");
auto batch_starts = batch_gate->lod()[0]; auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1; size_t num_batch = batch_starts.size() - 1;
...@@ -250,15 +296,15 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -250,15 +296,15 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_grad.gateGrad = gate_g.data<T>(); lstm_grad.gateGrad = gate_g.data<T>();
lstm_grad.outputGrad = out_g.data<T>(); lstm_grad.outputGrad = out_g.data<T>();
if (n) { if (n > 0) {
int bstart_pre = static_cast<int>(batch_starts[n - 1]); int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart); Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart); Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
lstm_value.prevStateValue = cell_pre.data<T>(); lstm_value.prevStateValue = cell_pre.data<T>();
lstm_grad.prevStateGrad = cell_pre_g.data<T>(); lstm_grad.prevStateGrad = cell_pre_g.data<T>();
} else { } else {
lstm_value.prevStateValue = nullptr; lstm_value.prevStateValue = c0 ? ordered_c0.data<T>() : nullptr;
lstm_grad.prevStateGrad = nullptr; lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data<T>() : nullptr;
} }
int cur_batch_size = bend - bstart; int cur_batch_size = bend - bstart;
...@@ -266,7 +312,7 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -266,7 +312,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size, device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
gate_act, cell_act, cand_act); gate_act, cell_act, cand_act);
if (n != 0) { if (n > 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]); int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size; int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end); auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
...@@ -280,6 +326,19 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -280,6 +326,19 @@ class LSTMGradKernel : public framework::OpKernel<T> {
static_cast<T>(1.0), weight_g, static_cast<T>(1.0), weight_g,
static_cast<T>(1.0)); static_cast<T>(1.0));
} }
} else {
if (h0 && weight_g) {
ReorderInitState<Place, T>(device_ctx, *h0, order, &ordered_h0, true);
math::matmul<Place, T>(device_ctx, ordered_h0, true, gate_g, false,
static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
}
if (h0 && h0_g) {
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
math::matmul<Place, T>(device_ctx, gate_g, false, *weight, true,
static_cast<T>(1.0), &ordered_h0_g,
static_cast<T>(0.0));
}
} }
} }
...@@ -302,6 +361,13 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -302,6 +361,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math::gemv<Place, T>(device_ctx, true, m, n, 1., batch_gate_g.data<T>(), math::gemv<Place, T>(device_ctx, true, m, n, 1., batch_gate_g.data<T>(),
ones.data<T>(), 0., bias_g->data<T>()); ones.data<T>(), 0., bias_g->data<T>());
} }
if (h0 && h0_g) {
ReorderInitState<Place, T>(device_ctx, ordered_h0_g, order, h0_g, false);
}
if (c0 && c0_g) {
ReorderInitState<Place, T>(device_ctx, ordered_c0_g, order, c0_g, false);
}
} }
}; };
......
...@@ -52,9 +52,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -52,9 +52,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
rValueIg = valueIg[i]; rValueIg = valueIg[i];
rValueFg = valueFg[i]; rValueFg = valueFg[i];
rValueOg = valueOg[i]; rValueOg = valueOg[i];
rCheckI = value.checkIg[i]; rCheckI = value.checkIg ? value.checkIg[i] : 0;
rCheckF = value.checkFg[i]; rCheckF = value.checkFg ? value.checkFg[i] : 0;
rCheckO = value.checkOg[i]; rCheckO = value.checkOg ? value.checkOg[i] : 0;
if (value.prevStateValue) { if (value.prevStateValue) {
rPrevState = value.prevStateValue[i]; rPrevState = value.prevStateValue[i];
...@@ -114,9 +114,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -114,9 +114,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
rValueIg = valueIg[i]; rValueIg = valueIg[i];
rValueFg = valueFg[i]; rValueFg = valueFg[i];
rValueOg = valueOg[i]; rValueOg = valueOg[i];
rCheckI = value.checkIg[i]; rCheckI = value.checkIg ? value.checkIg[i] : 0;
rCheckF = value.checkFg[i]; rCheckF = value.checkFg ? value.checkFg[i] : 0;
rCheckO = value.checkOg[i]; rCheckO = value.checkOg ? value.checkOg[i] : 0;
rState = value.stateValue[i]; rState = value.stateValue[i];
rStateAtv = value.stateActiveValue[i]; rStateAtv = value.stateActiveValue[i];
rOutputGrad = grad.outputGrad[i]; rOutputGrad = grad.outputGrad[i];
...@@ -155,9 +155,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -155,9 +155,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize,
__m256 rValueIg; __m256 rValueIg;
__m256 rValueFg; __m256 rValueFg;
__m256 rValueOg; __m256 rValueOg;
__m256 rCheckI; __m256 rCheckI = _mm256_set1_ps(0.0f);
__m256 rCheckF; __m256 rCheckF = _mm256_set1_ps(0.0f);
__m256 rCheckO; __m256 rCheckO = _mm256_set1_ps(0.0f);
__m256 rState; __m256 rState;
__m256 rPrevState = _mm256_set1_ps(0.0f); __m256 rPrevState = _mm256_set1_ps(0.0f);
__m256 rStateAtv; __m256 rStateAtv;
...@@ -173,9 +173,11 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -173,9 +173,11 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize,
rValueIg = valueIg[i]; rValueIg = valueIg[i];
rValueFg = valueFg[i]; rValueFg = valueFg[i];
rValueOg = valueOg[i]; rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i]; if (value.checkIg) {
rCheckF = ((__m256 *)value.checkFg)[i]; rCheckI = ((__m256 *)value.checkIg)[i];
rCheckO = ((__m256 *)value.checkOg)[i]; rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
}
if (value.prevStateValue) { if (value.prevStateValue) {
rPrevState = ((__m256 *)value.prevStateValue)[i]; rPrevState = ((__m256 *)value.prevStateValue)[i];
...@@ -216,9 +218,9 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -216,9 +218,9 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
__m256 rState; __m256 rState;
__m256 rStateAtv; __m256 rStateAtv;
__m256 rOutputGrad; __m256 rOutputGrad;
__m256 rCheckI; __m256 rCheckI = _mm256_set1_ps(0.0f);
__m256 rCheckF; __m256 rCheckF = _mm256_set1_ps(0.0f);
__m256 rCheckO; __m256 rCheckO = _mm256_set1_ps(0.0f);
__m256 rCheckIGrad; __m256 rCheckIGrad;
__m256 rCheckFGrad; __m256 rCheckFGrad;
__m256 rCheckOGrad; __m256 rCheckOGrad;
...@@ -237,9 +239,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -237,9 +239,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
rValueIg = valueIg[i]; rValueIg = valueIg[i];
rValueFg = valueFg[i]; rValueFg = valueFg[i];
rValueOg = valueOg[i]; rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i]; if (value.checkIg) {
rCheckF = ((__m256 *)value.checkFg)[i]; rCheckI = ((__m256 *)value.checkIg)[i];
rCheckO = ((__m256 *)value.checkOg)[i]; rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
}
rState = ((__m256 *)value.stateValue)[i]; rState = ((__m256 *)value.stateValue)[i];
rStateAtv = ((__m256 *)value.stateActiveValue)[i]; rStateAtv = ((__m256 *)value.stateActiveValue)[i];
rOutputGrad = ((__m256 *)grad.outputGrad)[i]; rOutputGrad = ((__m256 *)grad.outputGrad)[i];
......
...@@ -55,9 +55,10 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -55,9 +55,10 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
T rValueIg; T rValueIg;
T rValueFg; T rValueFg;
T rValueOg; T rValueOg;
T rCheckI = value.checkIg[frameIdx];
T rCheckF = value.checkFg[frameIdx]; T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0;
T rCheckO = value.checkOg[frameIdx]; T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0;
T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0;
rValueIn = value.gateValue[frameIdx]; rValueIn = value.gateValue[frameIdx];
rValueIg = value.gateValue[frameIdx + frameSize]; rValueIg = value.gateValue[frameIdx + frameSize];
...@@ -121,9 +122,10 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -121,9 +122,10 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
T rStateGrad; T rStateGrad;
T rStateAtv; T rStateAtv;
T rOutputGrad; T rOutputGrad;
T rCheckI = value.checkIg[frameIdx]; T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0;
T rCheckF = value.checkFg[frameIdx]; T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0;
T rCheckO = value.checkOg[frameIdx]; T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0;
T rCheckIGrad; T rCheckIGrad;
T rCheckFGrad; T rCheckFGrad;
T rCheckOGrad; T rCheckOGrad;
......
...@@ -22,8 +22,8 @@ template <typename T> ...@@ -22,8 +22,8 @@ template <typename T>
class CopyMatrixRowsFunctor<platform::CPUPlace, T> { class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index, const framework::Tensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index) { framework::Tensor& dst, bool is_src_index) {
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = dst.dims(); auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2UL, PADDLE_ENFORCE_EQ(src_dims.size(), 2UL,
......
...@@ -41,8 +41,8 @@ template <typename T> ...@@ -41,8 +41,8 @@ template <typename T>
class CopyMatrixRowsFunctor<platform::GPUPlace, T> { class CopyMatrixRowsFunctor<platform::GPUPlace, T> {
public: public:
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index, const framework::Tensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index) { framework::Tensor& dst, bool is_src_index) {
auto src_dims = src.dims(); auto src_dims = src.dims();
auto dst_dims = dst.dims(); auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2, PADDLE_ENFORCE_EQ(src_dims.size(), 2,
......
...@@ -30,8 +30,8 @@ class CopyMatrixRowsFunctor { ...@@ -30,8 +30,8 @@ class CopyMatrixRowsFunctor {
// copy the input src to the indexed rows of output dst. // copy the input src to the indexed rows of output dst.
// The indexed rows are based on the input index. // The indexed rows are based on the input index.
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index, const framework::Tensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index); framework::Tensor& dst, bool is_src_index);
}; };
template <typename Place, typename T> template <typename Place, typename T>
...@@ -57,7 +57,7 @@ class LoDTensor2BatchFunctor { ...@@ -57,7 +57,7 @@ class LoDTensor2BatchFunctor {
bool is_reverse = false) const { bool is_reverse = false) const {
if (!is_cal_batch_lod) { if (!is_cal_batch_lod) {
auto lods = batch.lod(); auto lods = batch.lod();
PADDLE_ENFORCE_EQ(lods.size(), 2UL); PADDLE_ENFORCE_GT(lods.size(), 2UL);
PADDLE_ENFORCE_EQ(lods[1].size(), PADDLE_ENFORCE_EQ(lods[1].size(),
static_cast<size_t>(lod_tensor.dims()[0])); static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_batch; CopyMatrixRowsFunctor<Place, T> to_batch;
...@@ -66,8 +66,8 @@ class LoDTensor2BatchFunctor { ...@@ -66,8 +66,8 @@ class LoDTensor2BatchFunctor {
} }
auto lods = lod_tensor.lod(); auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
auto lod = lods[0]; auto lod = lods[0];
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
std::vector<SeqInfo> seq_info; std::vector<SeqInfo> seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
...@@ -78,8 +78,7 @@ class LoDTensor2BatchFunctor { ...@@ -78,8 +78,7 @@ class LoDTensor2BatchFunctor {
std::sort(seq_info.begin(), seq_info.end(), std::sort(seq_info.begin(), seq_info.end(),
[](SeqInfo a, SeqInfo b) { return a.length > b.length; }); [](SeqInfo a, SeqInfo b) { return a.length > b.length; });
// calculate the start position of each batch // Calculate the start position of each batch.
// (numBatch equal the maxLength of sequences)
// example: sequences = {s0, s1, s2} // example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2 // s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// num_batch = 5, // num_batch = 5,
...@@ -95,19 +94,25 @@ class LoDTensor2BatchFunctor { ...@@ -95,19 +94,25 @@ class LoDTensor2BatchFunctor {
// 6, 2, 11, // 6, 2, 11,
// 7, 3, // 7, 3,
// 8} // 8}
// The batch number represents batch size after rearranging the // seq_order = {1, 0, 2}, the sort order.
// where 1 is the second sequence,
// 0 is the first sequence,
// 2 is the third sequence.
// The num_batch represents batch size after rearranging the
// input LodTensor. It is also the maximum length of input sequence. // input LodTensor. It is also the maximum length of input sequence.
paddle::framework::LoD batch_lods; paddle::framework::LoD batch_lods;
batch_lods.emplace_back(std::vector<size_t>{0}); batch_lods.emplace_back(std::vector<size_t>{0});
batch_lods.emplace_back(std::vector<size_t>{0}); batch_lods.emplace_back(std::vector<size_t>{0});
batch_lods.emplace_back(std::vector<size_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor // batch_lods[0] is the start positions for batch LoDTensor
int num_batch = seq_info[0].length; int num_batch = seq_info[0].length;
batch_lods[0].resize(static_cast<size_t>(num_batch + 1)); batch_lods[0].resize(static_cast<size_t>(num_batch + 1));
// batch_lods[1] is the raw index in the input LoDTensor // batch_lods[1] is the raw index in the input LoDTensor
auto dims = lod_tensor.dims(); batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
batch_lods[1].resize(static_cast<size_t>(dims[0])); // batch_lods[2] is the sort order for the input LoDTensor.
batch_lods[2].resize(seq_info.size());
size_t* batch_starts = batch_lods[0].data(); size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data(); size_t* seq2batch_idx = batch_lods[1].data();
...@@ -127,6 +132,10 @@ class LoDTensor2BatchFunctor { ...@@ -127,6 +132,10 @@ class LoDTensor2BatchFunctor {
} }
batch_starts[n + 1] = static_cast<size_t>(batch_id); batch_starts[n + 1] = static_cast<size_t>(batch_id);
} }
size_t* seq_order = batch_lods[2].data();
for (size_t i = 0; i < seq_info.size(); ++i) {
seq_order[i] = seq_info[i].seq_idx;
}
batch.set_lod(batch_lods); batch.set_lod(batch_lods);
CopyMatrixRowsFunctor<Place, T> to_batch; CopyMatrixRowsFunctor<Place, T> to_batch;
...@@ -141,8 +150,7 @@ class Batch2LoDTensorFunctor { ...@@ -141,8 +150,7 @@ class Batch2LoDTensorFunctor {
const framework::LoDTensor& batch, const framework::LoDTensor& batch,
framework::LoDTensor& lod_tensor) const { framework::LoDTensor& lod_tensor) const {
auto in_lod = batch.lod(); auto in_lod = batch.lod();
PADDLE_ENFORCE_EQ(in_lod.size(), 2UL, PADDLE_ENFORCE_GT(in_lod.size(), 2UL);
"The LoD size of input `batch` should be 2.");
PADDLE_ENFORCE_EQ(in_lod[1].size(), PADDLE_ENFORCE_EQ(in_lod[1].size(),
static_cast<size_t>(lod_tensor.dims()[0])); static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_seq; CopyMatrixRowsFunctor<Place, T> to_seq;
......
...@@ -117,8 +117,9 @@ class TestLstmOp(OpTest): ...@@ -117,8 +117,9 @@ class TestLstmOp(OpTest):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.has_initial_state = True self.has_initial_state = False
self.is_reverse = False self.is_reverse = False
self.use_peepholes = True
def setUp(self): def setUp(self):
self.set_argument() self.set_argument()
...@@ -128,18 +129,28 @@ class TestLstmOp(OpTest): ...@@ -128,18 +129,28 @@ class TestLstmOp(OpTest):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
x = np.random.normal(size=(T, 4 * self.D)).astype('float64') x = np.random.normal(size=(T, 4 * self.D)).astype('float64')
h0 = np.zeros((N, self.D)).astype('float64') if self.has_initial_state:
c0 = np.zeros((N, self.D)).astype('float64') h0 = np.random.normal(size=(N, self.D)).astype('float64')
c0 = np.random.normal(size=(N, self.D)).astype('float64')
else:
h0 = np.zeros((N, self.D)).astype('float64')
c0 = np.zeros((N, self.D)).astype('float64')
w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64') w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
b = np.random.normal(size=(1, 7 * self.D)).astype('float64') if self.use_peepholes:
b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
else:
b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
w_b = b[:, 0:4 * self.D] w_b = b[:, 0:4 * self.D]
w_c = b[:, 4 * self.D:] w_c = b[:, 4 * self.D:] if self.use_peepholes else None
h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell], ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand]) ACTVATION[self.act_cand])
self.inputs = {'Input': (x, self.lod), 'Weight': w, 'Bias': b} self.inputs = {'Input': (x, self.lod), 'Weight': w}
self.inputs['Bias'] = b
if self.has_initial_state: if self.has_initial_state:
self.inputs['H0'] = h0 self.inputs['H0'] = h0
self.inputs['C0'] = c0 self.inputs['C0'] = c0
...@@ -149,17 +160,16 @@ class TestLstmOp(OpTest): ...@@ -149,17 +160,16 @@ class TestLstmOp(OpTest):
'Cell': (c, self.lod), 'Cell': (c, self.lod),
} }
self.attrs = { self.attrs = {
'usePeepholes': True, 'use_peepholes': self.use_peepholes,
'isReverse': self.is_reverse, 'is_reverse': self.is_reverse,
'gateActivation': self.act_gate, 'gate_activation': self.act_gate,
'cellActivation': self.act_cell, 'cell_activation': self.act_cell,
'candidateActivation': self.act_cand 'candidate_activation': self.act_cand
} }
def test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-8) self.check_output(atol=1e-8)
#TODO(qingqing) add more unit testing case
def test_check_grad(self): def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined. # TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
...@@ -170,7 +180,7 @@ class TestLstmOp(OpTest): ...@@ -170,7 +180,7 @@ class TestLstmOp(OpTest):
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4)
class TestLstmOpHasNoInitial(TestLstmOp): class TestLstmOpHasInitial(TestLstmOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 5, 7]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
...@@ -179,8 +189,69 @@ class TestLstmOpHasNoInitial(TestLstmOp): ...@@ -179,8 +189,69 @@ class TestLstmOpHasNoInitial(TestLstmOp):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.has_initial_state = False self.has_initial_state = True
self.is_reverse = True self.is_reverse = True
self.use_peepholes = True
def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'],
max_relative_error=5e-4)
def test_check_grad_ingore_bias(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Weight'))
def test_check_grad_ingore_input(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Weight', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Input'))
def test_check_grad_ingore_h0(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'C0'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('H0'))
def test_check_grad_ingore_c0(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'H0'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('C0'))
class TestLstmOpRerverse(TestLstmOp): class TestLstmOpRerverse(TestLstmOp):
...@@ -192,8 +263,23 @@ class TestLstmOpRerverse(TestLstmOp): ...@@ -192,8 +263,23 @@ class TestLstmOpRerverse(TestLstmOp):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.has_initial_state = True self.has_initial_state = False
self.is_reverse = True
self.use_peepholes = True
class TestLstmOpNotUsePeepholes(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.has_initial_state = False
self.is_reverse = True self.is_reverse = True
self.use_peepholes = False
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -102,7 +102,7 @@ class Momentum(Optimizer): ...@@ -102,7 +102,7 @@ class Momentum(Optimizer):
.. math:: .. math::
v_{t} &= k * v_{t-1} - \\gamma_t / (g_{t} + \\lambda w_{t-1}) \\\\ v_{t} &= k * v_{t-1} - \\gamma_t (g_{t} + \\lambda w_{t-1}) \\\\
w_{t} &= w_{t-1} + v_{t} \\\\ w_{t} &= w_{t-1} + v_{t} \\\\
where, :math:`k` is momentum, :math:`\\lambda` is decay rate, where, :math:`k` is momentum, :math:`\\lambda` is decay rate,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册