提交 7a5b8ffa 编写于 作者: Y Yibing Liu

Pass grad checking for projection weight

上级 552c9012
...@@ -217,7 +217,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -217,7 +217,7 @@ class LSTMPOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC( AddComment(R"DOC(
Long-Short Term Memory with Recurrent Projection (LSTMP) Operator. Long-Short Term Memory with Recurrent Projection (LSTMP) Operator.
LATMP is stand LSTM appended by a recurrent projection layer to reduce the LSTMP is stand LSTM appended by a recurrent projection layer to reduce the
number of parameters, espeacially when the output size is relative large. number of parameters, espeacially when the output size is relative large.
The formula is as follows: The formula is as follows:
...@@ -232,7 +232,7 @@ o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\ ...@@ -232,7 +232,7 @@ o_t = \sigma(W_{ox}x_{t} + W_{oh}r_{t-1} + W_{oc}c_t + b_o) \\
h_t = o_t \odot act_h(c_t) h_t = o_t \odot act_h(c_t)
r_t = act_h'(W_{rh}h_t) r_t = act_{h'}(W_{rh}h_t)
$$ $$
where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix where the W terms denote weight matrices (e.g. $W_{xi}$ is the matrix
......
...@@ -365,10 +365,18 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -365,10 +365,18 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev, ActGradCompute(cell_act, place, cur_proj_dev, cur_proj_dev, proj_g_dev,
proj_g_dev); proj_g_dev);
} }
/* hidden state backwarad */
Tensor out_g = batch_hidden_g.Slice(bstart, bend); Tensor out_g = batch_hidden_g.Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, proj_g, false, *proj_weight, math::matmul<DeviceContext, T>(device_ctx, proj_g, false, *proj_weight,
true, static_cast<T>(1.0), &out_g, true, static_cast<T>(1.0), &out_g,
static_cast<T>(0.0)); static_cast<T>(0.0));
/* projection weight backward*/
if (proj_weight_g) {
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, hidden_t, true, proj_g,
false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
Tensor gate = batch_gate->Slice(bstart, bend); Tensor gate = batch_gate->Slice(bstart, bend);
Tensor cell = batch_cell.Slice(bstart, bend); Tensor cell = batch_cell.Slice(bstart, bend);
...@@ -407,19 +415,12 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -407,19 +415,12 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
static_cast<T>(1.0), &pre_proj_g, static_cast<T>(1.0), &pre_proj_g,
static_cast<T>(1.0)); static_cast<T>(1.0));
if (weight_g) { if (weight_g) {
/* backward weight */ /* weight backward*/
auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end); auto pre_proj = batch_proj.Slice(pre_h_start, pre_h_end);
math::matmul<DeviceContext, T>(device_ctx, pre_proj, true, gate_g, math::matmul<DeviceContext, T>(device_ctx, pre_proj, true, gate_g,
false, static_cast<T>(1.0), weight_g, false, static_cast<T>(1.0), weight_g,
static_cast<T>(1.0)); static_cast<T>(1.0));
} }
if (proj_weight_g) {
/* backward proj weigh */
Tensor hidden_t = batch_hidden->Slice(bstart, bend);
math::matmul<DeviceContext, T>(device_ctx, hidden_t, true, proj_g,
false, static_cast<T>(1.0),
proj_weight_g, static_cast<T>(1.0));
}
} else { } else {
if (h0 && weight_g) { if (h0 && weight_g) {
ReorderInitState<DeviceContext, T>(device_ctx, *h0, order, ReorderInitState<DeviceContext, T>(device_ctx, *h0, order,
...@@ -444,7 +445,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> { ...@@ -444,7 +445,6 @@ class LSTMPGradKernel : public framework::OpKernel<T> {
ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev, ActGradCompute(cell_act, place, proj0_dev, proj0_dev, proj0_g_dev,
proj0_g_dev); proj0_g_dev);
} }
// Tensor proj0_g = proj_g.Slice(bstart, bend);
if (h0_g) { if (h0_g) {
math::matmul<DeviceContext, T>( math::matmul<DeviceContext, T>(
device_ctx, proj0_g, false, *proj_weight, true, device_ctx, proj0_g, false, *proj_weight, true,
......
...@@ -207,8 +207,8 @@ class TestLstmOp(OpTest): ...@@ -207,8 +207,8 @@ class TestLstmOp(OpTest):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias'], ['Projection'], ['Input', 'Weight', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=5e-3) max_relative_error=1e-2)
class TestLstmOpHasInitial(TestLstmOp): class TestLstmOpHasInitial(TestLstmOp):
...@@ -235,8 +235,9 @@ class TestLstmOpHasInitial(TestLstmOp): ...@@ -235,8 +235,9 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias', 'H0', 'C0'], ['Projection'], ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0', 'C0'],
max_relative_error=5e-3) ['Projection'],
max_relative_error=1e-2)
def test_check_grad_ingore_bias(self): def test_check_grad_ingore_bias(self):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
...@@ -246,8 +247,8 @@ class TestLstmOpHasInitial(TestLstmOp): ...@@ -246,8 +247,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight'], ['Projection'], ['Input', 'ProjWeight', 'Weight'], ['Projection'],
max_relative_error=5e-3, max_relative_error=1e-2,
no_grad_set=set('Bias')) no_grad_set=set('Bias'))
def test_check_grad_ingore_weight(self): def test_check_grad_ingore_weight(self):
...@@ -258,10 +259,22 @@ class TestLstmOpHasInitial(TestLstmOp): ...@@ -258,10 +259,22 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Bias'], ['Projection'], ['Input', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=5e-3, max_relative_error=1e-2,
no_grad_set=set('Weight')) no_grad_set=set('Weight'))
def test_check_grad_ingore_proj_weight(self):
N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
self.outputs['BatchGate'] = np.zeros((N, 4 * self.D)).astype('float64')
self.outputs['BatchHidden'] = np.zeros((N, self.D)).astype('float64')
self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64')
self.check_grad(
['Input', 'Weight', 'Bias'], ['Projection'],
max_relative_error=1e-2,
no_grad_set=set('ProjWeight'))
def test_check_grad_ingore_input(self): def test_check_grad_ingore_input(self):
N = len(self.lod[0]) - 1 N = len(self.lod[0]) - 1
self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64') self.outputs['OrderedP0'] = np.zeros((N, self.P)).astype('float64')
...@@ -270,8 +283,8 @@ class TestLstmOpHasInitial(TestLstmOp): ...@@ -270,8 +283,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Weight', 'Bias'], ['Projection'], ['Weight', 'ProjWeight', 'Bias'], ['Projection'],
max_relative_error=5e-3, max_relative_error=1e-2,
no_grad_set=set('Input')) no_grad_set=set('Input'))
def test_check_grad_ingore_h0(self): def test_check_grad_ingore_h0(self):
...@@ -282,8 +295,8 @@ class TestLstmOpHasInitial(TestLstmOp): ...@@ -282,8 +295,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias', 'C0'], ['Projection'], ['Input', 'Weight', 'ProjWeight', 'Bias', 'C0'], ['Projection'],
max_relative_error=5e-3, max_relative_error=1e-2,
no_grad_set=set('H0')) no_grad_set=set('H0'))
def test_check_grad_ingore_c0(self): def test_check_grad_ingore_c0(self):
...@@ -294,8 +307,8 @@ class TestLstmOpHasInitial(TestLstmOp): ...@@ -294,8 +307,8 @@ class TestLstmOpHasInitial(TestLstmOp):
self.outputs['BatchCellPreAct'] = np.zeros( self.outputs['BatchCellPreAct'] = np.zeros(
(N, self.D)).astype('float64') (N, self.D)).astype('float64')
self.check_grad( self.check_grad(
['Input', 'Weight', 'Bias', 'H0'], ['Projection'], ['Input', 'Weight', 'ProjWeight', 'Bias', 'H0'], ['Projection'],
max_relative_error=5e-3, max_relative_error=1e-2,
no_grad_set=set('C0')) no_grad_set=set('C0'))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册