提交 87086b13 编写于 作者: Y Yihua Xu 提交者: tensor-tang

Refine activation for GRU operator (#13275)

* Optimize GRU with AVX instruction

* Clean code

* Add the Unitest and fix the align issue

* Remove the remanent part of the unitest part

* Code clean

* Fix the parameters length issue for fusion_gru to pass CI

* Change the default type as float32
上级 d402234b
......@@ -85,26 +85,59 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
T *prev_output_value, int frame_size,
ActivationType active_gate) {
#ifdef __AVX__
__m256 r_value_update_gate;
__m256 r_value_reset_gate;
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_reset_gate, r_value_reset_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_reset_output;
__m256 r_prev_out = _mm256_set1_ps(0.0f);
__m256 *update_gate = reinterpret_cast<__m256 *>(gate_value);
__m256 *reset_gate = reinterpret_cast<__m256 *>(gate_value + frame_size);
__m256 r_prev_out = _mm256_set1_ps(0.0f),
r_prev_out_last = _mm256_set1_ps(0.0f);
T *update_gate = gate_value;
T *reset_gate = gate_value + frame_size;
int block = 8;
const int n = frame_size;
const int rest = n % block;
const int end = n - rest;
int i = 0;
if (rest > 0) {
i = n - block;
r_value_update_gate_last =
_mm256_loadu_ps((const float *)(update_gate + i));
r_value_reset_gate_last = _mm256_loadu_ps((const float *)(reset_gate + i));
if (prev_output_value) {
r_prev_out_last = _mm256_loadu_ps((const float *)(prev_output_value + i));
}
}
for (int i = 0; i < frame_size / 8; i++) {
r_value_update_gate = update_gate[i];
r_value_reset_gate = reset_gate[i];
for (i = 0; i < end; i += block) {
r_value_update_gate = _mm256_loadu_ps((const float *)(update_gate + i));
r_value_reset_gate = _mm256_loadu_ps((const float *)(reset_gate + i));
if (prev_output_value) {
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i));
}
op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
&r_value_reset_output, active_gate);
update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate;
(reinterpret_cast<__m256 *>(reset_output_value))[i] = r_value_reset_output;
_mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
r_value_update_gate);
_mm256_storeu_ps(reinterpret_cast<float *>(reset_gate + i),
r_value_reset_gate);
_mm256_storeu_ps(reinterpret_cast<float *>(reset_output_value + i),
r_value_reset_output);
}
if (rest > 0) {
i = n - block;
op_reset_output(&r_value_update_gate_last, &r_value_reset_gate_last,
&r_prev_out_last, &r_value_reset_output, active_gate);
_mm256_storeu_ps(reinterpret_cast<float *>(update_gate + i),
r_value_update_gate_last);
_mm256_storeu_ps(reinterpret_cast<float *>(reset_gate + i),
r_value_reset_gate_last);
_mm256_storeu_ps(reinterpret_cast<float *>(reset_output_value + i),
r_value_reset_output);
}
#endif
}
......@@ -115,26 +148,55 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *output_value, int frame_size,
ActivationType active_node) {
#ifdef __AVX__
__m256 r_value_update_gate;
__m256 r_value_frame_state;
__m256 r_prev_out = _mm256_set1_ps(0.0f);
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
__m256 r_prev_out = _mm256_set1_ps(0.0f),
r_prev_out_last = _mm256_set1_ps(0.0f);
__m256 r_output;
__m256 *update_gate = reinterpret_cast<__m256 *>(gate_value);
__m256 *frame_state = reinterpret_cast<__m256 *>(gate_value + frame_size * 2);
T *update_gate = gate_value;
T *frame_state = gate_value + frame_size * 2;
int block = 8;
const int n = frame_size;
const int rest = n % block;
const int end = n - rest;
int i = 0;
if (rest > 0) {
i = n - block;
r_value_update_gate_last =
_mm256_loadu_ps((const float *)(update_gate + i));
r_value_frame_state_last =
_mm256_loadu_ps((const float *)(frame_state + i));
if (prev_output_value) {
r_prev_out_last = _mm256_loadu_ps((const float *)(prev_output_value + i));
}
}
for (int i = 0; i < frame_size / 8; i++) {
r_value_update_gate = update_gate[i];
r_value_frame_state = frame_state[i];
for (i = 0; i < end; i += block) {
r_value_update_gate = _mm256_loadu_ps((const float *)(update_gate + i));
r_value_frame_state = _mm256_loadu_ps((const float *)(frame_state + i));
if (prev_output_value) {
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
r_prev_out = _mm256_loadu_ps((const float *)(prev_output_value + i));
}
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node);
frame_state[i] = r_value_frame_state;
(reinterpret_cast<__m256 *>(output_value))[i] = r_output;
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state);
_mm256_storeu_ps(reinterpret_cast<float *>(output_value + i), r_output);
}
if (rest > 0) {
i = n - block;
op_final_output(&r_value_update_gate_last, &r_value_frame_state_last,
&r_prev_out_last, &r_output, active_node);
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state_last);
_mm256_storeu_ps(reinterpret_cast<float *>(output_value + i), r_output);
}
#endif
}
......@@ -143,7 +205,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_gate) {
for (int b = 0; b < batch_size; b++) {
if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
if (OpResetOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) {
hl_avx_gru_forward_reset_output(
op_reset_output, value.gate_value, value.reset_output_value,
value.prev_out_value, frame_size, active_gate);
......@@ -166,7 +229,8 @@ inline void forward_final_output(OpFinalOutput op_final_output,
GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node) {
for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.output_value,
frame_size, active_node);
......
......@@ -30,7 +30,8 @@ def gru(
bias, # 1 x 3D
is_reverse,
act_state,
act_gate):
act_gate,
dtype='float32'):
def _seq_to_batch(lod, is_reverse):
idx_in_seq_list = []
seq_lens = lod[0]
......@@ -71,10 +72,10 @@ def gru(
T = sum(lod[0])
N = len(lod[0])
D = weight.shape[0]
batch_gate = np.zeros((T, 3 * D), dtype='float64')
batch_reset_hidden_prev = np.zeros((T, D), dtype='float64')
batch_hidden = np.zeros((T, D), dtype='float64')
hidden = np.zeros((T, D), dtype='float64')
batch_gate = np.zeros((T, 3 * D), dtype=dtype)
batch_reset_hidden_prev = np.zeros((T, D), dtype=dtype)
batch_hidden = np.zeros((T, D), dtype=dtype)
hidden = np.zeros((T, D), dtype=dtype)
idx_in_seq_list, sorted_seqs = _seq_to_batch(lod, is_reverse)
h_p = h0[sorted_seqs]
......@@ -108,23 +109,24 @@ class TestGRUOp(OpTest):
self.with_bias = True
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.dtype = 'float64'
self.set_confs()
T = sum(self.lod[0])
N = len(self.lod[0])
input = np.random.rand(T, 3 * self.D).astype('float64')
weight = np.random.rand(self.D, 3 * self.D).astype('float64')
input = np.random.rand(T, 3 * self.D).astype(self.dtype)
weight = np.random.rand(self.D, 3 * self.D).astype(self.dtype)
bias = np.random.rand(
1, 3 * self.D).astype('float64') if self.with_bias else np.zeros(
(1, 3 * self.D), dtype='float64')
1, 3 * self.D).astype(self.dtype) if self.with_bias else np.zeros(
(1, 3 * self.D), dtype=self.dtype)
h0 = np.random.rand(
N, self.D).astype('float64') if self.with_h0 else np.zeros(
(N, self.D), dtype='float64')
N, self.D).astype(self.dtype) if self.with_h0 else np.zeros(
(N, self.D), dtype=self.dtype)
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
input, self.lod, h0, weight, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate])
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype)
self.inputs = {'Input': (input, self.lod), 'Weight': weight}
if self.with_bias:
......@@ -153,6 +155,12 @@ class TestGRUOp(OpTest):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOp2(TestGRUOp):
def set_confs(self):
self.D = 19
self.dtype = 'float32'
class TestGRUOpNoInitial(TestGRUOp):
def set_confs(self):
self.with_h0 = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册