未验证 提交 778b981e 编写于 作者: G Guo Sheng 提交者: GitHub

Merge pull request #5804 from guoshengCS/fix-GRUUnitOp-dev

Fix calculations in gru_unit_op to consistent with gru_op 
...@@ -114,18 +114,19 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -114,18 +114,19 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(sigmoid) .SetDefault(sigmoid)
.InEnum({identity, sigmoid, tanh, relu}); .InEnum({identity, sigmoid, tanh, relu});
AddComment(R"DOC( AddComment(R"DOC(
GRUUnit Operator. GRUUnit Operator implements partial calculations of the GRU unit as following:
This operator implements partial calculations of the GRU unit as follows:
$$ $$
update \ gate: u_t = actGate(xu_t + W_u * hidden_{prev} + bias_u) \\ update \ gate: u_t = actGate(xu_t + W_u * h_{t-1} + b_u) \\
reset \ gate: r_t = actGate(xr_t + W_r * hidden_{prev} + bias_r) \\ reset \ gate: r_t = actGate(xr_t + W_r * h_{t-1} + b_r) \\
output \ candidate: {h}_t = actNode({xc}_t + W_c * dot(r_t, hidden_{prev}) + bias_c) \\ output \ candidate: {h}_t = actNode(xc_t + W_c * dot(r_t, h_{t-1}) + b_c) \\
output: h_t = dot((1-u_t), {h}_t) + dot(u_t, hidden_{prev}) output: h_t = dot((1 - u_t), h_{t-1}) + dot(u_t, {h}_t)
$$ $$
The rest of GRU unit can be completed by using FCOp's output as the input of GRUUnitOp. which is same as one time step of GRU Operator.
@note To implement the complete GRU unit, fully-connected operator must be
used before to feed xu, xr and xc as the Input of GRUUnit operator.
)DOC"); )DOC");
} }
...@@ -150,12 +151,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel { ...@@ -150,12 +151,6 @@ class GRUUnitGradOp : public framework::OperatorWithKernel {
"ResetHiddenPrev"); "ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput("Hidden"), PADDLE_ENFORCE(ctx->HasInput("Hidden"),
"Input(%s) of GRUUnitGradOp should not be null.", "Hidden"); "Input(%s) of GRUUnitGradOp should not be null.", "Hidden");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Gate")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"Gate");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("ResetHiddenPrev")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"ResetHiddenPrev");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")), PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Hidden")),
"Input(%s@GRAD) of GRUUnitGradOp should not be null.", "Input(%s@GRAD) of GRUUnitGradOp should not be null.",
"Hidden"); "Hidden");
......
...@@ -110,7 +110,7 @@ class GRUUnitKernel : public framework::OpKernel<T> { ...@@ -110,7 +110,7 @@ class GRUUnitKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output // calculate final output
h.device(place) = u * (h_p - c) + c; h.device(place) = u * (c - h_p) + h_p;
} }
}; };
...@@ -146,35 +146,27 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -146,35 +146,27 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
auto* weight_grad = auto* weight_grad =
context.Output<Tensor>(framework::GradVarName("Weight")); context.Output<Tensor>(framework::GradVarName("Weight"));
auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias")); auto* bias_grad = context.Output<Tensor>(framework::GradVarName("Bias"));
input_grad->mutable_data<T>(context.GetPlace());
hidden_prev_grad->mutable_data<T>(context.GetPlace());
weight_grad->mutable_data<T>(context.GetPlace());
Tensor gate_grad; Tensor gate_grad;
gate_grad.mutable_data<T>(input->dims(), context.GetPlace());
Tensor reset_hidden_prev_grad; Tensor reset_hidden_prev_grad;
reset_hidden_prev_grad.mutable_data<T>(reset_hidden_prev->dims(),
context.GetPlace());
int batch_size = input->dims()[0];
int frame_size = hidden_prev->dims()[1];
const T* hidden_prev_data = hidden_prev->data<T>(); const T* hidden_prev_data = hidden_prev->data<T>();
T* hidden_prev_grad_data = hidden_prev_grad->data<T>();
const T* weight_data = weight->data<T>(); const T* weight_data = weight->data<T>();
T* weight_grad_data = weight_grad->data<T>(); T* gate_grad_data =
T* gate_grad_data = gate_grad.data<T>(); gate_grad.mutable_data<T>(input->dims(), context.GetPlace());
const T* reset_hidden_prev_data = reset_hidden_prev->data<T>(); const T* reset_hidden_prev_data = reset_hidden_prev->data<T>();
T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.data<T>(); T* reset_hidden_prev_grad_data = reset_hidden_prev_grad.mutable_data<T>(
reset_hidden_prev->dims(), context.GetPlace());
auto h_p = EigenMatrix<T>::From(*hidden_prev); auto h_p = EigenMatrix<T>::From(*hidden_prev);
auto g = EigenMatrix<T>::From(*gate); auto g = EigenMatrix<T>::From(*gate);
auto d_h = EigenMatrix<T>::From(*hidden_grad); auto d_h = EigenMatrix<T>::From(*hidden_grad);
auto d_x = EigenMatrix<T>::From(*input_grad);
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
auto d_g = EigenMatrix<T>::From(gate_grad); auto d_g = EigenMatrix<T>::From(gate_grad);
auto d_r_h_p = EigenMatrix<T>::From(reset_hidden_prev_grad); auto d_r_h_p = EigenMatrix<T>::From(reset_hidden_prev_grad);
auto place = context.GetEigenDevice<Place>(); auto place = context.GetEigenDevice<Place>();
int batch_size = input->dims()[0];
int frame_size = hidden_prev->dims()[1];
Eigen::array<int, 2> extents({{batch_size, frame_size}}); Eigen::array<int, 2> extents({{batch_size, frame_size}});
Eigen::array<int, 2> u_offsets({{0, 0}}); Eigen::array<int, 2> u_offsets({{0, 0}});
auto u = g.slice(u_offsets, extents); // update gate auto u = g.slice(u_offsets, extents); // update gate
...@@ -185,38 +177,52 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -185,38 +177,52 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
// backward for unactivated update gate // backward for unactivated update gate
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u, ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (h_p - c)); d_g.slice(u_offsets, extents), d_h * (c - h_p));
// backward for unactivated output candidate // backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c, ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * (u.constant(T(1)) - u)); d_g.slice(c_offsets, extents), d_h * u);
// backward for reset_hidden_prev // backward for reset_hidden_prev
math::gemm<Place, T>(context.device_context(), false, true, batch_size, math::gemm<Place, T>(context.device_context(), false, true, batch_size,
frame_size, frame_size, 1, frame_size, frame_size, 1,
gate_grad_data + frame_size * 2, frame_size * 3, gate_grad_data + frame_size * 2, frame_size * 3,
weight_data + frame_size * frame_size * 2, frame_size, weight_data + frame_size * frame_size * 2, frame_size,
0, reset_hidden_prev_grad_data, frame_size); 0, reset_hidden_prev_grad_data, frame_size);
// backward for state_weight
math::gemm<Place, T>(
context.device_context(), true, false, frame_size, frame_size,
batch_size, 1, reset_hidden_prev_data, frame_size,
gate_grad_data + frame_size * 2, frame_size * 3, 0,
weight_grad_data + frame_size * frame_size * 2, frame_size);
// backward for unactivated reset gate // backward for unactivated reset gate
ActGradCompute(context.Attr<int>("gate_activation"), place, r, r, ActGradCompute(context.Attr<int>("gate_activation"), place, r, r,
d_g.slice(r_offsets, extents), d_r_h_p * h_p); d_g.slice(r_offsets, extents), d_r_h_p * h_p);
// backward for update_gate_weight and reset_gate_weight // backward for weight
math::gemm<Place, T>(context.device_context(), true, false, frame_size, if (weight_grad) {
frame_size * 2, batch_size, 1, hidden_prev_data, T* weight_grad_data = weight_grad->mutable_data<T>(context.GetPlace());
frame_size, gate_grad_data, frame_size * 3, 0, // backward for state_weight
weight_grad_data, frame_size * 2); math::gemm<Place, T>(
context.device_context(), true, false, frame_size, frame_size,
batch_size, 1, reset_hidden_prev_data, frame_size,
gate_grad_data + frame_size * 2, frame_size * 3, 0,
weight_grad_data + frame_size * frame_size * 2, frame_size);
// backward for update_gate_weight and reset_gate_weight
math::gemm<Place, T>(context.device_context(), true, false, frame_size,
frame_size * 2, batch_size, 1, hidden_prev_data,
frame_size, gate_grad_data, frame_size * 3, 0,
weight_grad_data, frame_size * 2);
}
// backward for hidden_prev // backward for hidden_prev
d_h_p.device(place) = d_r_h_p * r + d_h * u; if (hidden_prev_grad) {
math::gemm<Place, T>(context.device_context(), false, true, batch_size, T* hidden_prev_grad_data =
frame_size, frame_size * 2, 1, gate_grad_data, hidden_prev_grad->mutable_data<T>(context.GetPlace());
frame_size * 3, weight_data, frame_size * 2, 1, auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
hidden_prev_grad_data, frame_size); d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
math::gemm<Place, T>(context.device_context(), false, true, batch_size,
frame_size, frame_size * 2, 1, gate_grad_data,
frame_size * 3, weight_data, frame_size * 2, 1,
hidden_prev_grad_data, frame_size);
}
// backward for input // backward for input
d_x.device(place) = d_g; if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
auto d_x = EigenMatrix<T>::From(*input_grad);
d_x.device(place) = d_g;
}
// backward for bias // backward for bias
if (bias_grad) { if (bias_grad) {
bias_grad->mutable_data<T>(context.GetPlace()); bias_grad->mutable_data<T>(context.GetPlace());
......
...@@ -28,8 +28,8 @@ def relu(x): ...@@ -28,8 +28,8 @@ def relu(x):
class TestGRUUnitOp(OpTest): class TestGRUUnitOp(OpTest):
batch_size = 3 batch_size = 5
frame_size = 5 frame_size = 10
activate = { activate = {
GRUActivationType.identity: identity, GRUActivationType.identity: identity,
GRUActivationType.sigmoid: sigmoid, GRUActivationType.sigmoid: sigmoid,
...@@ -77,7 +77,7 @@ class TestGRUUnitOp(OpTest): ...@@ -77,7 +77,7 @@ class TestGRUUnitOp(OpTest):
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) + c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:]) g[:, frame_size * 2:])
g = np.hstack((u_r, c)) g = np.hstack((u_r, c))
h = u * h_p + (1 - u) * c h = u * c + (1 - u) * h_p
self.outputs = { self.outputs = {
'Gate': g.astype('float64'), 'Gate': g.astype('float64'),
'ResetHiddenPrev': r_h_p.astype('float64'), 'ResetHiddenPrev': r_h_p.astype('float64'),
...@@ -92,10 +92,7 @@ class TestGRUUnitOp(OpTest): ...@@ -92,10 +92,7 @@ class TestGRUUnitOp(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad( self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden'])
['Input', 'HiddenPrev', 'Weight'],
['Hidden', 'ResetHiddenPrev', 'Gate'],
max_relative_error=0.007)
class TestGRUUnitOpWithBias(TestGRUUnitOp): class TestGRUUnitOpWithBias(TestGRUUnitOp):
...@@ -104,18 +101,20 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp): ...@@ -104,18 +101,20 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
frame_size = self.frame_size frame_size = self.frame_size
super(TestGRUUnitOpWithBias, self).set_inputs() super(TestGRUUnitOpWithBias, self).set_inputs()
self.inputs['Bias'] = np.random.uniform( self.inputs['Bias'] = np.random.uniform(
-0.1, 0.1, (1, frame_size * 3)).astype('float32') -0.1, 0.1, (1, frame_size * 3)).astype('float64')
self.attrs = { self.attrs = {
'activation': GRUActivationType.identity, 'activation': GRUActivationType.identity,
'gate_activation': GRUActivationType.sigmoid 'gate_activation': GRUActivationType.sigmoid
} }
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'])
def test_check_grad_ingore_input(self):
self.check_grad( self.check_grad(
['Input', 'HiddenPrev', 'Weight', 'Bias'], ['Hidden'], ['HiddenPrev', 'Weight', 'Bias'], ['Hidden'],
max_relative_error=0.007) no_grad_set=set('Input'))
if __name__ == '__main__': if __name__ == '__main__':
exit(0) # FIXME(yuyang18): This unittest is not pass. Fix it later
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册