提交 e10af895 编写于 作者: Q Qiao Longfei

update gru grad op

test=develop
上级 78ec7c0f
...@@ -113,8 +113,7 @@ class GRUUnitKernel : public framework::OpKernel<T> { ...@@ -113,8 +113,7 @@ class GRUUnitKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output // calculate final output
bool origin_mode = context.Attr<bool>("origin_mode"); if (context.Attr<bool>("origin_mode")) {
if (origin_mode) {
h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p
} else { } else {
h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p
...@@ -218,7 +217,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> { ...@@ -218,7 +217,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
T* hidden_prev_grad_data = T* hidden_prev_grad_data =
hidden_prev_grad->mutable_data<T>(context.GetPlace()); hidden_prev_grad->mutable_data<T>(context.GetPlace());
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad); auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
if (context.Attr<bool>("origin_mode")) {
d_h_p.device(place) = d_r_h_p * (u.constant(T(1)) - u) + d_h * r;
} else {
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
}
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1, gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1,
hidden_prev_grad_data, frame_size); hidden_prev_grad_data, frame_size);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册