提交 d94c936b 编写于 作者: D dangqingqing

Enhance unit testing.

1. user can disable peephole connections.
2. not calculate some gradients.
上级 d851dafe
...@@ -164,16 +164,19 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -164,16 +164,19 @@ class LSTMOpMaker : public framework::OpProtoAndCheckerMaker {
"(string, default: sigmoid)" "(string, default: sigmoid)"
"The activation for input gate, forget gate and output " "The activation for input gate, forget gate and output "
"gate, `sigmoid` by default.") "gate, `sigmoid` by default.")
.SetDefault("sigmoid"); .SetDefault("sigmoid")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("cell_activation", AddAttr<std::string>("cell_activation",
"(string, default: tanh)" "(string, default: tanh)"
"The activation for cell output, `tanh` by defalut.") "The activation for cell output, `tanh` by defalut.")
.SetDefault("tanh"); .SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddAttr<std::string>("candidate_activation", AddAttr<std::string>("candidate_activation",
"(string, default: tanh)" "(string, default: tanh)"
"The activation for candidate hidden state, " "The activation for candidate hidden state, "
"`tanh` by default.") "`tanh` by default.")
.SetDefault("tanh"); .SetDefault("tanh")
.InEnum({"sigmoid", "tanh", "relu", "identity"});
AddComment(R"DOC( AddComment(R"DOC(
Long-Short Term Memory (LSTM) Operator. Long-Short Term Memory (LSTM) Operator.
......
...@@ -69,7 +69,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -69,7 +69,7 @@ class LSTMKernel : public framework::OpKernel<T> {
} }
math::LstmMetaValue<T> lstm_value; math::LstmMetaValue<T> lstm_value;
if (bias) { if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
// the code style in LstmMetaValue will be updated later. // the code style in LstmMetaValue will be updated later.
...@@ -85,6 +85,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -85,6 +85,7 @@ class LSTMKernel : public framework::OpKernel<T> {
Tensor ordered_c0; Tensor ordered_c0;
if (cell_t0) { if (cell_t0) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle; math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
ordered_c0.mutable_data<T>(cell_t0->dims(), ctx.GetPlace());
const size_t* order = batch_gate->lod()[2].data(); const size_t* order = batch_gate->lod()[2].data();
row_shuffle(device_ctx, *cell_t0, order, ordered_c0, true); row_shuffle(device_ctx, *cell_t0, order, ordered_c0, true);
lstm_value.prevStateValue = ordered_c0.data<T>(); lstm_value.prevStateValue = ordered_c0.data<T>();
...@@ -124,6 +125,7 @@ class LSTMKernel : public framework::OpKernel<T> { ...@@ -124,6 +125,7 @@ class LSTMKernel : public framework::OpKernel<T> {
} else if (hidden_t0) { } else if (hidden_t0) {
math::CopyMatrixRowsFunctor<Place, T> row_shuffle; math::CopyMatrixRowsFunctor<Place, T> row_shuffle;
Tensor ordered_h0; Tensor ordered_h0;
ordered_h0.mutable_data<T>(hidden_t0->dims(), ctx.GetPlace());
const size_t* order = batch_gate->lod()[2].data(); const size_t* order = batch_gate->lod()[2].data();
row_shuffle(device_ctx, *hidden_t0, order, ordered_h0, true); row_shuffle(device_ctx, *hidden_t0, order, ordered_h0, true);
math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false, math::matmul<Place, T>(device_ctx, ordered_h0, false, *weight, false,
...@@ -199,7 +201,7 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -199,7 +201,7 @@ class LSTMGradKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(frame_size, out_dims[1]); PADDLE_ENFORCE_EQ(frame_size, out_dims[1]);
math::LstmMetaValue<T> lstm_value; math::LstmMetaValue<T> lstm_value;
if (bias) { if (bias && ctx.Attr<bool>("use_peepholes")) {
T* bias_data = const_cast<T*>(bias->data<T>()); T* bias_data = const_cast<T*>(bias->data<T>());
lstm_value.checkIg = bias_data + 4 * frame_size; lstm_value.checkIg = bias_data + 4 * frame_size;
lstm_value.checkFg = lstm_value.checkIg + frame_size; lstm_value.checkFg = lstm_value.checkIg + frame_size;
...@@ -211,9 +213,13 @@ class LSTMGradKernel : public framework::OpKernel<T> { ...@@ -211,9 +213,13 @@ class LSTMGradKernel : public framework::OpKernel<T> {
} }
math::LstmMetaGrad<T> lstm_grad; math::LstmMetaGrad<T> lstm_grad;
if (bias && bias_g) { if (bias && bias_g) {
T* bias_g_data = const_cast<T*>(bias_g->mutable_data<T>(ctx.GetPlace())); bias_g->mutable_data<T>(ctx.GetPlace());
zero(device_ctx, bias_g, static_cast<T>(0.0)); zero(device_ctx, bias_g, static_cast<T>(0.0));
}
if (bias && bias_g && ctx.Attr<bool>("use_peepholes")) {
T* bias_g_data = bias_g->data<T>();
lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size; lstm_grad.checkIgGrad = bias_g_data + 4 * frame_size;
lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size; lstm_grad.checkFgGrad = lstm_grad.checkIgGrad + frame_size;
lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size; lstm_grad.checkOgGrad = lstm_grad.checkFgGrad + frame_size;
......
...@@ -52,9 +52,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -52,9 +52,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
rValueIg = valueIg[i]; rValueIg = valueIg[i];
rValueFg = valueFg[i]; rValueFg = valueFg[i];
rValueOg = valueOg[i]; rValueOg = valueOg[i];
rCheckI = value.checkIg[i]; rCheckI = value.checkIg ? value.checkIg[i] : 0;
rCheckF = value.checkFg[i]; rCheckF = value.checkFg ? value.checkFg[i] : 0;
rCheckO = value.checkOg[i]; rCheckO = value.checkOg ? value.checkOg[i] : 0;
if (value.prevStateValue) { if (value.prevStateValue) {
rPrevState = value.prevStateValue[i]; rPrevState = value.prevStateValue[i];
...@@ -114,9 +114,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -114,9 +114,9 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
rValueIg = valueIg[i]; rValueIg = valueIg[i];
rValueFg = valueFg[i]; rValueFg = valueFg[i];
rValueOg = valueOg[i]; rValueOg = valueOg[i];
rCheckI = value.checkIg[i]; rCheckI = value.checkIg ? value.checkIg[i] : 0;
rCheckF = value.checkFg[i]; rCheckF = value.checkFg ? value.checkFg[i] : 0;
rCheckO = value.checkOg[i]; rCheckO = value.checkOg ? value.checkOg[i] : 0;
rState = value.stateValue[i]; rState = value.stateValue[i];
rStateAtv = value.stateActiveValue[i]; rStateAtv = value.stateActiveValue[i];
rOutputGrad = grad.outputGrad[i]; rOutputGrad = grad.outputGrad[i];
...@@ -155,9 +155,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -155,9 +155,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize,
__m256 rValueIg; __m256 rValueIg;
__m256 rValueFg; __m256 rValueFg;
__m256 rValueOg; __m256 rValueOg;
__m256 rCheckI; __m256 rCheckI = _mm256_set1_ps(0.0f);
__m256 rCheckF; __m256 rCheckF = _mm256_set1_ps(0.0f);
__m256 rCheckO; __m256 rCheckO = _mm256_set1_ps(0.0f);
__m256 rState; __m256 rState;
__m256 rPrevState = _mm256_set1_ps(0.0f); __m256 rPrevState = _mm256_set1_ps(0.0f);
__m256 rStateAtv; __m256 rStateAtv;
...@@ -173,9 +173,11 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -173,9 +173,11 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, int frameSize,
rValueIg = valueIg[i]; rValueIg = valueIg[i];
rValueFg = valueFg[i]; rValueFg = valueFg[i];
rValueOg = valueOg[i]; rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i]; if (value.checkIg) {
rCheckF = ((__m256 *)value.checkFg)[i]; rCheckI = ((__m256 *)value.checkIg)[i];
rCheckO = ((__m256 *)value.checkOg)[i]; rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
}
if (value.prevStateValue) { if (value.prevStateValue) {
rPrevState = ((__m256 *)value.prevStateValue)[i]; rPrevState = ((__m256 *)value.prevStateValue)[i];
...@@ -216,9 +218,9 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -216,9 +218,9 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
__m256 rState; __m256 rState;
__m256 rStateAtv; __m256 rStateAtv;
__m256 rOutputGrad; __m256 rOutputGrad;
__m256 rCheckI; __m256 rCheckI = _mm256_set1_ps(0.0f);
__m256 rCheckF; __m256 rCheckF = _mm256_set1_ps(0.0f);
__m256 rCheckO; __m256 rCheckO = _mm256_set1_ps(0.0f);
__m256 rCheckIGrad; __m256 rCheckIGrad;
__m256 rCheckFGrad; __m256 rCheckFGrad;
__m256 rCheckOGrad; __m256 rCheckOGrad;
...@@ -237,9 +239,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -237,9 +239,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
rValueIg = valueIg[i]; rValueIg = valueIg[i];
rValueFg = valueFg[i]; rValueFg = valueFg[i];
rValueOg = valueOg[i]; rValueOg = valueOg[i];
rCheckI = ((__m256 *)value.checkIg)[i]; if (value.checkIg) {
rCheckF = ((__m256 *)value.checkFg)[i]; rCheckI = ((__m256 *)value.checkIg)[i];
rCheckO = ((__m256 *)value.checkOg)[i]; rCheckF = ((__m256 *)value.checkFg)[i];
rCheckO = ((__m256 *)value.checkOg)[i];
}
rState = ((__m256 *)value.stateValue)[i]; rState = ((__m256 *)value.stateValue)[i];
rStateAtv = ((__m256 *)value.stateActiveValue)[i]; rStateAtv = ((__m256 *)value.stateActiveValue)[i];
rOutputGrad = ((__m256 *)grad.outputGrad)[i]; rOutputGrad = ((__m256 *)grad.outputGrad)[i];
......
...@@ -55,9 +55,10 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize, ...@@ -55,9 +55,10 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frameSize,
T rValueIg; T rValueIg;
T rValueFg; T rValueFg;
T rValueOg; T rValueOg;
T rCheckI = value.checkIg[frameIdx];
T rCheckF = value.checkFg[frameIdx]; T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0;
T rCheckO = value.checkOg[frameIdx]; T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0;
T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0;
rValueIn = value.gateValue[frameIdx]; rValueIn = value.gateValue[frameIdx];
rValueIg = value.gateValue[frameIdx + frameSize]; rValueIg = value.gateValue[frameIdx + frameSize];
...@@ -121,9 +122,10 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -121,9 +122,10 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
T rStateGrad; T rStateGrad;
T rStateAtv; T rStateAtv;
T rOutputGrad; T rOutputGrad;
T rCheckI = value.checkIg[frameIdx]; T rCheckI = value.checkIg ? value.checkIg[frameIdx] : 0;
T rCheckF = value.checkFg[frameIdx]; T rCheckF = value.checkFg ? value.checkFg[frameIdx] : 0;
T rCheckO = value.checkOg[frameIdx]; T rCheckO = value.checkOg ? value.checkOg[frameIdx] : 0;
T rCheckIGrad; T rCheckIGrad;
T rCheckFGrad; T rCheckFGrad;
T rCheckOGrad; T rCheckOGrad;
......
...@@ -31,7 +31,7 @@ class CopyMatrixRowsFunctor { ...@@ -31,7 +31,7 @@ class CopyMatrixRowsFunctor {
// The indexed rows are based on the input index. // The indexed rows are based on the input index.
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& src, const size_t* index, const framework::Tensor& src, const size_t* index,
framework::Tensor* dst, bool is_src_index); framework::Tensor& dst, bool is_src_index);
}; };
template <typename Place, typename T> template <typename Place, typename T>
...@@ -57,7 +57,7 @@ class LoDTensor2BatchFunctor { ...@@ -57,7 +57,7 @@ class LoDTensor2BatchFunctor {
bool is_reverse = false) const { bool is_reverse = false) const {
if (!is_cal_batch_lod) { if (!is_cal_batch_lod) {
auto lods = batch.lod(); auto lods = batch.lod();
PADDLE_ENFORCE_LE(lods.size(), 2UL); PADDLE_ENFORCE_GT(lods.size(), 2UL);
PADDLE_ENFORCE_EQ(lods[1].size(), PADDLE_ENFORCE_EQ(lods[1].size(),
static_cast<size_t>(lod_tensor.dims()[0])); static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_batch; CopyMatrixRowsFunctor<Place, T> to_batch;
...@@ -68,8 +68,6 @@ class LoDTensor2BatchFunctor { ...@@ -68,8 +68,6 @@ class LoDTensor2BatchFunctor {
auto lods = lod_tensor.lod(); auto lods = lod_tensor.lod();
auto lod = lods[0]; auto lod = lods[0];
PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now."); PADDLE_ENFORCE_EQ(lods.size(), 1UL, "Only support one level sequence now.");
PADDLE_ENFORCE_EQ(lod_tensor.dims()[0],
static_cast<int64_t>(lod.size() - 1));
std::vector<SeqInfo> seq_info; std::vector<SeqInfo> seq_info;
for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) { for (size_t seq_id = 0; seq_id < lod.size() - 1; ++seq_id) {
...@@ -112,7 +110,7 @@ class LoDTensor2BatchFunctor { ...@@ -112,7 +110,7 @@ class LoDTensor2BatchFunctor {
int num_batch = seq_info[0].length; int num_batch = seq_info[0].length;
batch_lods[0].resize(static_cast<size_t>(num_batch + 1)); batch_lods[0].resize(static_cast<size_t>(num_batch + 1));
// batch_lods[1] is the raw index in the input LoDTensor // batch_lods[1] is the raw index in the input LoDTensor
batch_lods[1].resize(static_cast<size_t>(seq_info.size())); batch_lods[1].resize(static_cast<size_t>(lod_tensor.dims()[0]));
// batch_lods[2] is the sort order for the input LoDTensor. // batch_lods[2] is the sort order for the input LoDTensor.
batch_lods[2].resize(seq_info.size()); batch_lods[2].resize(seq_info.size());
...@@ -152,8 +150,7 @@ class Batch2LoDTensorFunctor { ...@@ -152,8 +150,7 @@ class Batch2LoDTensorFunctor {
const framework::LoDTensor& batch, const framework::LoDTensor& batch,
framework::LoDTensor& lod_tensor) const { framework::LoDTensor& lod_tensor) const {
auto in_lod = batch.lod(); auto in_lod = batch.lod();
PADDLE_ENFORCE_LT(in_lod.size(), 2UL, PADDLE_ENFORCE_GT(in_lod.size(), 2UL);
"The LoD size of input `batch` should be 2.");
PADDLE_ENFORCE_EQ(in_lod[1].size(), PADDLE_ENFORCE_EQ(in_lod[1].size(),
static_cast<size_t>(lod_tensor.dims()[0])); static_cast<size_t>(lod_tensor.dims()[0]));
CopyMatrixRowsFunctor<Place, T> to_seq; CopyMatrixRowsFunctor<Place, T> to_seq;
......
...@@ -117,9 +117,9 @@ class TestLstmOp(OpTest): ...@@ -117,9 +117,9 @@ class TestLstmOp(OpTest):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.has_initial_state = True self.has_initial_state = False
self.has_bias = True
self.is_reverse = False self.is_reverse = False
self.use_peepholes = True
def setUp(self): def setUp(self):
self.set_argument() self.set_argument()
...@@ -129,21 +129,27 @@ class TestLstmOp(OpTest): ...@@ -129,21 +129,27 @@ class TestLstmOp(OpTest):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
x = np.random.normal(size=(T, 4 * self.D)).astype('float64') x = np.random.normal(size=(T, 4 * self.D)).astype('float64')
h0 = np.zeros((N, self.D)).astype('float64') if self.has_initial_state:
c0 = np.zeros((N, self.D)).astype('float64') h0 = np.random.normal(size=(N, self.D)).astype('float64')
c0 = np.random.normal(size=(N, self.D)).astype('float64')
else:
h0 = np.zeros((N, self.D)).astype('float64')
c0 = np.zeros((N, self.D)).astype('float64')
w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64') w = np.random.normal(size=(self.D, 4 * self.D)).astype('float64')
b = np.random.normal(size=(1, 7 * self.D)).astype('float64') if self.use_peepholes:
b = np.random.normal(size=(1, 7 * self.D)).astype('float64')
else:
b = np.random.normal(size=(1, 4 * self.D)).astype('float64')
w_b = b[:, 0:4 * self.D] if self.has_bias else None w_b = b[:, 0:4 * self.D]
w_c = b[:, 4 * self.D:] if self.has_bias else None w_c = b[:, 4 * self.D:] if self.use_peepholes else None
h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse, h, c = lstm(x, self.lod, h0, c0, w, w_b, w_c, self.is_reverse,
ACTVATION[self.act_gate], ACTVATION[self.act_cell], ACTVATION[self.act_gate], ACTVATION[self.act_cell],
ACTVATION[self.act_cand]) ACTVATION[self.act_cand])
self.inputs = {'Input': (x, self.lod), 'Weight': w} self.inputs = {'Input': (x, self.lod), 'Weight': w}
if self.has_bias: self.inputs['Bias'] = b
self.inputs['Bias'] = b
if self.has_initial_state: if self.has_initial_state:
self.inputs['H0'] = h0 self.inputs['H0'] = h0
...@@ -154,18 +160,17 @@ class TestLstmOp(OpTest): ...@@ -154,18 +160,17 @@ class TestLstmOp(OpTest):
'Cell': (c, self.lod), 'Cell': (c, self.lod),
} }
self.attrs = { self.attrs = {
'use_peepholes': True, 'use_peepholes': self.use_peepholes,
'is_reverse': self.is_reverse, 'is_reverse': self.is_reverse,
'gate_activation': self.act_gate, 'gate_activation': self.act_gate,
'cell_activation': self.act_cell, 'cell_activation': self.act_cell,
'candidate_activation': self.act_cand 'candidate_activation': self.act_cand
} }
def not_test_check_output(self): def test_check_output(self):
self.check_output(atol=1e-8) self.check_output(atol=1e-8)
#TODO(qingqing) add more unit testing case def test_check_grad(self):
def not_test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined. # TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64') self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
...@@ -174,8 +179,38 @@ class TestLstmOp(OpTest): ...@@ -174,8 +179,38 @@ class TestLstmOp(OpTest):
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4) ['Input', 'Weight', 'Bias'], ['Hidden'], max_relative_error=5e-4)
def test_check_grad_ingore_bias(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Weight'))
def test_check_grad_ingore_input(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Weight', 'Bias'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('Input'))
class TestLstmOpHasNoInitial(TestLstmOp): class TestLstmOpHasInitial(TestLstmOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 5, 7]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
...@@ -184,12 +219,52 @@ class TestLstmOpHasNoInitial(TestLstmOp): ...@@ -184,12 +219,52 @@ class TestLstmOpHasNoInitial(TestLstmOp):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.has_initial_state = False self.has_initial_state = True
self.is_reverse = True self.is_reverse = True
self.has_bias = True self.use_peepholes = True
def test_check_grad(self):
# TODO(qingqing) remove folowing lines after the check_grad is refined.
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Hidden'],
max_relative_error=5e-4)
class TestLstmOpHasNoBias(TestLstmOp): # In order to speed up, skip following testing
def test_check_grad_ingore_bias(self):
return
def test_check_grad_ingore_weight(self):
return
def test_check_grad_ingore_input(self):
return
def test_check_grad_ingore_h0(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'C0'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('H0'))
def test_check_grad_ingore_c0(self):
N = len(self.lod[0]) - 1
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'H0'], ['Hidden'],
max_relative_error=5e-4,
no_grad_set=set('C0'))
class TestLstmOpRerverse(TestLstmOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 5, 7]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
...@@ -198,15 +273,22 @@ class TestLstmOpHasNoBias(TestLstmOp): ...@@ -198,15 +273,22 @@ class TestLstmOpHasNoBias(TestLstmOp):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.has_initial_state = True self.has_initial_state = False
self.is_reverse = False self.is_reverse = True
self.has_bias = False self.use_peepholes = True
def test_check_output(self): # In order to speed up, skip following testing
self.check_output(atol=1e-8) def test_check_grad_ingore_bias(self):
return
def test_check_grad_ingore_weight(self):
return
class TestLstmOpRerverse(TestLstmOp): def test_check_grad_ingore_input(self):
return
class TestLstmOpNotUsePeepholes(TestLstmOp):
def set_argument(self): def set_argument(self):
self.lod = [[0, 2, 5, 7]] self.lod = [[0, 2, 5, 7]]
self.D = 16 self.D = 16
...@@ -215,9 +297,19 @@ class TestLstmOpRerverse(TestLstmOp): ...@@ -215,9 +297,19 @@ class TestLstmOpRerverse(TestLstmOp):
self.act_cell = 'tanh' self.act_cell = 'tanh'
self.act_cand = 'tanh' self.act_cand = 'tanh'
self.has_initial_state = True self.has_initial_state = False
self.is_reverse = True self.is_reverse = True
self.has_bias = True self.use_peepholes = False
# In order to speed up, skip following testing
def test_check_grad_ingore_bias(self):
return
def test_check_grad_ingore_weight(self):
return
def test_check_grad_ingore_input(self):
return
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -102,7 +102,7 @@ class Momentum(Optimizer): ...@@ -102,7 +102,7 @@ class Momentum(Optimizer):
.. math:: .. math::
v_{t} &= k * v_{t-1} - \\gamma_t / (g_{t} + \\lambda w_{t-1}) \\\\ v_{t} &= k * v_{t-1} - \\gamma_t (g_{t} + \\lambda w_{t-1}) \\\\
w_{t} &= w_{t-1} + v_{t} \\\\ w_{t} &= w_{t-1} + v_{t} \\\\
where, :math:`k` is momentum, :math:`\\lambda` is decay rate, where, :math:`k` is momentum, :math:`\\lambda` is decay rate,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册