未验证 提交 81870723 编写于 作者: X xuezhong 提交者: GitHub

Merge pull request #15605 from xuezhong/fix_bug_for_lstmp

Fix bug for lstmp
...@@ -71,7 +71,7 @@ paddle.fluid.initializer.NumpyArrayInitializer.__init__ ArgSpec(args=['self', 'v ...@@ -71,7 +71,7 @@ paddle.fluid.initializer.NumpyArrayInitializer.__init__ ArgSpec(args=['self', 'v
paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'is_test', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, False, None)) paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param_attr', 'bias_attr', 'act', 'is_test', 'name'], varargs=None, keywords=None, defaults=(1, None, None, None, False, None))
paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32')) paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32'))
paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None)) paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None))
paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None)) paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name', 'h_0', 'c_0', 'cell_clip', 'proj_clip'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None, None, None, None, None))
paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False)) paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False))
paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False)) paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False))
paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,))
......
...@@ -151,9 +151,10 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -151,9 +151,10 @@ class LSTMKernel : public framework::OpKernel<T> {
lstm_value.output_value = out_t.data<T>(); lstm_value.output_value = out_t.data<T>();
lstm_value.state_value = cell_t.data<T>(); lstm_value.state_value = cell_t.data<T>();
lstm_value.state_active_value = cell_pre_act_t.data<T>(); lstm_value.state_active_value = cell_pre_act_t.data<T>();
T cell_clip = 0.0;
math::LstmUnitFunctor<DeviceContext, T>::compute( math::LstmUnitFunctor<DeviceContext, T>::compute(
device_ctx, lstm_value, frame_size, cur_batch_size, gate_act, device_ctx, lstm_value, frame_size, cur_batch_size, cell_clip,
cell_act, cand_act); gate_act, cell_act, cand_act);
lstm_value.prev_state_value = lstm_value.state_value; lstm_value.prev_state_value = lstm_value.state_value;
} }
...@@ -316,9 +317,10 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -316,9 +317,10 @@ class LSTMGradKernel : public framework::OpKernel<T> {
lstm_value.output_value = nullptr; lstm_value.output_value = nullptr;
lstm_grad.state_active_grad = nullptr; lstm_grad.state_active_grad = nullptr;
int cur_batch_size = bend - bstart; int cur_batch_size = bend - bstart;
T cell_clip = 0.0;
math::LstmUnitGradFunctor<DeviceContext, T>::compute( math::LstmUnitGradFunctor<DeviceContext, T>::compute(
device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size, device_ctx, lstm_value, lstm_grad, frame_size, cur_batch_size,
gate_act, cell_act, cand_act); cell_clip, gate_act, cell_act, cand_act);
if (n > 0) { if (n > 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]); int pre_h_start = static_cast<int>(batch_starts[n - 1]);
......
...@@ -73,12 +73,6 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -73,12 +73,6 @@ class LSTMPOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("C0"), PADDLE_ENFORCE(ctx->HasInput("C0"),
"Input(C0) of LSTMP operator should not be null after " "Input(C0) of LSTMP operator should not be null after "
"Input(H0) provided."); "Input(H0) provided.");
auto h_dims = ctx->GetInputDim("H0");
auto c_dims = ctx->GetInputDim("C0");
PADDLE_ENFORCE(h_dims == c_dims,
"The dimension of Input(H0) and Input(C0) "
"should be the same.");
ctx->SetOutputDim("OrderedP0", {h_dims[0], proj_dims[1]});
} }
auto b_dims = ctx->GetInputDim("Bias"); auto b_dims = ctx->GetInputDim("Bias");
...@@ -180,11 +174,6 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -180,11 +174,6 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"This LoDTensor is obtained in the forward and used in the " "This LoDTensor is obtained in the forward and used in the "
"backward.") "backward.")
.AsIntermediate(); .AsIntermediate();
AddOutput("OrderedP0",
"(Tensor) the projection of the initial hidden state "
"H0. This is a tensor with shape (N x P), where N is the "
"batch size and P is the hidden size.")
.AsIntermediate();
AddAttr<bool>("use_peepholes", AddAttr<bool>("use_peepholes",
"(bool, defalut: True) " "(bool, defalut: True) "
"whether to enable diagonal/peephole connections.") "whether to enable diagonal/peephole connections.")
...@@ -193,6 +182,16 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -193,6 +182,16 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) " "(bool, defalut: False) "
"whether to compute reversed LSTMP.") "whether to compute reversed LSTMP.")
.SetDefault(false); .SetDefault(false);
AddAttr<float>("cell_clip",
"(float, defalut: 0.0) "
"Clip for Tensor for cell state tensor when clip value is "
"greater than 0.0")
.SetDefault(0.0);
AddAttr<float>("proj_clip",
"(float, defalut: 0.0) "
"Clip for Tensor for projection tensor when clip value is "
"greater than 0.0")
.SetDefault(0.0);
AddAttr<std::string>( AddAttr<std::string>(
"gate_activation", "gate_activation",
"(string, default: sigmoid)" "(string, default: sigmoid)"
......
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
...@@ -21,17 +22,50 @@ limitations under the License. */ ...@@ -21,17 +22,50 @@ limitations under the License. */
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/operators/math/lstm_compute.h" #include "paddle/fluid/operators/math/lstm_compute.h"
#include "paddle/fluid/operators/math/sequence2batch.h" #include "paddle/fluid/operators/math/sequence2batch.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using platform::Transform;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
class _ClipFunctor {
public:
explicit _ClipFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x) const {
if (x < min_)
return min_;
else if (x > max_)
return max_;
else
return x;
}
private:
T min_;
T max_;
};
template <typename T>
class _ClipGradFunctor {
public:
explicit _ClipGradFunctor(const T min, const T max) : min_(min), max_(max) {}
HOSTDEVICE T operator()(const T& x, const T& y) const {
return (y > min_ && y < max_) ? x : 0;
}
private:
T min_;
T max_;
};
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
inline void ReorderInitState(const DeviceContext& ctx, inline void ReorderInitState(const DeviceContext& ctx,
const framework::Tensor& src, const framework::Tensor& src,
...@@ -67,9 +101,11 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -67,9 +101,11 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto* bias = ctx.Input<Tensor>("Bias"); auto* bias = ctx.Input<Tensor>("Bias");
auto* hidden_t0 = ctx.Input<Tensor>("H0"); auto* hidden_t0 = ctx.Input<Tensor>("H0");
auto* ordered_proj0 = ctx.Output<Tensor>("OrderedP0");
auto* cell_t0 = ctx.Input<Tensor>("C0"); auto* cell_t0 = ctx.Input<Tensor>("C0");
auto proj_clip = static_cast<T>(ctx.Attr<float>("proj_clip"));
auto cell_clip = static_cast<T>(ctx.Attr<float>("cell_clip"));
auto* batch_gate = ctx.Output<LoDTensor>("BatchGate"); auto* batch_gate = ctx.Output<LoDTensor>("BatchGate");
batch_gate->mutable_data<T>(ctx.GetPlace()); batch_gate->mutable_data<T>(ctx.GetPlace());
auto* proj_out = ctx.Output<LoDTensor>("Projection"); auto* proj_out = ctx.Output<LoDTensor>("Projection");
...@@ -110,6 +146,7 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -110,6 +146,7 @@ class LSTMPKernel : public framework::OpKernel<T> {
} }
lstmp_value.prev_state_value = nullptr; lstmp_value.prev_state_value = nullptr;
Tensor ordered_c0; Tensor ordered_c0;
Tensor ordered_h0;
framework::Vector<size_t> order(batch_gate->lod()[2]); framework::Vector<size_t> order(batch_gate->lod()[2]);
...@@ -169,18 +206,9 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -169,18 +206,9 @@ class LSTMPKernel : public framework::OpKernel<T> {
// Since the batch computing for LSTMP reorders the input sequence // Since the batch computing for LSTMP reorders the input sequence
// according to their length. The initialized hidden state also needs // according to their length. The initialized hidden state also needs
// to reorder. // to reorder.
Tensor ordered_h0;
ordered_proj0->mutable_data<T>(ctx.GetPlace());
ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order, ReorderInitState<DeviceContext, T>(device_ctx, *hidden_t0, order,
&ordered_h0, true); &ordered_h0, true);
blas.MatMul(ordered_h0, false, *proj_weight, false, static_cast<T>(1.0), blas.MatMul(ordered_h0, false, *weight, false, static_cast<T>(1.0),
ordered_proj0, static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
ActCompute(cell_act, place, proj0_dev, proj0_dev);
}
blas.MatMul(*ordered_proj0, false, *weight, false, static_cast<T>(1.0),
&gate_t, static_cast<T>(1.0)); &gate_t, static_cast<T>(1.0));
} }
...@@ -189,8 +217,8 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -189,8 +217,8 @@ class LSTMPKernel : public framework::OpKernel<T> {
lstmp_value.state_value = cell_t.data<T>(); lstmp_value.state_value = cell_t.data<T>();
lstmp_value.state_active_value = cell_pre_act_t.data<T>(); lstmp_value.state_active_value = cell_pre_act_t.data<T>();
math::LstmUnitFunctor<DeviceContext, T>::compute( math::LstmUnitFunctor<DeviceContext, T>::compute(
device_ctx, lstmp_value, frame_size, cur_batch_size, gate_act, device_ctx, lstmp_value, frame_size, cur_batch_size, cell_clip,
cell_act, cand_act); gate_act, cell_act, cand_act);
lstmp_value.prev_state_value = lstmp_value.state_value; lstmp_value.prev_state_value = lstmp_value.state_value;
blas.MatMul(hidden_t, false, *proj_weight, false, static_cast<T>(1.0), blas.MatMul(hidden_t, false, *proj_weight, false, static_cast<T>(1.0),
&proj_t, static_cast<T>(0.0)); &proj_t, static_cast<T>(0.0));
...@@ -198,6 +226,14 @@ class LSTMPKernel : public framework::OpKernel<T> { ...@@ -198,6 +226,14 @@ class LSTMPKernel : public framework::OpKernel<T> {
auto proj_t_dev = EigenMatrix<T>::From(proj_t); auto proj_t_dev = EigenMatrix<T>::From(proj_t);
ActCompute(cell_act, place, proj_t_dev, proj_t_dev); ActCompute(cell_act, place, proj_t_dev, proj_t_dev);
} }
if (proj_clip && proj_clip > 0.0) {
T* x_data = proj_t.data<T>();
int64_t numel = proj_t.numel();
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), x_data,
x_data + numel, x_data,
_ClipFunctor<T>(-1.0 * proj_clip, proj_clip));
}
} }
math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq; math::Batch2LoDTensorFunctor<DeviceContext, T> to_seq;
...@@ -239,6 +275,9 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -239,6 +275,9 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto* proj_out = ctx.Input<LoDTensor>("Projection"); auto* proj_out = ctx.Input<LoDTensor>("Projection");
auto* cell_out = ctx.Input<LoDTensor>("Cell"); auto* cell_out = ctx.Input<LoDTensor>("Cell");
auto proj_clip = static_cast<T>(ctx.Attr<float>("proj_clip"));
auto cell_clip = static_cast<T>(ctx.Attr<float>("cell_clip"));
auto* batch_gate = ctx.Input<LoDTensor>("BatchGate"); auto* batch_gate = ctx.Input<LoDTensor>("BatchGate");
auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct"); auto* batch_cell_pre_act = ctx.Input<LoDTensor>("BatchCellPreAct");
auto* batch_hidden = ctx.Input<LoDTensor>("BatchHidden"); auto* batch_hidden = ctx.Input<LoDTensor>("BatchHidden");
...@@ -253,7 +292,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -253,7 +292,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias")); auto* bias_g = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto* h0 = ctx.Input<Tensor>("H0"); auto* h0 = ctx.Input<Tensor>("H0");
auto* ordered_proj0 = ctx.Input<Tensor>("OrderedP0");
auto* c0 = ctx.Input<Tensor>("C0"); auto* c0 = ctx.Input<Tensor>("C0");
auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0")); auto* h0_g = ctx.Output<Tensor>(framework::GradVarName("H0"));
...@@ -363,6 +401,17 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -363,6 +401,17 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
Tensor cur_proj = batch_proj.Slice(bstart, bend); Tensor cur_proj = batch_proj.Slice(bstart, bend);
Tensor proj_g = batch_proj_g.Slice(bstart, bend); Tensor proj_g = batch_proj_g.Slice(bstart, bend);
if (proj_clip && proj_clip > 0.0) {
T* dx_data = proj_g.data<T>();
T* x_data = cur_proj.data<T>();
int64_t numel = proj_g.numel();
Transform<DeviceContext> trans;
trans(ctx.template device_context<DeviceContext>(), dx_data,
dx_data + numel, x_data, dx_data,
_ClipGradFunctor<T>(-1.0 * proj_clip, proj_clip));
}
if (proj_act != math::detail::ActivationType::kIdentity) { if (proj_act != math::detail::ActivationType::kIdentity) {
auto cur_proj_dev = EigenMatrix<T>::From(cur_proj); auto cur_proj_dev = EigenMatrix<T>::From(cur_proj);
auto proj_g_dev = EigenMatrix<T>::From(proj_g); auto proj_g_dev = EigenMatrix<T>::From(proj_g);
...@@ -412,7 +461,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -412,7 +461,7 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
math::LstmUnitGradFunctor<DeviceContext, T>::compute( math::LstmUnitGradFunctor<DeviceContext, T>::compute(
device_ctx, lstmp_value, lstmp_grad, frame_size, cur_batch_size, device_ctx, lstmp_value, lstmp_grad, frame_size, cur_batch_size,
gate_act, cell_act, cand_act); cell_clip, gate_act, cell_act, cand_act);
if (n > 0) { if (n > 0) {
int pre_h_start = static_cast<int>(batch_starts[n - 1]); int pre_h_start = static_cast<int>(batch_starts[n - 1]);
...@@ -431,32 +480,15 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -431,32 +480,15 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order, ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
&ordered_h0, true); &ordered_h0, true);
if (weight_g) { if (weight_g) {
blas.MatMul(*ordered_proj0, true, gate_g, false, blas.MatMul(ordered_h0, true, gate_g, false, static_cast<T>(1.0),
static_cast<T>(1.0), weight_g, static_cast<T>(1.0)); weight_g, static_cast<T>(1.0));
} }
} }
if (h0 && (h0_g || proj_weight_g)) { if (h0 && (h0_g || proj_weight_g)) {
ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace()); ordered_h0_g.mutable_data<T>(h0_g->dims(), ctx.GetPlace());
Tensor proj0_g;
proj0_g.Resize({in_dims[0], proj_weight->dims()[1]});
proj0_g.mutable_data<T>(ctx.GetPlace());
blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0), blas.MatMul(gate_g, false, *weight, true, static_cast<T>(1.0),
&proj0_g, static_cast<T>(0.0));
if (proj_act != math::detail::ActivationType::kIdentity) {
auto proj0_dev = EigenMatrix<T>::From(*ordered_proj0);
auto proj0_g_dev = EigenMatrix<T>::From(proj0_g);
ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev,
proj0_g_dev);
}
if (h0_g) {
blas.MatMul(proj0_g, false, *proj_weight, true, static_cast<T>(1.0),
&ordered_h0_g, static_cast<T>(0.0)); &ordered_h0_g, static_cast<T>(0.0));
} }
if (proj_weight_g) {
blas.MatMul(ordered_h0, true, proj0_g, false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
}
} }
} }
......
...@@ -32,7 +32,8 @@ namespace detail { ...@@ -32,7 +32,8 @@ namespace detail {
template <class T, class Op> template <class T, class Op>
void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, ActivationType active_node, int frame_size, T cell_clip,
ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
T r_value_in; T r_value_in;
...@@ -67,7 +68,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -67,7 +68,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_node, active_gate, active_state); &cell_clip, active_node, active_gate, active_state);
value_in[i] = r_value_in; value_in[i] = r_value_in;
value_ig[i] = r_value_ig; value_ig[i] = r_value_ig;
...@@ -82,7 +83,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -82,7 +83,7 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
ActivationType active_node, T cell_clip, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
T r_value_in; T r_value_in;
...@@ -135,7 +136,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -135,7 +136,7 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
&r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
&r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
&r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_node, active_gate, active_state); &cell_clip, active_node, active_gate, active_state);
grad_in[i] = r_grad_in; grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig; grad_ig[i] = r_grad_ig;
...@@ -154,7 +155,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -154,7 +155,8 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
int frame_size, ActivationType active_node, int frame_size, T cell_clip,
ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
#ifdef __AVX__ #ifdef __AVX__
...@@ -194,7 +196,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -194,7 +196,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_node, active_gate, active_state); &cell_clip, active_node, active_gate, active_state);
value_in[i] = r_value_in; value_in[i] = r_value_in;
value_ig[i] = r_value_ig; value_ig[i] = r_value_ig;
...@@ -210,7 +212,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -210,7 +212,7 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
ActivationType active_node, T cell_clip, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
#ifdef __AVX__ #ifdef __AVX__
...@@ -268,7 +270,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -268,7 +270,7 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
&r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
&r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
&r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_node, active_gate, active_state); &cell_clip, active_node, active_gate, active_state);
grad_in[i] = r_grad_in; grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig; grad_ig[i] = r_grad_ig;
...@@ -292,27 +294,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -292,27 +294,27 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size, void cpu_lstm_forward(Op op, LstmMetaValue<T> value, int frame_size,
ActivationType active_node, ActivationType active_gate, T cell_clip, ActivationType active_node,
ActivationType active_state) { ActivationType active_gate, ActivationType active_state) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_forward_one_sequence<T>(op, value, frame_size, active_node, avx_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip,
active_gate, active_state); active_node, active_gate, active_state);
} else { } else {
naive_lstm_forward_one_sequence<T>(op, value, frame_size, active_node, naive_lstm_forward_one_sequence<T>(op, value, frame_size, cell_clip,
active_gate, active_state); active_node, active_gate, active_state);
} }
} }
template <class T, class Op> template <class T, class Op>
void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad, void cpu_lstm_backward(Op op, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, ActivationType active_node, int frame_size, T cell_clip, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) { if (Op::avx && !(frame_size & (8 - 1)) && (std::is_same<T, float>::value)) {
avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, active_node, avx_lstm_backward_one_sequence<T>(op, value, grad, frame_size, cell_clip,
active_gate, active_state); active_node, active_gate, active_state);
} else { } else {
naive_lstm_backward_one_sequence<T>(op, value, grad, frame_size, naive_lstm_backward_one_sequence<T>(op, value, grad, frame_size, cell_clip,
active_node, active_gate, active_state); active_node, active_gate, active_state);
} }
} }
......
...@@ -31,7 +31,8 @@ namespace detail { ...@@ -31,7 +31,8 @@ namespace detail {
*/ */
template <class T, class Op, bool is_batch> template <class T, class Op, bool is_batch>
__global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node, int batch_size, T cell_clip,
ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -72,7 +73,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, ...@@ -72,7 +73,7 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_node, active_gate, active_state); &cell_clip, active_node, active_gate, active_state);
value.gate_value[frame_idx] = r_value_in; value.gate_value[frame_idx] = r_value_in;
value.gate_value[frame_idx + frame_size] = r_value_ig; value.gate_value[frame_idx + frame_size] = r_value_ig;
...@@ -91,7 +92,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, ...@@ -91,7 +92,8 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
template <class T, class Op, bool is_batch> template <class T, class Op, bool is_batch>
__global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, LstmMetaGrad<T> grad, int frame_size,
int batch_size, ActivationType active_node, int batch_size, T cell_clip,
ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
...@@ -148,8 +150,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -148,8 +150,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig,
&r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state,
&r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF,
&r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, active_node, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, &cell_clip,
active_gate, active_state); active_node, active_gate, active_state);
grad.gate_grad[frame_idx] = r_grad_in; grad.gate_grad[frame_idx] = r_grad_in;
grad.gate_grad[frame_idx + frame_size] = r_grad_ig; grad.gate_grad[frame_idx + frame_size] = r_grad_ig;
...@@ -185,8 +187,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -185,8 +187,8 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
template <class T, class Op> template <class T, class Op>
void gpu_lstm_forward(const platform::DeviceContext& context, Op op, void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
ActivationType active_node, ActivationType active_gate, T cell_clip, ActivationType active_node,
ActivationType active_state) { ActivationType active_gate, ActivationType active_state) {
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
if (batch_size == 1) { if (batch_size == 1) {
...@@ -205,12 +207,12 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, ...@@ -205,12 +207,12 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
if (batch_size == 1) { if (batch_size == 1) {
KeLstmForward<T, Op, KeLstmForward<T, Op,
/* is_batch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
op, value, frame_size, batch_size, active_node, active_gate, op, value, frame_size, batch_size, cell_clip, active_node, active_gate,
active_state); active_state);
} else { } else {
KeLstmForward<T, Op, KeLstmForward<T, Op,
/* is_batch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
op, value, frame_size, batch_size, active_node, active_gate, op, value, frame_size, batch_size, cell_clip, active_node, active_gate,
active_state); active_state);
} }
} }
...@@ -218,7 +220,7 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op, ...@@ -218,7 +220,7 @@ void gpu_lstm_forward(const platform::DeviceContext& context, Op op,
template <class T, class Op> template <class T, class Op>
void gpu_lstm_backward(const platform::DeviceContext& context, Op op, void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size, T cell_clip,
ActivationType active_node, ActivationType active_gate, ActivationType active_node, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
dim3 threads; dim3 threads;
...@@ -239,13 +241,13 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op, ...@@ -239,13 +241,13 @@ void gpu_lstm_backward(const platform::DeviceContext& context, Op op,
if (batch_size == 1) { if (batch_size == 1) {
KeLstmBackward<T, Op, KeLstmBackward<T, Op,
/* is_batch= */ false><<<grid, threads, 0, stream>>>( /* is_batch= */ false><<<grid, threads, 0, stream>>>(
op, value, grad, frame_size, batch_size, active_node, active_gate, op, value, grad, frame_size, batch_size, cell_clip, active_node,
active_state); active_gate, active_state);
} else { } else {
KeLstmBackward<T, Op, KeLstmBackward<T, Op,
/* is_batch= */ true><<<grid, threads, 0, stream>>>( /* is_batch= */ true><<<grid, threads, 0, stream>>>(
op, value, grad, frame_size, batch_size, active_node, active_gate, op, value, grad, frame_size, batch_size, cell_clip, active_node,
active_state); active_gate, active_state);
} }
} }
......
...@@ -29,7 +29,7 @@ class lstm { ...@@ -29,7 +29,7 @@ class lstm {
public: public:
HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og, HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
T *prev_state, T *state, T *state_atv, T *output, T *prev_state, T *state, T *state_atv, T *output,
T *checkI, T *checkF, T *checkO, T *checkI, T *checkF, T *checkO, T *cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
...@@ -37,6 +37,15 @@ class lstm { ...@@ -37,6 +37,15 @@ class lstm {
*value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate); *value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate);
*value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate); *value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate);
*state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg); *state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg);
if (*cell_clip > 0.0) {
if (*state < -1.0 * (*cell_clip)) {
*state = -1.0 * (*cell_clip);
}
if (*state > *cell_clip) {
*state = *cell_clip;
}
}
*value_og = activation(*value_og + (*state) * (*checkO), active_gate); *value_og = activation(*value_og + (*state) * (*checkO), active_gate);
*state_atv = activation(*state, active_state); *state_atv = activation(*state, active_state);
*output = (*value_og) * (*state_atv); *output = (*value_og) * (*state_atv);
...@@ -52,7 +61,7 @@ class lstm { ...@@ -52,7 +61,7 @@ class lstm {
__m256 *value_fg, __m256 *value_og, __m256 *value_fg, __m256 *value_og,
__m256 *prev_state, __m256 *state, __m256 *prev_state, __m256 *state,
__m256 *state_atv, __m256 *output, __m256 *checkI, __m256 *state_atv, __m256 *output, __m256 *checkI,
__m256 *checkF, __m256 *checkO, __m256 *checkF, __m256 *checkO, T *cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
...@@ -65,6 +74,13 @@ class lstm { ...@@ -65,6 +74,13 @@ class lstm {
active_gate); active_gate);
*state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig), *state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig),
_mm256_mul_ps(*prev_state, *value_fg)); _mm256_mul_ps(*prev_state, *value_fg));
if (*cell_clip > 0.0f) {
__m256 min = _mm256_set1_ps(0.0f - *cell_clip);
__m256 max = _mm256_set1_ps(*cell_clip);
*state = _mm256_min_ps(max, *state);
*state = _mm256_max_ps(min, *state);
}
*value_og = activation( *value_og = activation(
_mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate); _mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate);
*state_atv = activation(*state, active_state); *state_atv = activation(*state, active_state);
...@@ -86,15 +102,26 @@ class lstm { ...@@ -86,15 +102,26 @@ class lstm {
T *prev_state, T *prev_state_grad, T *state, T *prev_state, T *prev_state_grad, T *state,
T *state_grad, T *state_atv, T *output_grad, T *state_grad, T *state_atv, T *output_grad,
T *checkI, T *checkF, T *checkO, T *checkIGrad, T *checkI, T *checkF, T *checkO, T *checkIGrad,
T *checkFGrad, T *checkOGrad, T *checkFGrad, T *checkOGrad, T *cell_clip,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
*grad_og = *grad_og =
activation((*output_grad) * (*state_atv), *value_og, active_gate); activation((*output_grad) * (*state_atv), *value_og, active_gate);
if (*cell_clip > 0.0f) {
if (*state >= (*cell_clip) || *state <= (0.0f - (*cell_clip))) {
*state_grad = 0.0f;
} else {
*state_grad +=
activation((*output_grad) * (*value_og), *state_atv, active_state) +
(*grad_og) * (*checkO);
}
} else {
*state_grad += *state_grad +=
activation((*output_grad) * (*value_og), *state_atv, active_state) + activation((*output_grad) * (*value_og), *state_atv, active_state) +
(*grad_og) * (*checkO); (*grad_og) * (*checkO);
}
*grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node); *grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node);
*grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate); *grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate);
*grad_fg = *grad_fg =
...@@ -117,15 +144,24 @@ class lstm { ...@@ -117,15 +144,24 @@ class lstm {
__m256 *prev_state, __m256 *prev_state_grad, __m256 *state, __m256 *prev_state, __m256 *prev_state_grad, __m256 *state,
__m256 *state_grad, __m256 *state_atv, __m256 *output_grad, __m256 *state_grad, __m256 *state_atv, __m256 *output_grad,
__m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad, __m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad,
__m256 *checkFGrad, __m256 *checkOGrad, ActivationType active_node, __m256 *checkFGrad, __m256 *checkOGrad, T *cell_clip,
ActivationType active_gate, ActivationType active_state) { ActivationType active_node, ActivationType active_gate,
ActivationType active_state) {
*grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og, *grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og,
active_gate); active_gate);
if (*cell_clip > 0.0f) {
T *state_ = reinterpret_cast<T *>(state);
if (*state_ >= (*cell_clip) || *state_ <= (0.0f - (*cell_clip))) {
*state_grad = _mm256_set1_ps(0.0f);
} else {
*state_grad = *state_grad =
_mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og), _mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og),
*state_atv, active_state), *state_atv, active_state),
*state_grad); *state_grad);
*state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad); *state_grad =
_mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad);
}
}
*grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in, *grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in,
active_node); active_node);
*grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig, *grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig,
......
...@@ -24,12 +24,12 @@ template <class T> ...@@ -24,12 +24,12 @@ template <class T>
struct LstmUnitFunctor<platform::CPUDeviceContext, T> { struct LstmUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context, static void compute(const platform::CPUDeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType& gate_act, T cell_clip, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size, detail::cpu_lstm_forward(detail::forward::lstm<T>(), value, frame_size,
cand_act, gate_act, cell_act); cell_clip, cand_act, gate_act, cell_act);
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
value.state_active_value += frame_size; value.state_active_value += frame_size;
...@@ -45,13 +45,14 @@ template <class T> ...@@ -45,13 +45,14 @@ template <class T>
struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> { struct LstmUnitGradFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext& context, static void compute(const platform::CPUDeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size, T cell_clip,
const detail::ActivationType& gate_act, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad, detail::cpu_lstm_backward(detail::backward::lstm<T>(), value, grad,
frame_size, cand_act, gate_act, cell_act); frame_size, cell_clip, cand_act, gate_act,
cell_act);
value.gate_value += frame_size * 4; value.gate_value += frame_size * 4;
value.state_value += frame_size; value.state_value += frame_size;
......
...@@ -24,12 +24,12 @@ template <class T> ...@@ -24,12 +24,12 @@ template <class T>
struct LstmUnitFunctor<platform::CUDADeviceContext, T> { struct LstmUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context, static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, int frame_size, int batch_size, LstmMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType& gate_act, T cell_clip, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act) {
detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value, detail::gpu_lstm_forward<T>(context, detail::forward::lstm<T>(), value,
frame_size, batch_size, cand_act, gate_act, frame_size, batch_size, cell_clip, cand_act,
cell_act); gate_act, cell_act);
} }
}; };
...@@ -37,13 +37,13 @@ template <class T> ...@@ -37,13 +37,13 @@ template <class T>
struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> { struct LstmUnitGradFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext& context, static void compute(const platform::CUDADeviceContext& context,
LstmMetaValue<T> value, LstmMetaGrad<T> grad, LstmMetaValue<T> value, LstmMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size, T cell_clip,
const detail::ActivationType& gate_act, const detail::ActivationType& gate_act,
const detail::ActivationType& cell_act, const detail::ActivationType& cell_act,
const detail::ActivationType& cand_act) { const detail::ActivationType& cand_act) {
detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad, detail::gpu_lstm_backward(context, detail::backward::lstm<T>(), value, grad,
frame_size, batch_size, cand_act, gate_act, frame_size, batch_size, cell_clip, cand_act,
cell_act); gate_act, cell_act);
} }
}; };
......
...@@ -50,7 +50,7 @@ template <typename DeviceContext, typename T> ...@@ -50,7 +50,7 @@ template <typename DeviceContext, typename T>
class LstmUnitFunctor { class LstmUnitFunctor {
public: public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value, static void compute(const DeviceContext &context, LstmMetaValue<T> value,
int frame_size, int batch_size, int frame_size, int batch_size, T cell_clip,
const detail::ActivationType &gate_act, const detail::ActivationType &gate_act,
const detail::ActivationType &cell_act, const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act); const detail::ActivationType &cand_act);
...@@ -61,7 +61,7 @@ class LstmUnitGradFunctor { ...@@ -61,7 +61,7 @@ class LstmUnitGradFunctor {
public: public:
static void compute(const DeviceContext &context, LstmMetaValue<T> value, static void compute(const DeviceContext &context, LstmMetaValue<T> value,
LstmMetaGrad<T> grad, int frame_size, int batch_size, LstmMetaGrad<T> grad, int frame_size, int batch_size,
const detail::ActivationType &gate_act, T cell_clip, const detail::ActivationType &gate_act,
const detail::ActivationType &cell_act, const detail::ActivationType &cell_act,
const detail::ActivationType &cand_act); const detail::ActivationType &cand_act);
}; };
......
...@@ -668,7 +668,11 @@ def dynamic_lstmp(input, ...@@ -668,7 +668,11 @@ def dynamic_lstmp(input,
candidate_activation='tanh', candidate_activation='tanh',
proj_activation='tanh', proj_activation='tanh',
dtype='float32', dtype='float32',
name=None): name=None,
h_0=None,
c_0=None,
cell_clip=None,
proj_clip=None):
""" """
**Dynamic LSTMP Layer** **Dynamic LSTMP Layer**
...@@ -785,6 +789,17 @@ def dynamic_lstmp(input, ...@@ -785,6 +789,17 @@ def dynamic_lstmp(input,
dtype(str): Data type. Choices = ["float32", "float64"], default "float32". dtype(str): Data type. Choices = ["float32", "float64"], default "float32".
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
h_0(Variable): The initial hidden state is an optional input, default is zero.
This is a tensor with shape (N x D), where N is the
batch size and D is the projection size.
c_0(Variable): The initial cell state is an optional input, default is zero.
This is a tensor with shape (N x D), where N is the
batch size. `h_0` and `c_0` can be NULL but only at the same time.
cell_clip(float): If provided the cell state is clipped
by this value prior to the cell output activation.
proj_clip(float): If `num_proj > 0` and `proj_clip` is
provided, then the projected values are clipped elementwise to within
`[-proj_clip, proj_clip]`.
Returns: Returns:
tuple: A tuple of two output variable: the projection of hidden state, \ tuple: A tuple of two output variable: the projection of hidden state, \
...@@ -831,25 +846,41 @@ def dynamic_lstmp(input, ...@@ -831,25 +846,41 @@ def dynamic_lstmp(input,
batch_hidden = helper.create_variable_for_type_inference(dtype) batch_hidden = helper.create_variable_for_type_inference(dtype)
batch_gate = helper.create_variable_for_type_inference(dtype) batch_gate = helper.create_variable_for_type_inference(dtype)
batch_cell_pre_act = helper.create_variable_for_type_inference(dtype) batch_cell_pre_act = helper.create_variable_for_type_inference(dtype)
inputs = {
helper.append_op(
type='lstmp',
inputs={
'Input': input, 'Input': input,
'Weight': weight, 'Weight': weight,
'ProjWeight': proj_weight, 'ProjWeight': proj_weight,
'Bias': bias 'Bias': bias
}, }
batch_size = input.shape[0]
if h_0:
assert h_0.shape == (batch_size, proj_size), \
'The shape of h0 should be (batch_size, %d)' % proj_size
inputs['H0'] = h_0
if c_0:
assert c_0.shape == (batch_size, size), \
'The shape of c0 should be (batch_size, %d)' % size
inputs['C0'] = c_0
if cell_clip:
assert cell_clip >= 0, "cell_clip should not be negtive."
if proj_clip:
assert proj_clip >= 0, "proj_clip should not be negtive."
helper.append_op(
type='lstmp',
inputs=inputs,
outputs={ outputs={
'Projection': projection, 'Projection': projection,
'Cell': cell, 'Cell': cell,
'OrderedP0': ordered_proj0,
'BatchHidden': batch_hidden, 'BatchHidden': batch_hidden,
'BatchGate': batch_gate, 'BatchGate': batch_gate,
'BatchCellPreAct': batch_cell_pre_act 'BatchCellPreAct': batch_cell_pre_act
}, },
attrs={ attrs={
'use_peepholes': use_peepholes, 'use_peepholes': use_peepholes,
'cell_clip': cell_clip,
'proj_clip': proj_clip,
'is_reverse': is_reverse, 'is_reverse': is_reverse,
'gate_activation': gate_activation, 'gate_activation': gate_activation,
'cell_activation': cell_activation, 'cell_activation': cell_activation,
......
...@@ -36,12 +36,14 @@ def lstmp( ...@@ -36,12 +36,14 @@ def lstmp(
w_b=None, # 1 x 4D w_b=None, # 1 x 4D
w_c=None, # 1 x 3D w_c=None, # 1 x 3D
is_reverse=False, is_reverse=False,
proj_clip=0.0,
cell_clip=0.0,
act_gate=None, act_gate=None,
act_cell=None, act_cell=None,
act_cand=None, act_cand=None,
act_proj=None): act_proj=None):
def _step(x, w_r, w_rh, w_c, r_pre, c_pre, act_gate, act_cell, act_cand, def _step(x, w_r, w_rh, w_c, r_pre, c_pre, proj_clip, cell_clip, act_gate,
act_proj): act_cell, act_cand, act_proj):
g = np.dot(r_pre, w_r) # 1 x 4D g = np.dot(r_pre, w_r) # 1 x 4D
g = g + x g = g + x
g = np.reshape(g, (1, g.size)) g = np.reshape(g, (1, g.size))
...@@ -55,6 +57,17 @@ def lstmp( ...@@ -55,6 +57,17 @@ def lstmp(
g_f = act_gate(g_f + w_fc * c_pre) # 1 x D g_f = act_gate(g_f + w_fc * c_pre) # 1 x D
c = g_f * c_pre + g_i * act_cand(c) # 1 x D c = g_f * c_pre + g_i * act_cand(c) # 1 x D
def array_clip(a, clip):
size = np.prod(a.shape)
new_a = np.reshape(a, (size))
for i in range(size):
new_a[i] = max(new_a[i], -1.0 * clip)
new_a[i] = min(new_a[i], clip)
new_a = np.reshape(new_a, a.shape)
return new_a
if cell_clip > 0.0:
c = array_clip(c, cell_clip)
if w_c is None: if w_c is None:
g_o = act_gate(g_o) # 1 x D g_o = act_gate(g_o) # 1 x D
else: else:
...@@ -64,6 +77,8 @@ def lstmp( ...@@ -64,6 +77,8 @@ def lstmp(
# projection # projection
r = np.dot(h, w_rh) r = np.dot(h, w_rh)
r = act_proj(r) r = act_proj(r)
if proj_clip > 0.0:
r = array_clip(r, proj_clip)
return r, c return r, c
def _reverse(x, offset): def _reverse(x, offset):
...@@ -87,13 +102,13 @@ def lstmp( ...@@ -87,13 +102,13 @@ def lstmp(
# compute one sequence # compute one sequence
seq_len = lod[0][i] seq_len = lod[0][i]
x = input[offset[i]:offset[i + 1], :] x = input[offset[i]:offset[i + 1], :]
r_pre = np.dot(h0[i], w_rh) # 1 x P r_pre = h0[i]
r_pre = act_proj(r_pre)
c_pre = c0[i] # 1 x D c_pre = c0[i] # 1 x D
for j in range(seq_len): for j in range(seq_len):
# compute one step # compute one step
r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, act_gate, r_pre, c_pre = _step(x[j], w_r, w_rh, w_c, r_pre, c_pre, proj_clip,
act_cell, act_cand, act_proj) cell_clip, act_gate, act_cell, act_cand,
act_proj)
projection.append(r_pre.flatten()) projection.append(r_pre.flatten())
cell.append(c_pre.flatten()) cell.append(c_pre.flatten())
...@@ -123,13 +138,12 @@ class TestLstmpOp(LstmTest.TestLstmOp): ...@@ -123,13 +138,12 @@ class TestLstmpOp(LstmTest.TestLstmOp):
T = sum(self.lod[0]) T = sum(self.lod[0])
N = len(self.lod[0]) N = len(self.lod[0])
x = np.random.normal(size=(T, 4 * self.D)).astype('float64') x = np.random.normal(size=(T, 4 * self.D)).astype('float64')
if self.has_initial_state: if self.has_initial_state:
h0 = np.random.normal(size=(N, self.D)).astype('float64') h0 = np.random.normal(size=(N, self.P)).astype('float64')
c0 = np.random.normal(size=(N, self.D)).astype('float64') c0 = np.random.normal(size=(N, self.D)).astype('float64')
else: else:
h0 = np.zeros((N, self.D)).astype('float64') h0 = np.zeros((N, self.P)).astype('float64')
c0 = np.zeros((N, self.D)).astype('float64') c0 = np.zeros((N, self.D)).astype('float64')
w = np.random.normal(size=(self.P, 4 * self.D)).astype('float64') w = np.random.normal(size=(self.P, 4 * self.D)).astype('float64')
if self.use_peepholes: if self.use_peepholes:
...@@ -140,9 +154,12 @@ class TestLstmpOp(LstmTest.TestLstmOp): ...@@ -140,9 +154,12 @@ class TestLstmpOp(LstmTest.TestLstmOp):
w_b = b[:, 0:4 * self.D] w_b = b[:, 0:4 * self.D]
w_c = b[:, 4 * self.D:] if self.use_peepholes else None w_c = b[:, 4 * self.D:] if self.use_peepholes else None
w_rh = np.random.normal(size=(self.D, self.P)).astype('float64') w_rh = np.random.normal(size=(self.D, self.P)).astype('float64')
proj_clip = 0.1
cell_clip = 0.1
r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse, r, c = lstmp(x, self.lod, h0, c0, w, w_rh, w_b, w_c, self.is_reverse,
ACTIVATION[self.act_gate], ACTIVATION[self.act_cell], proj_clip, cell_clip, ACTIVATION[self.act_gate],
ACTIVATION[self.act_cand], ACTIVATION[self.act_proj]) ACTIVATION[self.act_cell], ACTIVATION[self.act_cand],
ACTIVATION[self.act_proj])
self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh} self.inputs = {'Input': (x, self.lod), 'Weight': w, 'ProjWeight': w_rh}
...@@ -159,6 +176,8 @@ class TestLstmpOp(LstmTest.TestLstmOp): ...@@ -159,6 +176,8 @@ class TestLstmpOp(LstmTest.TestLstmOp):
self.attrs = { self.attrs = {
'use_peepholes': self.use_peepholes, 'use_peepholes': self.use_peepholes,
'is_reverse': self.is_reverse, 'is_reverse': self.is_reverse,
'proj_clip': proj_clip,
'cell_clip': cell_clip,
'gate_activation': self.act_gate, 'gate_activation': self.act_gate,
'cell_activation': self.act_cell, 'cell_activation': self.act_cell,
'candidate_activation': self.act_cand, 'candidate_activation': self.act_cand,
...@@ -171,14 +190,14 @@ class TestLstmpOp(LstmTest.TestLstmOp): ...@@ -171,14 +190,14 @@ class TestLstmpOp(LstmTest.TestLstmOp):
def test_check_grad(self): def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined. # TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) N = len(self.lod[0])
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'ProjWeight', 'Bias'], ['Projection'], ['Input', 'Weight', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=1e-2) max_relative_error=1e-2,
numeric_grad_delta=0.0000005)
class TestLstmpOpHasInitial(TestLstmpOp): class TestLstmpOpHasInitial(TestLstmpOp):
...@@ -188,7 +207,6 @@ class TestLstmpOpHasInitial(TestLstmpOp): ...@@ -188,7 +207,6 @@ class TestLstmpOpHasInitial(TestLstmpOp):
def test_check_grad(self): def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined. # TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) N = len(self.lod[0])
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
...@@ -196,11 +214,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): ...@@ -196,11 +214,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self.check_grad( self.check_grad(
['Input', 'Weight', 'ProjWeight', 'Bias', 'H0', 'C0'], ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0', 'C0'],
['Projection'], ['Projection'],
numeric_grad_delta=0.0000005,
max_relative_error=1e-2) max_relative_error=1e-2)
def test_check_grad_ingore_bias(self): def test_check_grad_ingore_bias(self):
N = len(self.lod[0]) N = len(self.lod[0])
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
...@@ -208,11 +226,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): ...@@ -208,11 +226,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self.check_grad( self.check_grad(
['Input', 'ProjWeight', 'Weight'], ['Projection'], ['Input', 'ProjWeight', 'Weight'], ['Projection'],
max_relative_error=1e-2, max_relative_error=1e-2,
numeric_grad_delta=0.0000005,
no_grad_set=set('Bias')) no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self): def test_check_grad_ingore_weight(self):
N = len(self.lod[0]) N = len(self.lod[0])
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
...@@ -220,11 +238,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): ...@@ -220,11 +238,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self.check_grad( self.check_grad(
['Input', 'ProjWeight', 'Bias'], ['Projection'], ['Input', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=1e-2, max_relative_error=1e-2,
numeric_grad_delta=0.0000005,
no_grad_set=set('Weight')) no_grad_set=set('Weight'))
def test_check_grad_ingore_proj_weight(self): def test_check_grad_ingore_proj_weight(self):
N = len(self.lod[0]) N = len(self.lod[0])
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
...@@ -232,11 +250,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): ...@@ -232,11 +250,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias'], ['Projection'], ['Input', 'Weight', 'Bias'], ['Projection'],
max_relative_error=1e-2, max_relative_error=1e-2,
numeric_grad_delta=0.0000005,
no_grad_set=set('ProjWeight')) no_grad_set=set('ProjWeight'))
def test_check_grad_ingore_input(self): def test_check_grad_ingore_input(self):
N = len(self.lod[0]) N = len(self.lod[0])
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
...@@ -244,11 +262,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): ...@@ -244,11 +262,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self.check_grad( self.check_grad(
['Weight', 'ProjWeight', 'Bias'], ['Projection'], ['Weight', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=1e-2, max_relative_error=1e-2,
numeric_grad_delta=0.0000005,
no_grad_set=set('Input')) no_grad_set=set('Input'))
def test_check_grad_ingore_h0(self): def test_check_grad_ingore_h0(self):
N = len(self.lod[0]) N = len(self.lod[0])
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
...@@ -256,11 +274,11 @@ class TestLstmpOpHasInitial(TestLstmpOp): ...@@ -256,11 +274,11 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self.check_grad( self.check_grad(
['Input', 'Weight', 'ProjWeight', 'Bias', 'C0'], ['Projection'], ['Input', 'Weight', 'ProjWeight', 'Bias', 'C0'], ['Projection'],
max_relative_error=1e-2, max_relative_error=1e-2,
numeric_grad_delta=0.0000005,
no_grad_set=set('H0')) no_grad_set=set('H0'))
def test_check_grad_ingore_c0(self): def test_check_grad_ingore_c0(self):
N = len(self.lod[0]) N = len(self.lod[0])
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64') self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
...@@ -268,6 +286,7 @@ class TestLstmpOpHasInitial(TestLstmpOp): ...@@ -268,6 +286,7 @@ class TestLstmpOpHasInitial(TestLstmpOp):
self.check_grad( self.check_grad(
['Input', 'Weight', 'ProjWeight', 'Bias', 'H0'], ['Projection'], ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0'], ['Projection'],
max_relative_error=1e-2, max_relative_error=1e-2,
numeric_grad_delta=0.0000005,
no_grad_set=set('C0')) no_grad_set=set('C0'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册