未验证 提交 36d20609 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #5115 from qingqing01/lstm_bp

Add backward implementation for LSTM operator. 
......@@ -36,8 +36,8 @@ TEST(LoDTensor, LoDInGPU) {
lod_tensor.mutable_data<float>(place);
lod_tensor.set_lod(src_lod);
CHECK_EQ(lod_tensor.lod_element(0, 2).first, 4UL);
CHECK_EQ(lod_tensor.lod_element(0, 4).first, 8UL);
EXPECT_EQ(lod_tensor.lod_element(0, 2).first, 4UL);
EXPECT_EQ(lod_tensor.lod_element(0, 4).first, 8UL);
auto lod = lod_tensor.lod();
......@@ -45,6 +45,6 @@ TEST(LoDTensor, LoDInGPU) {
cudaDeviceSynchronize();
for (size_t i = 0; i < src_lod[0].size(); ++i) {
CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
EXPECT_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
}
}
\ No newline at end of file
}
......@@ -21,7 +21,6 @@ class LSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
......@@ -29,9 +28,13 @@ class LSTMOp : public framework::OperatorWithKernel {
"Output(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Cell"),
"Output(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchGate"),
"Output(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("BatchCellPreAct"),
"Output(BatchGate) of LSTM should not be null.");
auto x_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank must be 2.");
auto in_dims = ctx->GetInputDim("Input");
PADDLE_ENFORCE_EQ(in_dims.size(), 2, "Input(X)'s rank must be 2.");
if (ctx->HasInput("H0")) {
PADDLE_ENFORCE(ctx->HasInput("C0"),
......@@ -44,7 +47,7 @@ class LSTMOp : public framework::OperatorWithKernel {
"should be the same.");
}
int frame_size = x_dims[1] / 4;
int frame_size = in_dims[1] / 4;
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(w_dims.size(), 2,
"The rank of Input(Weight) should be 2.");
......@@ -71,12 +74,21 @@ class LSTMOp : public framework::OperatorWithKernel {
"4 * %d if disable peepholes connection",
frame_size);
}
ctx->SetOutputDim("Hidden", {x_dims[0], frame_size});
ctx->SetOutputDim("Cell", {x_dims[0], frame_size});
ctx->SetOutputDim("BatchGate", x_dims);
framework::DDim out_dims({in_dims[0], frame_size});
ctx->SetOutputDim("Hidden", out_dims);
ctx->SetOutputDim("Cell", out_dims);
ctx->SetOutputDim("BatchGate", in_dims);
ctx->SetOutputDim("BatchCellPreAct", out_dims);
ctx->ShareLoD("Input", "Hidden");
ctx->ShareLoD("Input", "Cell");
}
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
}
};
class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -86,16 +98,18 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Input",
"(LoDTensor) the first input is a LodTensor, which support "
"variable-time length input sequence. The underlying tensor in "
"this LoDTensor is a matrix with shape (T X 4D), where, T is the "
"this LoDTensor is a matrix with shape (T X 4D), where T is the "
"total time steps in this mini-batch, D is the hidden size.");
AddInput("H0",
"(Tensor, optional) the initial hidden state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size, D is the hidden size.");
"batch size, D is the hidden size.")
.AsDispensable();
AddInput("C0",
"(Tensor, optional) the initial cell state is an optional "
"input. This is a tensor with shape (N x D), where N is the "
"batch size. `H0` and `C0` can be NULL but only at the same time");
"batch size. `H0` and `C0` can be NULL but only at the same time")
.AsDispensable();
AddInput("Weight",
"(Tensor) the learnable hidden-hidden weights."
" - The shape is (D x 4D), where D is the hidden size. "
......@@ -109,22 +123,27 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
" - Bias = {b_c, b_i, b_f, b_o}."
"2. `usePeepholes = True` "
" - The shape is (1 x 7D). "
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.");
" - Bias = {b_c, b_i, b_f, b_o, W_ic, W_fc, W_oc}.")
.AsDispensable();
AddOutput("Hidden",
"(LoDTensor) the hidden state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) the cell state of LSTM operator. "
"The shape is (T x D), and lod is the same with the `Input`.");
AddOutput("BatchGate",
"(LoDTensor) This LoDTensor contains input gate, forget gate "
"and output gate after the nonlinear computation. This "
"LoDTensor has the same shape with the reorganized input, which "
"was also be called batch input. The LoD size is 2. The first "
"is also be called batch input. The LoD size is 2. The first "
"LoD is the batch offsets and the second LoD contains the "
"indexes, which denote the position of reorganized sequence "
"in the raw input.")
.AsIntermediate();
AddOutput("Hidden",
"(LoDTensor) the hidden state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("Cell",
"(LoDTensor) the cell state lod tensor of LSTM operator. "
"The shape and lod is the same with the `Input`.");
AddOutput("BatchCellPreAct",
"(LoDTensor) This LoDTensor is got in the forward and used "
"in the backward.")
.AsIntermediate();
AddAttr<bool>("usePeepholes",
"(bool, defalut: True) "
"whether to enable diagonal/peephole connections.")
......@@ -202,15 +221,37 @@ class LSTMGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(Hidden@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Cell")),
"Input(Cell@GRAD) should not be null");
ctx->SetOutputDim(framework::GradVarName("Weight"),
ctx->GetInputDim("Weight"));
ctx->SetOutputDim(framework::GradVarName("Bias"), ctx->GetInputDim("Bias"));
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(Hidden) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Cell"),
"Input(Cell) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchGate"),
"Input(BatchGate) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("BatchCellPreAct"),
"Input(BatchGate) of LSTM should not be null.");
auto in_g_name = framework::GradVarName("Input");
if (ctx->HasOutput(in_g_name))
ctx->SetOutputDim(in_g_name, ctx->GetInputDim("Input"));
auto w_g_name = framework::GradVarName("Weight");
if (ctx->HasOutput(w_g_name))
ctx->SetOutputDim(w_g_name, ctx->GetInputDim("Weight"));
auto b_g_name = framework::GradVarName("Bias");
if (ctx->HasOutput(b_g_name))
ctx->SetOutputDim(b_g_name, ctx->GetInputDim("Bias"));
}
protected:
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return framework::ToDataType(
ctx.Input<framework::LoDTensor>("Input")->type());
}
};
......
......@@ -21,8 +21,9 @@ limitations under the License. */
namespace paddle {
namespace operators {
using framework::LoDTensor;
using framework::Tensor;
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
......@@ -31,15 +32,15 @@ template <typename Place, typename T>
class LSTMKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::LoDTensor>("Input");
auto* weight = ctx.Input<framework::Tensor>("Weight");
auto* bias = ctx.Input<framework::Tensor>("Bias");
auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias");
auto* batch_gate = ctx.Output<framework::LoDTensor>("BatchGate");
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(ctx.GetPlace());
auto* hidden_out = ctx.Output<framework::LoDTensor>("Hidden");
auto* hidden_out = ctx.Output<LoDTensor>("Hidden");
hidden_out->mutable_data<T>(ctx.GetPlace());
auto* cell_out = ctx.Output<framework::LoDTensor>("Cell");
auto* cell_out = ctx.Output<LoDTensor>("Cell");
cell_out->mutable_data<T>(ctx.GetPlace());
// Now the function ShareLoD in InferShape is not implemented.
......@@ -49,7 +50,8 @@ class LSTMKernel : public framework::OpKernel<T> {
bool is_reverse = ctx.Attr<bool>("isReverse");
math::LoDTensor2BatchFunctor<Place, T> to_batch;
to_batch(ctx.device_context(), *input, *batch_gate, is_reverse);
auto& device_ctx = ctx.device_context();
to_batch(device_ctx, *input, *batch_gate, true, is_reverse);
auto in_dims = input->dims();
int frame_size = static_cast<int>(in_dims[1] / 4);
......@@ -69,17 +71,26 @@ class LSTMKernel : public framework::OpKernel<T> {
}
math::LstmMetaValue<T> lstm_value;
T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
if (bias) {
T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later.
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
} else {
lstm_value.checkIg = nullptr;
lstm_value.checkFg = nullptr;
lstm_value.checkOg = nullptr;
}
lstm_value.prevStateValue = nullptr;
framework::LoDTensor batch_out, batch_cell, batch_cell_pre_act;
batch_out.mutable_data<T>(dims, ctx.GetPlace());
// Use the local variable as here.
LoDTensor batch_hidden, batch_cell;
auto* batch_cell_pre_act = ctx.Output<LoDTensor>("BatchCellPreAct");
batch_hidden.mutable_data<T>(dims, ctx.GetPlace());
batch_cell.mutable_data<T>(dims, ctx.GetPlace());
batch_cell_pre_act.mutable_data<T>(dims, ctx.GetPlace());
batch_cell_pre_act->mutable_data<T>(dims, ctx.GetPlace());
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
......@@ -92,18 +103,18 @@ class LSTMKernel : public framework::OpKernel<T> {
int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate_t = batch_gate->Slice(bstart, bend);
Tensor out_t = batch_out.Slice(bstart, bend);
Tensor out_t = batch_hidden.Slice(bstart, bend);
Tensor cell_t = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act_t = batch_cell_pre_act.Slice(bstart, bend);
Tensor cell_pre_act_t = batch_cell_pre_act->Slice(bstart, bend);
int cur_batch_size = bend - bstart;
if (n != 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_t = batch_out.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(ctx.device_context(), pre_hidden_t, false,
*weight, false, static_cast<T>(1.0), &gate_t,
auto pre_hidden_t = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(device_ctx, pre_hidden_t, false, *weight, false,
static_cast<T>(1.0), &gate_t,
static_cast<T>(1.0));
}
// else if : FIXME support the initial hidden and cell
......@@ -112,27 +123,186 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value.outputValue = out_t.data<T>();
lstm_value.stateValue = cell_t.data<T>();
lstm_value.stateActiveValue = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<Place, T>::compute(ctx.device_context(), lstm_value,
math::LstmUnitFunctor<Place, T>::compute(device_ctx, lstm_value,
frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
lstm_value.prevStateValue = lstm_value.stateValue;
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
batch_out.set_lod(batch_gate->lod());
batch_hidden.set_lod(batch_gate->lod());
// restore the output hidden in LoDTensor from the batch hidden
to_seq(ctx.device_context(), batch_out, *hidden_out);
to_seq(device_ctx, batch_hidden, *hidden_out);
batch_cell.set_lod(batch_gate->lod());
// restore the output cell state in LoDTensor from the batch cell
to_seq(ctx.device_context(), batch_cell, *cell_out);
to_seq(device_ctx, batch_cell, *cell_out);
}
};
template <typename Place, typename T>
class LSTMGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {}
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<LoDTensor>("Input");
auto* weight = ctx.Input<Tensor>("Weight");
auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden_out = ctx.Input<LoDTensor>("Hidden");
auto* cell_out = ctx.Input<LoDTensor>("Cell");
auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");
auto* hidden_g = ctx.Input<LoDTensor>(framework::GradVarName("Hidden"));
auto* in_g = ctx.Output<LoDTensor>(framework::GradVarName("Input"));
auto* weight_g = ctx.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto& device_ctx = ctx.device_context();
math::SetConstant<Place, T> zero;
if (weight_g) {
weight_g->mutable_data<T>(ctx.GetPlace());
zero(device_ctx, weight_g, static_cast<T>(0.0));
}
auto in_dims = input->dims();
auto out_dims = hidden_g->dims();
int frame_size = static_cast<int>(in_dims[1] / 4);
PADDLE_ENFORCE_EQ(frame_size, out_dims[1]);
math::LstmMetaValue<T> lstm_value;
if (bias) {
T* bias_data = const_cast<T*>(bias->data<T>());
lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size;
lstm_value.checkOg = lstm_value.checkFg + frame_size;
} else {
lstm_value.checkIg = nullptr;
lstm_value.checkFg = nullptr;
lstm_value.checkOg = nullptr;
}
math::LstmMetaGrad<T> lstm_grad;
if (bias && bias_g) {
T* bias_g_data = const_cast<T*>(bias_g->mutable_data<T>(ctx.GetPlace()));
zero(device_ctx, bias_g, static_cast<T>(0.0));
lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
} else {
lstm_grad.checkIgGrad = nullptr;
lstm_grad.checkFgGrad = nullptr;
lstm_grad.checkOgGrad = nullptr;
}
math::LoDTensor2BatchFunctor<Place, T> to_batch;
// use the local variable as here.
LoDTensor batch_hidden;
batch_hidden.mutable_data<T>(out_dims, ctx.GetPlace());
batch_hidden.set_lod(batch_gate->lod());
to_batch(device_ctx, *hidden_out, batch_hidden, false);
LoDTensor batch_hidden_g;
batch_hidden_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_hidden_g.set_lod(batch_gate->lod());
to_batch(device_ctx, *hidden_g, batch_hidden_g, false);
LoDTensor batch_cell;
batch_cell.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell.set_lod(batch_gate->lod());
to_batch(device_ctx, *cell_out, batch_cell, false);
LoDTensor batch_cell_g;
batch_cell_g.mutable_data<T>(out_dims, ctx.GetPlace());
batch_cell_g.set_lod(batch_gate->lod());
// TODO(qingqing) support the case output cell has gradient.
// to_batch(device_ctx, *cell_g, batch_cell_g, false);
zero(device_ctx, &batch_cell_g, static_cast<T>(0.0));
LoDTensor batch_gate_g;
batch_gate_g.mutable_data<T>(batch_gate->dims(), ctx.GetPlace());
batch_gate_g.set_lod(batch_gate->lod());
auto gate_act = ctx.Attr<std::string>("gateActivation");
auto cell_act = ctx.Attr<std::string>("cellActivation");
auto cand_act = ctx.Attr<std::string>("candidateActivation");
auto batch_starts = batch_gate->lod()[0];
size_t num_batch = batch_starts.size() - 1;
for (int n = static_cast<int>(num_batch) - 1; n >= 0; n--) {
int bstart = static_cast<int>(batch_starts[n]);
int bend = static_cast<int>(batch_starts[n + 1]);
Tensor gate = batch_gate->Slice(bstart, bend);
Tensor cell = batch_cell.Slice(bstart, bend);
Tensor cell_pre_act = batch_cell_pre_act->Slice(bstart, bend);
lstm_value.gateValue = gate.data<T>();
lstm_value.stateValue = cell.data<T>();
lstm_value.stateActiveValue = cell_pre_act.data<T>();
Tensor out_g = batch_hidden_g.Slice(bstart, bend);
Tensor gate_g = batch_gate_g.Slice(bstart, bend);
Tensor cell_g = batch_cell_g.Slice(bstart, bend);
lstm_grad.stateGrad = cell_g.data<T>();
lstm_grad.gateGrad = gate_g.data<T>();
lstm_grad.outputGrad = out_g.data<T>();
if (n) {
int bstart_pre = static_cast<int>(batch_starts[n - 1]);
Tensor cell_pre = batch_cell.Slice(bstart_pre, bstart);
Tensor cell_pre_g = batch_cell_g.Slice(bstart_pre, bstart);
lstm_value.prevStateValue = cell_pre.data<T>();
lstm_grad.prevStateGrad = cell_pre_g.data<T>();
} else {
lstm_value.prevStateValue = nullptr;
lstm_grad.prevStateGrad = nullptr;
}
int cur_batch_size = bend - bstart;
math::LstmUnitGradFunctor<Place, T>::compute(
device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
gate_act, cell_act, cand_act);
if (n != 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]);
int pre_h_end = pre_h_start + cur_batch_size;
auto pre_hidden_g = batch_hidden_g.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(device_ctx, gate_g, false, *weight, true,
static_cast<T>(1.0), &pre_hidden_g,
static_cast<T>(1.0));
if (weight_g) {
/* backward weight */
auto pre_hidden = batch_hidden.Slice(pre_h_start, pre_h_end);
math::matmul<Place, T>(device_ctx, pre_hidden, true, gate_g, false,
static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
}
}
}
math::Batch2LoDTensorFunctor<Place, T> to_seq;
if (in_g) {
/* backward data */
in_g->mutable_data<T>(ctx.GetPlace());
to_seq(device_ctx, batch_gate_g, *in_g);
}
if (bias && bias_g) {
/* backward bias */
int m = static_cast<int>(batch_gate_g.dims()[0]);
int n = static_cast<int>(batch_gate_g.dims()[1]);
Tensor ones;
ones.mutable_data<T>({m}, ctx.GetPlace());
math::SetConstant<Place, T> set;
set(device_ctx, &ones, static_cast<T>(1.0));
math::gemv<Place, T>(device_ctx, true, m, n, 1., batch_gate_g.data<T>(),
ones.data<T>(), 0., bias_g->data<T>());
}
}
};
} // namespace operators
......
......@@ -26,10 +26,7 @@ namespace detail {
template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
int frameSize) {
T rValueIn;
T rValueIg;
T rValueFg;
......@@ -60,10 +57,8 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[i];
}
hppl::cpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate),
act(active_state));
rOut, rCheckI, rCheckF, rCheckO);
valueIn[i] = rValueIn;
valueIg[i] = rValueIg;
......@@ -77,10 +72,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
LstmMetaGrad<T> grad, int frameSize) {
T rValueIn;
T rValueIg;
T rValueFg;
......@@ -127,11 +119,10 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[i];
}
hppl::cpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg,
rGradOg, rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv,
rOutputGrad, rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad,
rCheckOGrad, act(active_node), act(active_gate), act(active_state));
rCheckOGrad);
gradIn[i] = rGradIn;
gradIg[i] = rGradIg;
......@@ -283,8 +274,7 @@ void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frameSize,
avx_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state);
} else {
naive_lstm_forward_one_sequence<T>(op, value, frameSize, active_node,
active_gate, active_state);
naive_lstm_forward_one_sequence<T>(op, value, frameSize);
}
}
......@@ -297,8 +287,7 @@ void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
avx_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state);
} else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize, active_node,
active_gate, active_state);
naive_lstm_backward_one_sequence<T>(op, value, grad, frameSize);
}
}
......
......@@ -32,9 +32,7 @@ namespace detail {
*/
template <class T, class Op, bool isBatch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
int batchSize) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
......@@ -70,10 +68,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
rPrevState = value.prevStateValue[frameIdx];
}
hppl::gpu::ForwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rPrevState, rState, rStateAtv,
rOut, rCheckI, rCheckF, rCheckO, act(active_node), act(active_gate),
act(active_state));
rOut, rCheckI, rCheckF, rCheckO);
value.gateValue[frameIdx] = rValueIn;
value.gateValue[frameIdx + frameSize] = rValueIg;
......@@ -92,9 +88,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
template <class T, class Op, bool isBatch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frameSize,
int batchSize, activation_mode_t active_node,
activation_mode_t active_gate,
activation_mode_t active_state) {
int batchSize) {
const int frameIdx = blockIdx.x * blockDim.x + threadIdx.x;
if (frameIdx >= frameSize) return;
......@@ -145,11 +139,9 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
rPrevState = value.prevStateValue[frameIdx];
}
hppl::gpu::BackwardAct<T> act;
op(rValueIn, rValueIg, rValueFg, rValueOg, rGradIn, rGradIg, rGradFg, rGradOg,
rPrevState, rPrevStateGrad, rState, rStateGrad, rStateAtv, rOutputGrad,
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad,
act(active_node), act(active_gate), act(active_state));
rCheckI, rCheckF, rCheckO, rCheckIGrad, rCheckFGrad, rCheckOGrad);
grad.gateGrad[frameIdx] = rGradIn;
grad.gateGrad[frameIdx + frameSize] = rGradIg;
......@@ -205,13 +197,11 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
if (batchSize == 1) {
KeLstmForward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
op, value, frameSize, batchSize);
} else {
KeLstmForward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, frameSize, batchSize, active_node, active_gate,
active_state);
op, value, frameSize, batchSize);
}
}
......@@ -240,13 +230,11 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
if (batchSize == 1) {
KeLstmBackward<T, Op,
/* isBatch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
op, value, grad, frameSize, batchSize);
} else {
KeLstmBackward<T, Op,
/* isBatch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frameSize, batchSize, active_node, active_gate,
active_state);
op, value, grad, frameSize, batchSize);
}
}
......
......@@ -24,15 +24,29 @@ namespace detail {
namespace forward {
template <typename T>
DEVICE inline T sigmoid(const T a) {
const T min = SIGMOID_THRESHOLD_MIN;
const T max = SIGMOID_THRESHOLD_MAX;
T tmp = (a < min) ? min : ((a > max) ? max : a);
return static_cast<T>(1.0) / (static_cast<T>(1.0) + exp(-tmp));
}
template <typename T>
DEVICE inline T tanh(const T a) {
T tmp = -2.0 * a;
tmp = (tmp > EXP_MAX_INPUT) ? EXP_MAX_INPUT : tmp;
return (2.0 / (1.0 + exp(tmp))) - 1.0;
}
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &valueIn, T &valueIg, T &valueFg, T &valueOg,
T &prevState, T &state, T &stateAtv, T &output,
T &checkI, T &checkF, T &checkO,
typename hppl::ForwardActType<T>::type actInput,
typename hppl::ForwardActType<T>::type actGate,
typename hppl::ForwardActType<T>::type actState) {
T &checkI, T &checkF, T &checkO) {
#if 0
// TODO(qingqing) support to activation speficed by users
valueIn = actInput(valueIn);
valueIg = actGate(valueIg + prevState * checkI);
valueFg = actGate(valueFg + prevState * checkF);
......@@ -40,6 +54,15 @@ class lstm {
valueOg = actGate(valueOg + state * checkO);
stateAtv = actState(state);
output = valueOg * stateAtv;
#else
valueIn = tanh<T>(valueIn);
valueIg = sigmoid<T>(valueIg + prevState * checkI);
valueFg = sigmoid<T>(valueFg + prevState * checkF);
state = valueIn * valueIg + prevState * valueFg;
valueOg = sigmoid<T>(valueOg + state * checkO);
stateAtv = tanh<T>(state);
output = valueOg * stateAtv;
#endif
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
......@@ -72,6 +95,16 @@ class lstm {
namespace backward {
template <typename T>
DEVICE inline T sigmoid(const T a, const T b) {
return a * b * (1.0 - b);
}
template <typename T>
DEVICE inline T tanh(const T a, const T b) {
return a * (1.0 - b * b);
}
template <class T>
class lstm {
public:
......@@ -80,10 +113,9 @@ class lstm {
T &prevState, T &prevStateGrad, T &state,
T &stateGrad, T &stateAtv, T &outputGrad,
T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad,
typename hppl::BackwardActType<T>::type actInput,
typename hppl::BackwardActType<T>::type actGate,
typename hppl::BackwardActType<T>::type actState) {
T &checkFGrad, T &checkOGrad) {
#if 0
// TODO(qingqing) support to activation speficed by users
gradOg = actGate(outputGrad * stateAtv, valueOg);
stateGrad += actState(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = actInput(stateGrad * valueIg, valueIn);
......@@ -93,6 +125,17 @@ class lstm {
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
#else
gradOg = sigmoid<T>(outputGrad * stateAtv, valueOg);
stateGrad += tanh<T>(outputGrad * valueOg, stateAtv) + gradOg * checkO;
gradIn = tanh<T>(stateGrad * valueIg, valueIn);
gradIg = sigmoid<T>(stateGrad * valueIn, valueIg);
gradFg = sigmoid<T>(stateGrad * prevState, valueFg);
prevStateGrad = gradIg * checkI + gradFg * checkF + stateGrad * valueFg;
checkIGrad = gradIg * prevState;
checkFGrad = gradFg * prevState;
checkOGrad = gradOg * state;
#endif
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
......
......@@ -211,6 +211,26 @@ void batched_gemm<platform::CPUPlace, double>(
}
#endif
template <>
void gemv<platform::CPUPlace, float>(const platform::DeviceContext& context,
const bool trans_a, const int M,
const int N, const float alpha,
const float* A, const float* B,
const float beta, float* C) {
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
cblas_sgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
}
template <>
void gemv<platform::CPUPlace, double>(const platform::DeviceContext& context,
const bool trans_a, const int M,
const int N, const double alpha,
const double* A, const double* B,
const double beta, double* C) {
CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans;
cblas_dgemv(CblasRowMajor, transA, M, N, alpha, A, N, B, 1, beta, C, 1);
}
template struct SetConstant<platform::CPUPlace, float>;
} // namespace math
......
......@@ -203,6 +203,33 @@ void batched_gemm<platform::GPUPlace, double>(
&beta, C, ldc, strideC, batchCount));
}
template <>
void gemv<platform::GPUPlace, float>(const platform::DeviceContext& context,
const bool trans_a, const int M,
const int N, const float alpha,
const float* A, const float* B,
const float beta, float* C) {
cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE(platform::dynload::cublasSgemv(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1));
}
template <>
void gemv<platform::GPUPlace, double>(const platform::DeviceContext& context,
const bool trans_a, const int M,
const int N, const double alpha,
const double* A, const double* B,
const double beta, double* C) {
cublasOperation_t cuTransA = (trans_a == false) ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE(platform::dynload::cublasDgemv(
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.cublas_handle(),
cuTransA, N, M, &alpha, A, N, B, 1, &beta, C, 1));
}
template struct SetConstant<platform::GPUPlace, float>;
} // namespace math
......
......@@ -93,6 +93,11 @@ void batched_gemm(const platform::DeviceContext& context,
const T* A, const T* B, const T beta, T* C,
const int batchCount, const int strideA, const int strideB);
template <typename Place, typename T>
void gemv(const platform::DeviceContext& context, const bool trans_a,
const int M, const int N, const T alpha, const T* A, const T* B,
const T beta, T* C);
template <typename Place, typename T>
struct SetConstant {
void operator()(const platform::DeviceContext& context,
......
......@@ -89,3 +89,53 @@ TEST(math_function, zero) {
EXPECT_EQ(t[2], 1);
EXPECT_EQ(t[3], 1);
}
template <typename T>
void GemvTest(int m, int n, bool trans) {
paddle::framework::Tensor mat_a;
paddle::framework::Tensor vec_b;
paddle::framework::Tensor vec_c;
auto* cpu_place = new paddle::platform::CPUPlace();
int b_num = trans ? m : n;
int c_num = trans ? n : m;
T* data_a = mat_a.mutable_data<T>({m, n}, *cpu_place);
T* data_b = vec_b.mutable_data<T>({b_num}, *cpu_place);
T* data_c = vec_c.mutable_data<T>({c_num}, *cpu_place);
for (int i = 0; i < mat_a.numel(); ++i) {
data_a[i] = static_cast<T>(i);
}
for (int i = 0; i < vec_b.numel(); ++i) {
data_b[i] = static_cast<T>(i);
}
paddle::platform::CPUDeviceContext context(*cpu_place);
paddle::operators::math::gemv<paddle::platform::CPUPlace, T>(
context, trans, static_cast<int>(m), static_cast<int>(n), 1., data_a,
data_b, 0., data_c);
if (!trans) {
for (int i = 0; i < m; ++i) {
T sum = 0.0;
for (int j = 0; j < n; ++j) {
sum += data_a[i * n + j] * data_b[j];
}
ASSERT_FLOAT_EQ(data_c[i], sum);
}
} else {
for (int i = 0; i < n; ++i) {
T sum = 0.0;
for (int j = 0; j < m; ++j) {
sum += data_a[j * n + i] * data_b[j];
}
ASSERT_FLOAT_EQ(data_c[i], sum);
}
}
}
TEST(math_function, gemv) {
GemvTest<float>(3, 13, false);
GemvTest<double>(4, 5, false);
GemvTest<float>(12, 7, true);
GemvTest<double>(7, 9, true);
}
......@@ -177,3 +177,65 @@ TEST(math_function, gemm_trans_cublas) {
EXPECT_EQ(input3_ptr[7], 99);
delete gpu_place;
}
template <typename T>
void GemvTest(int m, int n, bool trans) {
paddle::framework::Tensor mat_a;
paddle::framework::Tensor vec_b;
paddle::framework::Tensor vec_c;
auto* cpu_place = new paddle::platform::CPUPlace();
T* data_a = mat_a.mutable_data<T>({m, n}, *cpu_place);
T* data_b = vec_b.mutable_data<T>({trans ? m : n}, *cpu_place);
T* data_c = vec_c.mutable_data<T>({trans ? n : m}, *cpu_place);
auto* gpu_place = new paddle::platform::GPUPlace(0);
paddle::framework::Tensor g_mat_a;
paddle::framework::Tensor g_vec_b;
paddle::framework::Tensor g_vec_c;
T* g_data_a = g_mat_a.mutable_data<T>(mat_a.dims(), *gpu_place);
T* g_data_b = g_vec_b.mutable_data<T>(vec_b.dims(), *gpu_place);
T* g_data_c = g_vec_c.mutable_data<T>(vec_c.dims(), *gpu_place);
for (int i = 0; i < mat_a.numel(); ++i) {
data_a[i] = static_cast<T>(i);
}
for (int i = 0; i < vec_b.numel(); ++i) {
data_b[i] = static_cast<T>(i);
}
paddle::platform::CUDADeviceContext context(*gpu_place);
g_mat_a.CopyFrom(mat_a, *gpu_place, context);
g_vec_b.CopyFrom(vec_b, *gpu_place, context);
paddle::operators::math::gemv<paddle::platform::GPUPlace, T>(
context, trans, static_cast<int>(m), static_cast<int>(n), 1., g_data_a,
g_data_b, 0., g_data_c);
vec_c.CopyFrom(g_vec_c, paddle::platform::CPUPlace(), context);
if (!trans) {
for (int i = 0; i < m; ++i) {
T sum = 0.0;
for (int j = 0; j < n; ++j) {
sum += data_a[i * n + j] * data_b[j];
}
ASSERT_FLOAT_EQ(data_c[i], sum);
}
} else {
for (int i = 0; i < n; ++i) {
T sum = 0.0;
for (int j = 0; j < m; ++j) {
sum += data_a[j * n + i] * data_b[j];
}
ASSERT_FLOAT_EQ(data_c[i], sum);
}
}
}
TEST(math_function, gemv) {
GemvTest<float>(3, 13, false);
GemvTest<double>(3, 13, false);
GemvTest<float>(3, 13, true);
GemvTest<double>(3, 13, true);
}
......@@ -53,7 +53,18 @@ class LoDTensor2BatchFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::LoDTensor& lod_tensor,
framework::LoDTensor& batch, bool is_reverse) const {
framework::LoDTensor& batch, bool is_cal_batch_lod,
bool is_reverse = false) const {
if (!is_cal_batch_lod) {
auto lods = batch.lod();
PADDLE_ENFORCE_EQ(lods.size(), 2UL);
PADDLE_ENFORCE_EQ(lods[1].size(),
static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_batch;
to_batch(context, lod_tensor, lods[1].data(), batch, true);
return;
}
auto lods = lod_tensor.lod();
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
auto lod = lods[0];
......@@ -101,10 +112,10 @@ class LoDTensor2BatchFunctor {
size_t* batch_starts = batch_lods[0].data();
size_t* seq2batch_idx = batch_lods[1].data();
batch_starts[0] = 0;
for (size_t n = 0; n < num_batch; n++) {
for (int n = 0; n < num_batch; n++) {
auto batch_id = static_cast<int>(batch_starts[n]);
for (size_t i = 0; i < seq_info.size(); ++i) {
size_t seq_len = seq_info[i].length;
int seq_len = seq_info[i].length;
int start = seq_info[i].start;
if (n < seq_len) {
seq2batch_idx[batch_id] =
......@@ -132,11 +143,8 @@ class Batch2LoDTensorFunctor {
auto in_lod = batch.lod();
PADDLE_ENFORCE_EQ(in_lod.size(), 2UL,
"The LoD size of input `batch` should be 2.");
auto out_lod = lod_tensor.lod()[0];
auto num = out_lod[out_lod.size() - 1];
PADDLE_ENFORCE_EQ(num, lod_tensor.dims()[0]);
PADDLE_ENFORCE_EQ(num, in_lod[1].size());
PADDLE_ENFORCE_EQ(num, batch.dims()[0]);
PADDLE_ENFORCE_EQ(in_lod[1].size(),
static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_seq;
size_t* index = in_lod[1].data();
to_seq(context, batch, index, lod_tensor, false);
......
......@@ -52,7 +52,7 @@ def lstm(
g = np.dot(h_pre, w_h) # 1 x 4D
g = g + x
g = np.reshape(g, (1, g.size))
c_tmp, g_i, g_f, g_o = np.split(g, 4, axis=1)
c, g_i, g_f, g_o = np.split(g, 4, axis=1)
if w_c is None:
g_i = act_gate(g_i) # 1 x D
g_f = act_gate(g_f) # 1 x D
......@@ -60,7 +60,7 @@ def lstm(
w_ic, w_fc, w_oc = np.split(w_c, 3, axis=1)
g_i = act_gate(g_i + w_ic * c_pre) # 1 x D
g_f = act_gate(g_f + w_fc * c_pre) # 1 x D
c = g_f * c_pre + g_i * act_cand(c_tmp) # 1 x D
c = g_f * c_pre + g_i * act_cand(c) # 1 x D
if w_c is None:
g_o = act_gate(g_o) # 1 x D
......@@ -68,8 +68,7 @@ def lstm(
_, _, w_oc = np.split(w_c, 3, axis=1)
g_o = act_gate(g_o + w_oc * c) # 1 x D
h = g_o * act_cell(c)
bg = np.concatenate((act_cand(c_tmp), g_i, g_f, g_o), axis=1)
return h, c, bg
return h, c
def _reverse(x, lod):
y = np.zeros_like(x)
......@@ -82,7 +81,6 @@ def lstm(
batch_size = len(offset) - 1
hidden = []
cell = []
gate = []
input = _reverse(input, offset) if is_reverse else input
if w_b is not None:
input = input + np.tile(w_b, (offset[-1], 1))
......@@ -94,96 +92,109 @@ def lstm(
c_pre = c0[i] # 1 x D
for j in range(seq_len):
# compute one step
h_pre, c_pre, g_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
act_cell, act_cand)
h_pre, c_pre = _step(x[j], w_h, w_c, h_pre, c_pre, act_gate,
act_cell, act_cand)
hidden.append(h_pre.flatten())
cell.append(c_pre.flatten())
gate.append(g_pre.flatten())
hidden = np.array(hidden).astype("float64")
cell = np.array(cell).astype("float64")
gate = np.array(gate).astype("float64")
hidden = np.array(hidden).astype('float64')
cell = np.array(cell).astype('float64')
hidden = _reverse(hidden, offset) if is_reverse else hidden
cell = _reverse(cell, offset) if is_reverse else cell
assert gate.shape == input.shape
assert hidden.shape == (input.shape[0], input.shape[1] / 4)
assert cell.shape == (input.shape[0], input.shape[1] / 4)
return hidden, cell, gate
return hidden, cell
class TestLstmOp(OpTest):
def set_data(self):
self.lod = [[0, 2, 6, 9]]
self.D = 64
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.act_gate = "sigmoid"
self.act_cell = "tanh"
self.act_cand = "tanh"
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.has_initial_state = True
self.is_reverse = False
def setUp(self):
self.set_data()
self.op_type = "lstm"
self.set_argument()
self.op_type = 'lstm'
T = self.lod[0][-1]
N = len(self.lod[0]) - 1
x = np.random.normal(size=(T, 4 * self.D)).astype("float64")
h0 = np.zeros((N, self.D)).astype("float64")
c0 = np.zeros((N, self.D)).astype("float64")
w = np.random.normal(size=(self.D, 4 * self.D)).astype("float64")
b = np.random.normal(size=(1, 7 * self.D)).astype("float64")
x = np.random.normal(size=(T, 4 * self.D)).astype('float64')
h0 = np.zeros((N, self.D)).astype('float64')
c0 = np.zeros((N, self.D)).astype('float64')
w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
w_b = b[:, 0:4 * self.D]
w_c = b[:, 4 * self.D:]
h, c, g = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand])
g_sort = np.zeros_like(x)
for i, j in enumerate(self.sort_idx):
g_sort[i, :] = g[j, :]
self.inputs = {
'Input': (x, self.lod),
'H0': h0,
'C0': c0,
'Weight': w,
'Bias': b
}
h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand])
self.inputs = {'Input': (x, self.lod), 'Weight': w, 'Bias': b}
if self.has_initial_state:
self.inputs['H0'] = h0
self.inputs['C0'] = c0
self.outputs = {
'Hidden': (h, self.lod),
'Cell': (c, self.lod),
'BatchGate': g_sort
}
self.attrs = {
'usePeepholes': True,
'isReverse': self.is_reverse,
'gateActivation': 'sigmoid',
'cellActivation': 'tanh',
'candidateActivation': 'tanh'
'gateActivation': self.act_gate,
'cellActivation': self.act_cell,
'candidateActivation': self.act_cand
}
def test_check_output(self):
self.check_output()
self.check_output(atol=1e-8)
#TODO(qingqing) add more unit testing case
def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4)
class TestLstmOpHasNoInitial(TestLstmOp):
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.has_initial_state = False
self.is_reverse = True
class TestLstmOpRerverse(TestLstmOp):
def set_data(self):
self.lod = [[0, 2, 6, 9]]
self.D = 64
self.sort_idx = [2, 6, 0, 3, 7, 1, 4, 8, 5]
def set_argument(self):
self.lod = [[0, 2, 5, 7]]
self.D = 16
self.act_gate = "sigmoid"
self.act_cell = "tanh"
self.act_cand = "tanh"
self.act_gate = 'sigmoid'
self.act_cell = 'tanh'
self.act_cand = 'tanh'
self.has_initial_state = True
self.is_reverse = True
if __name__ == "__main__":
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册