提交 038c16ee 编写于 作者: T tensor-tang

save intermediate data to out buffer

上级 2d0ddf8c
...@@ -266,25 +266,24 @@ class FusionGRUKernel : public framework::OpKernel<T> { ...@@ -266,25 +266,24 @@ class FusionGRUKernel : public framework::OpKernel<T> {
batched_input_data, D3); batched_input_data, D3);
T* cur_batched_data = batched_input_data; T* cur_batched_data = batched_input_data;
T* cur_out_data = batched_out_data;
T* cur_prev_hidden_data = prev_hidden_data; T* cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
act_gate(D2, cur_batched_data, cur_batched_data); act_gate(D2, cur_batched_data, cur_batched_data);
// rt = rt*ht_1 inplace result // rt = rt*ht_1 inplace result
// TODO(TJ): try to save to cur out data blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D, cur_out_data);
// maybe get benifits avoiding cache miss in next gemm
blas.VMUL(D, cur_prev_hidden_data, cur_batched_data + D,
cur_batched_data + D);
cur_batched_data += D3; cur_batched_data += D3;
cur_prev_hidden_data += D; cur_prev_hidden_data += D;
cur_out_data += D;
} }
cur_batched_data = batched_input_data; cur_batched_data = batched_input_data;
cur_out_data = batched_out_data;
blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D, D, static_cast<T>(1), blas.GEMM(CblasNoTrans, CblasNoTrans, cur_bs, D, D, static_cast<T>(1),
cur_batched_data + D, D3, wh_state_data, D, static_cast<T>(1), cur_out_data, D, wh_state_data, D, static_cast<T>(1),
cur_batched_data + D2, D3); cur_batched_data + D2, D3);
T* cur_out_data = batched_out_data;
cur_prev_hidden_data = prev_hidden_data; cur_prev_hidden_data = prev_hidden_data;
for (int i = 0; i < cur_bs; ++i) { for (int i = 0; i < cur_bs; ++i) {
// ht~ = act_state(...) // ht~ = act_state(...)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册