未验证 提交 41d05339 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #5429 from qingqing01/lstm_fix

Enable hidden/cell state initialization and enhance unit testing in LSTM operator.
......@@ -24,6 +24,11 @@ class LSTMOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Hidden"),
"Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
......@@ -59,11 +64,13 @@ class LSTMOp : public framework::OperatorWithKernel {
"The second dimension of Input(Weight) "
"should be 4 * %d.",
frame_size);
auto b_dims = ctx->GetInputDim("Bias");
PADDLE_ENFORCE_EQ(b_dims.size(), 2, "The rank of Input(Bias) should be 2.");
PADDLE_ENFORCE_EQ(b_dims[0], 1,
"The first dimension of Input(Bias) should be 1.");
if (ctx->Attrs().Get<bool>("usePeepholes")) {
if (ctx->Attrs().Get<bool>("use_peepholes")) {
PADDLE_ENFORCE_EQ(b_dims[1], 7 * frame_size,
"The second dimension of Input(Bias) should be "
"7 * %d if enable peepholes connection",
......@@ -74,6 +81,7 @@ class LSTMOp : public framework::OperatorWithKernel {
"4 * %d if disable peepholes connection",
frame_size);
}
framework::DDim out_dims({in_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
......@@ -118,14 +126,13 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Bias",
"(Tensor) the learnable weights, which contains two parts: "
"input-hidden bias weight and peephole connections weight if "
"setting `usePeepholes` True. "
"1. `usePeepholes = False` "
"setting `use_peepholes` True. "
"1. `use_peepholes = False` "
" - The shape is (1 x 4D). "
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `usePeepholes = True` "
"2. `use_peepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.")
.AsDispensable();
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
AddOutput("Hidden",
"(LoDTensor) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
......@@ -145,29 +152,32 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) This LoDTensor is obtained in the forward and used "
"in the backward.")
.AsIntermediate();
AddAttr<bool>("usePeepholes",
"(bool, default True) "
AddAttr<bool>("use_peepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
.SetDefault(true);
AddAttr<bool>("isReverse",
"(bool, default False) "
AddAttr<bool>("is_reverse",
"(bool, defalut: False) "
"whether to compute reversed LSTM.")
.SetDefault(false);
AddAttr<std::string>(
"gateActivation",
"(string, default sigmoid)"
"gate_activation",
"(string, default: sigmoid)"
"The activation for input gate, forget gate and output "
"gate, `sigmoid` by default.")
.SetDefault("sigmoid");
AddAttr<std::string>("cellActivation",
"(string, default tanh)"
.SetDefault("sigmoid")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("cell_activation",
"(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh");
AddAttr<std::string>("candidateActivation",
"(string, default tanh)"
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("candidate_activation",
"(string, default: tanh)"
"The activation for candidate hidden state, "
"`tanh` by default.")
.SetDefault("tanh");
.SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC(
Long-Short Term Memory (LSTM) Operator.
......@@ -203,7 +213,7 @@ are the cell input and cell output activation functions and `tanh` is usually
used for them. \f$\tilde{c_t}\f$ is also called candidate hidden state,
which is computed based on the current input and the previous hidden state.
Set usePeepholes False to disable peephole connection
Set `use_peepholes` False to disable peephole connection
(http://www.bioinf.jku.at/publications/older/2604.pdf). The formula
is omitted here.
......@@ -226,23 +236,27 @@ class LSTMGradOp : public framework::OperatorWithKernel {
"Input(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"),
"Input(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Weight"),
"Input(Weight) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Bias"),
"Input(Bias) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTM should not be null.");
auto in_g_name = framework::GradVarName("Input");
if (ctx->HasOutput(in_g_name))
ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input"));
auto w_g_name = framework::GradVarName("Weight");
if (ctx->HasOutput(w_g_name))
ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight"));
auto b_g_name = framework::GradVarName("Bias");
if (ctx->HasOutput(b_g_name))
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias"));
auto SetOutGradDim = [&ctx](const std::string& name) {
auto g_name = framework::GradVarName(name);
if (ctx->HasOutput(g_name))
ctx->SetOutputDim(g_name, ctx->GetInputDim(name));
};
SetOutGradDim("Input");
SetOutGradDim("Weight");
SetOutGradDim("Bias");
SetOutGradDim("H0");
SetOutGradDim("C0");
}
protected:
......
......@@ -28,6 +28,15 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
inline void ReorderInitState(const platform::DeviceContext& ctx,
const framework::Tensor& src, const size_t* index,
framework::Tensor* dst, bool indexed_src) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
dst->mutable_data<T>(src.dims(), ctx.GetPlace());
row_shuffle(ctx, src, index, *dst, indexed_src);
}
template <typename Place, typename T>
class LSTMKernel : public framework::OpKernel<T> {
public:
......@@ -36,6 +45,9 @@ class LSTMKernel : public framework::OpKernel<T> {
auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden_t0 = ctx.Input<Tensor>("H0");
auto* cell_t0 = ctx.Input<Tensor>("C0");
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(ctx.GetPlace());
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
......@@ -43,12 +55,7 @@ class LSTMKernel : public framework::OpKernel<T> {
auto* cell_out = ctx.Output<LoDTensor>("Cell");
cell_out->mutable_data<T>(ctx.GetPlace());
// Now the function ShareLoD in InferShape is not implemented.
// So copy LoD here.
ctx.ShareLoD("Input", "Hidden");
ctx.ShareLoD("Input", "Cell");
bool is_reverse = ctx.Attr<bool>("isReverse");
bool is_reverse = ctx.Attr<bool>("is_reverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch;
auto& device_ctx = ctx.device_context();
to_batch(device_ctx, *input, *batch_gate, true, is_reverse);
......@@ -71,7 +78,7 @@ class LSTMKernel : public framework::OpKernel<T> {
}
math::LstmMetaValue<T> lstm_value;
if (bias) {
if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later.
......@@ -84,6 +91,16 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value.checkOg = nullptr;
}
lstm_value.prevStateValue = nullptr;
Tensor ordered_c0;
const size_t* order = batch_gate->lod()[2].data();
if (cell_t0) {
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized cell state also needs
// to reorder.
ReorderInitState<Place, T>(device_ctx, *cell_t0, order, &ordered_c0,
true);
lstm_value.prevStateValue = ordered_c0.data<T>();
}
// Use the local variable as here.
LoDTensor batch_hidden, batch_cell;
......@@ -94,9 +111,9 @@ class LSTMKernel : public framework::OpKernel<T> {
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
auto gate_act = ctx.Attr<std::string>("gateActivation");
auto cell_act = ctx.Attr<std::string>("cellActivation");
auto cand_act = ctx.Attr<std::string>("candidateActivation");
auto gate_act = ctx.Attr<std::string>("gate_activation");
auto cell_act = ctx.Attr<std::string>("cell_activation");
auto cand_act = ctx.Attr<std::string>("candidate_activation");
for (size_t n = 0; n < num_batch; n++) {
int bstart = static_cast<int>(batch_starts[n]);
......@@ -109,15 +126,28 @@ class LSTMKernel : public framework::OpKernel<T> {
int cur_batch_size = bend - bstart;
if (n != 0) {
if (n > 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(device_ctx, pre_hidden_t, false, *weight, false,
static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
} else if (hidden_t0) {
// If n == 0 and there is no initialized hidden state, that is to say
// the H0 is zeros, the calculation W_h * H0 will be skiped.
// If n == 0 and there is initialized hidden state, calculate W_h * H0.
// Since the batch computing for LSTM reorders the input sequence
// according to their length. The initialized hidden state also needs
// to reorder.
Tensor ordered_h0;
ReorderInitState<Place, T>(device_ctx, *hidden_t0, order, &ordered_h0,
true);
math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false,
static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
}
// else if : FIXME support the initial hidden and cell
lstm_value.gateValue = gate_t.data<T>();
lstm_value.outputValue = out_t.data<T>();
......@@ -160,6 +190,12 @@ class LSTMGradKernel : public framework::OpKernel<T> {
auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto* h0 = ctx.Input<Tensor>("H0");
auto* c0 = ctx.Input<Tensor>("C0");
auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0"));
auto* c0_g = ctx.Output<Tensor>(framework::GradVarName("C0"));
auto& device_ctx = ctx.device_context();
math::SetConstant<Place, T> zero;
if (weight_g) {
......@@ -167,13 +203,25 @@ class LSTMGradKernel : public framework::OpKernel<T> {
zero(device_ctx, weight_g, static_cast<T>(0.0));
}
// ordered_h0/c0 is the reordered hidden/cell initialization.
// ordered_h0_g/c0_g is the reordered gradient of hidden/cell
// initialization.
Tensor ordered_h0, ordered_c0, ordered_h0_g, ordered_c0_g;
const size_t* order = batch_gate->lod()[2].data();
if (c0) {
ReorderInitState<Place, T>(device_ctx, *c0, order, &ordered_c0, true);
}
if (c0 && c0_g) {
ordered_c0_g.mutable_data<T>(c0_g->dims(), ctx.GetPlace());
}
auto in_dims = input->dims();
auto out_dims = hidden_g->dims();
int frame_size = static_cast<int>(in_dims[1] / 4);
PADDLE_ENFORCE_EQ(frame_size, out_dims[1]);
math::LstmMetaValue<T> lstm_value;
if (bias) {
if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>());
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
......@@ -185,9 +233,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
}
math::LstmMetaGrad<T> lstm_grad;
if (bias && bias_g) {
T* bias_g_data = const_cast<T*>(bias_g->mutable_data<T>(ctx.GetPlace()));
bias_g->mutable_data<T>(ctx.GetPlace());
zero(device_ctx, bias_g, static_cast<T>(0.0));
}
if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
T* bias_g_data = bias_g->data<T>();
lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
......@@ -199,36 +251,30 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math::LoDTensor2BatchFunctor<Place, T> to_batch;
// use the local variable as here.
LoDTensor batch_hidden;
batch_hidden.mutable_data<T>(out_dims, ctx.GetPlace());
batch_hidden.set_lod(batch_gate->lod());
to_batch(device_ctx, *hidden_out, batch_hidden, false);
auto ToBatch = [&batch_gate, &to_batch](
const platform::DeviceContext& ctx, const framework::LoDTensor& src,
const framework::DDim& dims, framework::LoDTensor& dst) {
dst.mutable_data<T>(dims, ctx.GetPlace());
dst.set_lod(batch_gate->lod());
to_batch(ctx, src, dst, false);
};
LoDTensor batch_hidden_g;
batch_hidden_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_hidden_g.set_lod(batch_gate->lod());
to_batch(device_ctx, *hidden_g, batch_hidden_g, false);
LoDTensor batch_hidden, batch_hidden_g, batch_cell;
ToBatch(device_ctx, *hidden_out, out_dims, batch_hidden);
ToBatch(device_ctx, *hidden_g, out_dims, batch_hidden_g);
ToBatch(device_ctx, *cell_out, out_dims, batch_cell);
LoDTensor batch_cell;
batch_cell.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell.set_lod(batch_gate->lod());
to_batch(device_ctx, *cell_out, batch_cell, false);
LoDTensor batch_cell_g;
LoDTensor batch_cell_g, batch_gate_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_g.set_lod(batch_gate->lod());
// TODO(qingqing) support the case output cell has gradient.
// to_batch(device_ctx, *cell_g, batch_cell_g, false);
zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));
LoDTensor batch_gate_g;
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod());
auto gate_act = ctx.Attr<std::string>("gateActivation");
auto cell_act = ctx.Attr<std::string>("cellActivation");
auto cand_act = ctx.Attr<std::string>("candidateActivation");
auto gate_act = ctx.Attr<std::string>("gate_activation");
auto cell_act = ctx.Attr<std::string>("cell_activation");
auto cand_act = ctx.Attr<std::string>("candidate_activation");
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
......@@ -250,15 +296,15 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_grad.gateGrad = gate_g.data<T>();
lstm_grad.outputGrad = out_g.data<T>();
if (n) {
if (n > 0) {
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
lstm_value.prevStateValue = cell_pre.data<T>();
lstm_grad.prevStateGrad = cell_pre_g.data<T>();
} else {
lstm_value.prevStateValue = nullptr;
lstm_grad.prevStateGrad = nullptr;
lstm_value.prevStateValue = c0 ? ordered_c0.data<T>() : nullptr;
lstm_grad.prevStateGrad = c0_g ? ordered_c0_g.data<T>() : nullptr;
}
int cur_batch_size = bend - bstart;
......@@ -266,7 +312,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
if (n != 0) {
if (n > 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
......@@ -280,6 +326,19 @@ class LSTMGradKernel : public framework::OpKernel<T> {
static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
}
} else {
if (h0 && weight_g) {
ReorderInitState<Place, T>(device_ctx, *h0, order, &ordered_h0, true);
math::matmul<Place, T>(device_ctx, ordered_h0, true, gate_g, false,
static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
}
if (h0 && h0_g) {
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
math::matmul<Place, T>(device_ctx, gate_g, false, *weight, true,
static_cast<T>(1.0), &ordered_h0_g,
static_cast<T>(0.0));
}
}
}
......@@ -302,6 +361,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
math::gemv<Place, T>(device_ctx, true, m, n, 1., batch_gate_g.data<T>(),
ones.data<T>(), 0., bias_g->data<T>());
}
if (h0 && h0_g) {
ReorderInitState<Place, T>(device_ctx, ordered_h0_g, order, h0_g, false);
}
if (c0 && c0_g) {
ReorderInitState<Place, T>(device_ctx, ordered_c0_g, order, c0_g, false);
}
}
};
......
......@@ -52,9 +52,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = value.checkIg[i];
rCheckF = value.checkFg[i];
rCheckO = value.checkOg[i];
rCheckI = value.checkIg ? value.checkIg[i] : 0;
rCheckF = value.checkFg ? value.checkFg[i] : 0;
rCheckO = value.checkOg ? value.checkOg[i] : 0;
if (value.prevStateValue) {
rPrevState = value.prevStateValue[i];
......@@ -114,9 +114,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = value.checkIg[i];
rCheckF = value.checkFg[i];
rCheckO = value.checkOg[i];
rCheckI = value.checkIg ? value.checkIg[i] : 0;
rCheckF = value.checkFg ? value.checkFg[i] : 0;
rCheckO = value.checkOg ? value.checkOg[i] : 0;
rState = value.stateValue[i];
rStateAtv = value.stateActiveValue[i];
rOutputGrad = grad.outputGrad[i];
......@@ -155,9 +155,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize,
__m256 rValueIg;
__m256 rValueFg;
__m256 rValueOg;
__m256 rCheckI;
__m256 rCheckF;
__m256 rCheckO;
__m256 rCheckI = _mm256_set1_ps(0.0f);
__m256 rCheckF = _mm256_set1_ps(0.0f);
__m256 rCheckO = _mm256_set1_ps(0.0f);
__m256 rState;
__m256 rPrevState = _mm256_set1_ps(0.0f);
__m256 rStateAtv;
......@@ -173,9 +173,11 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize,
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
if (value.checkIg) {
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
}
if (value.prevStateValue) {
rPrevState = ((__m256 *)value.prevStateValue)[i];
......@@ -216,9 +218,9 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
__m256 rState;
__m256 rStateAtv;
__m256 rOutputGrad;
__m256 rCheckI;
__m256 rCheckF;
__m256 rCheckO;
__m256 rCheckI = _mm256_set1_ps(0.0f);
__m256 rCheckF = _mm256_set1_ps(0.0f);
__m256 rCheckO = _mm256_set1_ps(0.0f);
__m256 rCheckIGrad;
__m256 rCheckFGrad;
__m256 rCheckOGrad;
......@@ -237,9 +239,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
rValueIg = valueIg[i];
rValueFg = valueFg[i];
rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
if (value.checkIg) {
rCheckI = ((__m256 *)value.checkIg)[i];
rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
}
rState = ((__m256 *)value.stateValue)[i];
rStateAtv = ((__m256 *)value.stateActiveValue)[i];
rOutputGrad = ((__m256 *)grad.outputGrad)[i];
......
......@@ -55,9 +55,10 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
T rValueIg;
T rValueFg;
T rValueOg;
T rCheckI = value.checkIg[frameIdx];
T rCheckF = value.checkFg[frameIdx];
T rCheckO = value.checkOg[frameIdx];
T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0;
T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0;
T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0;
rValueIn = value.gateValue[frameIdx];
rValueIg = value.gateValue[frameIdx + frameSize];
......@@ -121,9 +122,10 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
T rStateGrad;
T rStateAtv;
T rOutputGrad;
T rCheckI = value.checkIg[frameIdx];
T rCheckF = value.checkFg[frameIdx];
T rCheckO = value.checkOg[frameIdx];
T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0;
T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0;
T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0;
T rCheckIGrad;
T rCheckFGrad;
T rCheckOGrad;
......
......@@ -22,8 +22,8 @@ template <typename T>
class CopyMatrixRowsFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index) {
const framework::Tensor& src, const size_t* index,
framework::Tensor& dst, bool is_src_index) {
auto src_dims = src.dims();
auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2UL,
......
......@@ -41,8 +41,8 @@ template <typename T>
class CopyMatrixRowsFunctor<platform::GPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index) {
const framework::Tensor& src, const size_t* index,
framework::Tensor& dst, bool is_src_index) {
auto src_dims = src.dims();
auto dst_dims = dst.dims();
PADDLE_ENFORCE_EQ(src_dims.size(), 2,
......
......@@ -30,8 +30,8 @@ class CopyMatrixRowsFunctor {
// copy the input src to the indexed rows of output dst.
// The indexed rows are based on the input index.
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& src, const size_t* index,
framework::LoDTensor& dst, bool is_src_index);
const framework::Tensor& src, const size_t* index,
framework::Tensor& dst, bool is_src_index);
};
template <typename Place, typename T>
......@@ -57,7 +57,7 @@ class LoDTensor2BatchFunctor {
bool is_reverse = false) const {
if (!is_cal_batch_lod) {
auto lods = batch.lod();
PADDLE_ENFORCE_EQ(lods.size(), 2UL);
PADDLE_ENFORCE_GT(lods.size(), 2UL);
PADDLE_ENFORCE_EQ(lods[1].size(),
static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_batch;
......@@ -66,8 +66,8 @@ class LoDTensor2BatchFunctor {
}
auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
auto lod = lods[0];
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
std::vector<SeqInfo> seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
......@@ -78,8 +78,7 @@ class LoDTensor2BatchFunctor {
std::sort(seq_info.begin(), seq_info.end(),
[](SeqInfo a, SeqInfo b) { return a.length > b.length; });
// calculate the start position of each batch
// (numBatch equal the maxLength of sequences)
// Calculate the start position of each batch.
// example: sequences = {s0, s1, s2}
// s0: 0 0 0 0, s1: 1 1 1 1 1, s2: 2 2 2
// num_batch = 5,
......@@ -95,19 +94,25 @@ class LoDTensor2BatchFunctor {
// 6, 2, 11,
// 7, 3,
// 8}
// The batch number represents batch size after rearranging the
// seq_order = {1, 0, 2}, the sort order.
// where 1 is the second sequence,
// 0 is the first sequence,
// 2 is the third sequence.
// The num_batch represents batch size after rearranging the
// input LodTensor. It is also the maximum length of input sequence.
paddle::framework::LoD batch_lods;
batch_lods.emplace_back(std::vector<size_t>{0});
batch_lods.emplace_back(std::vector<size_t>{0});
batch_lods.emplace_back(std::vector<size_t>{0});
// batch_lods[0] is the start positions for batch LoDTensor
int num_batch = seq_info[0].length;
batch_lods[0].resize(static_cast<size_t>(num_batch + 1));
// batch_lods[1] is the raw index in the input LoDTensor
auto dims = lod_tensor.dims();
batch_lods[1].resize(static_cast<size_t>(dims[0]));
batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
// batch_lods[2] is the sort order for the input LoDTensor.
batch_lods[2].resize(seq_info.size());
size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data();
......@@ -127,6 +132,10 @@ class LoDTensor2BatchFunctor {
}
batch_starts[n + 1] = static_cast<size_t>(batch_id);
}
size_t* seq_order = batch_lods[2].data();
for (size_t i = 0; i < seq_info.size(); ++i) {
seq_order[i] = seq_info[i].seq_idx;
}
batch.set_lod(batch_lods);
CopyMatrixRowsFunctor<Place, T> to_batch;
......@@ -141,8 +150,7 @@ class Batch2LoDTensorFunctor {
const framework::LoDTensor& batch,
framework::LoDTensor& lod_tensor) const {
auto in_lod = batch.lod();
PADDLE_ENFORCE_EQ(in_lod.size(), 2UL,
"The LoD size of input `batch` should be 2.");
PADDLE_ENFORCE_GT(in_lod.size(), 2UL);
PADDLE_ENFORCE_EQ(in_lod[1].size(),
static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_seq;
......
......@@ -117,8 +117,9 @@ class TestLstmOp(OpTest):
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.has_initial_state = True
self.has_initial_state = False
self.is_reverse = False
self.use_peepholes = True
def setUp(self):
self.set_argument()
......@@ -128,18 +129,28 @@ class TestLstmOp(OpTest):
N = len(self.lod[0]) - 1
x = np.random.normal(size=(T, 4 * self.D)).astype('float64')
h0 = np.zeros((N, self.D)).astype('float64')
c0 = np.zeros((N, self.D)).astype('float64')
if self.has_initial_state:
h0 = np.random.normal(size=(N, self.D)).astype('float64')
c0 = np.random.normal(size=(N, self.D)).astype('float64')
else:
h0 = np.zeros((N, self.D)).astype('float64')
c0 = np.zeros((N, self.D)).astype('float64')
w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
if self.use_peepholes:
b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
else:
b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
w_b = b[:, 0:4 * self.D]
w_c = b[:, 4 * self.D:]
w_c = b[:, 4 * self.D:] if self.use_peepholes else None
h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand])
self.inputs = {'Input': (x, self.lod), 'Weight': w, 'Bias': b}
self.inputs = {'Input': (x, self.lod), 'Weight': w}
self.inputs['Bias'] = b
if self.has_initial_state:
self.inputs['H0'] = h0
self.inputs['C0'] = c0
......@@ -149,17 +160,16 @@ class TestLstmOp(OpTest):
'Cell': (c, self.lod),
}
self.attrs = {
'usePeepholes': True,
'isReverse': self.is_reverse,
'gateActivation': self.act_gate,
'cellActivation': self.act_cell,
'candidateActivation': self.act_cand
'use_peepholes': self.use_peepholes,
'is_reverse': self.is_reverse,
'gate_activation': self.act_gate,
'cell_activation': self.act_cell,
'candidate_activation': self.act_cand
}
def test_check_output(self):
self.check_output(atol=1e-8)
#TODO(qingqing) add more unit testing case
def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) - 1
......@@ -170,7 +180,7 @@ class TestLstmOp(OpTest):
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4)
class TestLstmOpHasNoInitial(TestLstmOp):
class TestLstmOpHasInitial(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
......@@ -179,8 +189,69 @@ class TestLstmOpHasNoInitial(TestLstmOp):
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.has_initial_state = False
self.has_initial_state = True
self.is_reverse = True
self.use_peepholes = True
def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'],
max_relative_error=5e-4)
def test_check_grad_ingore_bias(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Weight'))
def test_check_grad_ingore_input(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Weight', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Input'))
def test_check_grad_ingore_h0(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'C0'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('H0'))
def test_check_grad_ingore_c0(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'H0'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('C0'))
class TestLstmOpRerverse(TestLstmOp):
......@@ -192,8 +263,23 @@ class TestLstmOpRerverse(TestLstmOp):
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.has_initial_state = True
self.has_initial_state = False
self.is_reverse = True
self.use_peepholes = True
class TestLstmOpNotUsePeepholes(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.has_initial_state = False
self.is_reverse = True
self.use_peepholes = False
if __name__ == '__main__':
......
......@@ -102,7 +102,7 @@ class Momentum(Optimizer):
.. math::
v_{t} &= k * v_{t-1} - \\gamma_t / (g_{t} + \\lambda w_{t-1}) \\\\
v_{t} &= k * v_{t-1} - \\gamma_t (g_{t} + \\lambda w_{t-1}) \\\\
w_{t} &= w_{t-1} + v_{t} \\\\
where, :math:`k` is momentum, :math:`\\lambda` is decay rate,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册