From 87086b1386083a2f7479585b966f3ca82d7a9012 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Thu, 20 Sep 2018 09:51:20 +0800 Subject: [PATCH] Refine activation for GRU operator (#13275) * Optimize GRU with AVX instruction * Clean code * Add the Unitest and fix the align issue * Remove the remanent part of the unitest part * Code clean * Fix the parameters length issue for fusion_gru to pass CI * Change the default type as float32 --- .../operators/math/detail/gru_cpu_kernel.h | 114 ++++++++++++++---- .../fluid/tests/unittests/test_gru_op.py | 32 +++-- 2 files changed, 109 insertions(+), 37 deletions(-) diff --git a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h index b6f4ab93777..47c771f7c5c 100644 --- a/paddle/fluid/operators/math/detail/gru_cpu_kernel.h +++ b/paddle/fluid/operators/math/detail/gru_cpu_kernel.h @@ -85,26 +85,59 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, T *prev_output_value, int frame_size, ActivationType active_gate) { #ifdef __AVX__ - __m256 r_value_update_gate; - __m256 r_value_reset_gate; + __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); + __m256 r_value_reset_gate, r_value_reset_gate_last = _mm256_set1_ps(0.0f); __m256 r_value_reset_output; - __m256 r_prev_out = _mm256_set1_ps(0.0f); - __m256 *update_gate = reinterpret_cast<__m256 *>(gate_value); - __m256 *reset_gate = reinterpret_cast<__m256 *>(gate_value + frame_size); + __m256 r_prev_out = _mm256_set1_ps(0.0f), + r_prev_out_last = _mm256_set1_ps(0.0f); + T *update_gate = gate_value; + T *reset_gate = gate_value + frame_size; + int block = 8; + const int n = frame_size; + const int rest = n % block; + const int end = n - rest; + int i = 0; + + if (rest > 0) { + i = n - block; + r_value_update_gate_last = + _mm256_loadu_ps((const float *)(update_gate + i)); + r_value_reset_gate_last = _mm256_loadu_ps((const float *)(reset_gate + i)); + if (prev_output_value) { + r_prev_out_last = _mm256_loadu_ps((const float *)(prev_output_value + i)); + } + } - for (int i = 0; i < frame_size / 8; i++) { - r_value_update_gate = update_gate[i]; - r_value_reset_gate = reset_gate[i]; + for (i = 0; i < end; i += block) { + r_value_update_gate = _mm256_loadu_ps((const float *)(update_gate + i)); + r_value_reset_gate = _mm256_loadu_ps((const float *)(reset_gate + i)); if (prev_output_value) { - r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i]; + r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i)); } op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out, &r_value_reset_output, active_gate); - update_gate[i] = r_value_update_gate; - reset_gate[i] = r_value_reset_gate; - (reinterpret_cast<__m256 *>(reset_output_value))[i] = r_value_reset_output; + _mm256_storeu_ps(reinterpret_cast(update_gate + i), + r_value_update_gate); + _mm256_storeu_ps(reinterpret_cast(reset_gate + i), + r_value_reset_gate); + _mm256_storeu_ps(reinterpret_cast(reset_output_value + i), + r_value_reset_output); + } + + if (rest > 0) { + i = n - block; + + op_reset_output(&r_value_update_gate_last, &r_value_reset_gate_last, + &r_prev_out_last, &r_value_reset_output, active_gate); + + _mm256_storeu_ps(reinterpret_cast(update_gate + i), + r_value_update_gate_last); + _mm256_storeu_ps(reinterpret_cast(reset_gate + i), + r_value_reset_gate_last); + _mm256_storeu_ps(reinterpret_cast(reset_output_value + i), + r_value_reset_output); } #endif } @@ -115,26 +148,55 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, T *output_value, int frame_size, ActivationType active_node) { #ifdef __AVX__ - __m256 r_value_update_gate; - __m256 r_value_frame_state; - __m256 r_prev_out = _mm256_set1_ps(0.0f); + __m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f); + __m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f); + __m256 r_prev_out = _mm256_set1_ps(0.0f), + r_prev_out_last = _mm256_set1_ps(0.0f); __m256 r_output; - __m256 *update_gate = reinterpret_cast<__m256 *>(gate_value); - __m256 *frame_state = reinterpret_cast<__m256 *>(gate_value + frame_size * 2); + T *update_gate = gate_value; + T *frame_state = gate_value + frame_size * 2; + int block = 8; + const int n = frame_size; + const int rest = n % block; + const int end = n - rest; + int i = 0; + + if (rest > 0) { + i = n - block; + r_value_update_gate_last = + _mm256_loadu_ps((const float *)(update_gate + i)); + r_value_frame_state_last = + _mm256_loadu_ps((const float *)(frame_state + i)); + if (prev_output_value) { + r_prev_out_last = _mm256_loadu_ps((const float *)(prev_output_value + i)); + } + } - for (int i = 0; i < frame_size / 8; i++) { - r_value_update_gate = update_gate[i]; - r_value_frame_state = frame_state[i]; + for (i = 0; i < end; i += block) { + r_value_update_gate = _mm256_loadu_ps((const float *)(update_gate + i)); + r_value_frame_state = _mm256_loadu_ps((const float *)(frame_state + i)); if (prev_output_value) { - r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i]; + r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i)); } op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out, &r_output, active_node); - frame_state[i] = r_value_frame_state; - (reinterpret_cast<__m256 *>(output_value))[i] = r_output; + _mm256_storeu_ps(reinterpret_cast(frame_state + i), + r_value_frame_state); + _mm256_storeu_ps(reinterpret_cast(output_value + i), r_output); + } + + if (rest > 0) { + i = n - block; + op_final_output(&r_value_update_gate_last, &r_value_frame_state_last, + &r_prev_out_last, &r_output, active_node); + + _mm256_storeu_ps(reinterpret_cast(frame_state + i), + r_value_frame_state_last); + _mm256_storeu_ps(reinterpret_cast(output_value + i), r_output); } + #endif } @@ -143,7 +205,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output, GRUMetaValue value, int frame_size, int batch_size, ActivationType active_gate) { for (int b = 0; b < batch_size; b++) { - if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { + if (OpResetOutput::avx && (frame_size > static_cast(8 - 1)) && + (sizeof(T) == 4)) { hl_avx_gru_forward_reset_output( op_reset_output, value.gate_value, value.reset_output_value, value.prev_out_value, frame_size, active_gate); @@ -166,7 +229,8 @@ inline void forward_final_output(OpFinalOutput op_final_output, GRUMetaValue value, int frame_size, int batch_size, ActivationType active_node) { for (int b = 0; b < batch_size; b++) { - if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { + if (OpFinalOutput::avx && (frame_size > static_cast(8 - 1)) && + (sizeof(T) == 4)) { hl_avx_gru_forward_final_output(op_final_output, value.gate_value, value.prev_out_value, value.output_value, frame_size, active_node); diff --git a/python/paddle/fluid/tests/unittests/test_gru_op.py b/python/paddle/fluid/tests/unittests/test_gru_op.py index 9f6f03f9cfe..f61a447fd77 100644 --- a/python/paddle/fluid/tests/unittests/test_gru_op.py +++ b/python/paddle/fluid/tests/unittests/test_gru_op.py @@ -30,7 +30,8 @@ def gru( bias, # 1 x 3D is_reverse, act_state, - act_gate): + act_gate, + dtype='float32'): def _seq_to_batch(lod, is_reverse): idx_in_seq_list = [] seq_lens = lod[0] @@ -71,10 +72,10 @@ def gru( T = sum(lod[0]) N = len(lod[0]) D = weight.shape[0] - batch_gate = np.zeros((T, 3 * D), dtype='float64') - batch_reset_hidden_prev = np.zeros((T, D), dtype='float64') - batch_hidden = np.zeros((T, D), dtype='float64') - hidden = np.zeros((T, D), dtype='float64') + batch_gate = np.zeros((T, 3 * D), dtype=dtype) + batch_reset_hidden_prev = np.zeros((T, D), dtype=dtype) + batch_hidden = np.zeros((T, D), dtype=dtype) + hidden = np.zeros((T, D), dtype=dtype) idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse) h_p = h0[sorted_seqs] @@ -108,23 +109,24 @@ class TestGRUOp(OpTest): self.with_bias = True self.act_state = 'tanh' self.act_gate = 'sigmoid' + self.dtype = 'float64' self.set_confs() T = sum(self.lod[0]) N = len(self.lod[0]) - input = np.random.rand(T, 3 * self.D).astype('float64') - weight = np.random.rand(self.D, 3 * self.D).astype('float64') + input = np.random.rand(T, 3 * self.D).astype(self.dtype) + weight = np.random.rand(self.D, 3 * self.D).astype(self.dtype) bias = np.random.rand( - 1, 3 * self.D).astype('float64') if self.with_bias else np.zeros( - (1, 3 * self.D), dtype='float64') + 1, 3 * self.D).astype(self.dtype) if self.with_bias else np.zeros( + (1, 3 * self.D), dtype=self.dtype) h0 = np.random.rand( - N, self.D).astype('float64') if self.with_h0 else np.zeros( - (N, self.D), dtype='float64') + N, self.D).astype(self.dtype) if self.with_h0 else np.zeros( + (N, self.D), dtype=self.dtype) batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru( input, self.lod, h0, weight, bias, self.is_reverse, - ACTIVATION[self.act_state], ACTIVATION[self.act_gate]) + ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype) self.inputs = {'Input': (input, self.lod), 'Weight': weight} if self.with_bias: @@ -153,6 +155,12 @@ class TestGRUOp(OpTest): self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden']) +class TestGRUOp2(TestGRUOp): + def set_confs(self): + self.D = 19 + self.dtype = 'float32' + + class TestGRUOpNoInitial(TestGRUOp): def set_confs(self): self.with_h0 = False -- GitLab