提交 7a5b8ffa 编写于 作者: Y Yibing Liu

Pass grad checking for projection weight

上级 552c9012
......@@ -217,7 +217,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
Long-Short Term Memory with Recurrent Projection (LSTMP) Operator.
LATMP is stand LSTM appended by a recurrent projection layer to reduce the
LSTMP is stand LSTM appended by a recurrent projection layer to reduce the
number of parameters, espeacially when the output size is relative large.
The formula is as follows:
......@@ -232,7 +232,7 @@ o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\
h_t = o_t \odot act_h(c_t)
r_t = act_h'(W_{rh}h_t)
r_t = act_{h'}(W_{rh}h_t)
$$
where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix
......
......@@ -365,10 +365,18 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev,
proj_g_dev);
}
/* hidden state backwarad */
Tensor out_g = batch_hidden_g.Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, proj_g, false, *proj_weight,
true, static_cast<T>(1.0), &out_g,
static_cast<T>(0.0));
/* projection weight backward*/
if (proj_weight_g) {
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, hidden_t, true, proj_g,
false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
Tensor gate = batch_gate->Slice(bstart, bend);
Tensor cell = batch_cell.Slice(bstart, bend);
......@@ -407,19 +415,12 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
static_cast<T>(1.0), &pre_proj_g,
static_cast<T>(1.0));
if (weight_g) {
/* backward weight */
/* weight backward*/
auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_proj, true, gate_g,
false, static_cast<T>(1.0), weight_g,
static_cast<T>(1.0));
}
if (proj_weight_g) {
/* backward proj weigh */
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, hidden_t, true, proj_g,
false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
} else {
if (h0 && weight_g) {
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
......@@ -444,7 +445,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev,
proj0_g_dev);
}
// Tensor proj0_g = proj_g.Slice(bstart, bend);
if (h0_g) {
math::matmul<DeviceContext, T>(
device_ctx, proj0_g, false, *proj_weight, true,
......
......@@ -207,8 +207,8 @@ class TestLstmOp(OpTest):
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias'], ['Projection'],
max_relative_error=5e-3)
['Input', 'Weight', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=1e-2)
class TestLstmOpHasInitial(TestLstmOp):
......@@ -235,8 +235,9 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Projection'],
max_relative_error=5e-3)
['Input', 'Weight', 'ProjWeight', 'Bias', 'H0', 'C0'],
['Projection'],
max_relative_error=1e-2)
def test_check_grad_ingore_bias(self):
N = len(self.lod[0]) - 1
......@@ -246,8 +247,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight'], ['Projection'],
max_relative_error=5e-3,
['Input', 'ProjWeight', 'Weight'], ['Projection'],
max_relative_error=1e-2,
no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self):
......@@ -258,10 +259,22 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Bias'], ['Projection'],
max_relative_error=5e-3,
['Input', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=1e-2,
no_grad_set=set('Weight'))
def test_check_grad_ingore_proj_weight(self):
N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias'], ['Projection'],
max_relative_error=1e-2,
no_grad_set=set('ProjWeight'))
def test_check_grad_ingore_input(self):
N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
......@@ -270,8 +283,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Weight', 'Bias'], ['Projection'],
max_relative_error=5e-3,
['Weight', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=1e-2,
no_grad_set=set('Input'))
def test_check_grad_ingore_h0(self):
......@@ -282,8 +295,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'C0'], ['Projection'],
max_relative_error=5e-3,
['Input', 'Weight', 'ProjWeight', 'Bias', 'C0'], ['Projection'],
max_relative_error=1e-2,
no_grad_set=set('H0'))
def test_check_grad_ingore_c0(self):
......@@ -294,8 +307,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias', 'H0'], ['Projection'],
max_relative_error=5e-3,
['Input', 'Weight', 'ProjWeight', 'Bias', 'H0'], ['Projection'],
max_relative_error=1e-2,
no_grad_set=set('C0'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册