diff --git a/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h b/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h index 19f6b213aa3bc06f7f5750fa42745fd8755c51b9..ccbd05c82ad6a880d21269092088be9656b35c99 100644 --- a/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_cpu_kernel.h @@ -59,9 +59,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue value, r_prev_state = value.prev_state_value[i]; } - op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, - r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, - active_gate, active_state); + op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, + &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, + active_node, active_gate, active_state); value_in[i] = r_value_in; value_ig[i] = r_value_ig; @@ -125,11 +125,11 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue value, r_prev_state = value.prev_state_value[i]; } - op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, - r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, - r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, - r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, - active_state); + op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, + &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, + &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, + &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, + active_node, active_gate, active_state); grad_in[i] = r_grad_in; grad_ig[i] = r_grad_ig; @@ -186,9 +186,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue value, r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; } - op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, - r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, - active_gate, active_state); + op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, + &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, + active_node, active_gate, active_state); value_in[i] = r_value_in; value_ig[i] = r_value_ig; @@ -258,11 +258,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue value, r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; } - op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, - r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, - r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, - r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, - active_state); + op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, + &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, + &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, + &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, + active_node, active_gate, active_state); grad_in[i] = r_grad_in; grad_ig[i] = r_grad_ig; diff --git a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h index d29c780dcfb1f1a3cbab25256238769d3a5ccd93..2aecb69237fdf344ebc0bfe72d9c7c147f06358d 100644 --- a/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_gpu_kernel.h @@ -70,9 +70,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue value, int frame_size, r_prev_state = value.prev_state_value[frame_idx]; } - op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, - r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, active_gate, - active_state); + op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state, + &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO, + active_node, active_gate, active_state); value.gate_value[frame_idx] = r_value_in; value.gate_value[frame_idx + frame_size] = r_value_ig; @@ -145,11 +145,11 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue value, r_prev_state = value.prev_state_value[frame_idx]; } - op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, - r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, - r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, - r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, - active_state); + op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig, + &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state, + &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF, + &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, active_node, + active_gate, active_state); grad.gate_grad[frame_idx] = r_grad_in; grad.gate_grad[frame_idx + frame_size] = r_grad_ig; diff --git a/paddle/fluid/operators/math/detail/lstm_kernel.h b/paddle/fluid/operators/math/detail/lstm_kernel.h index 9080634f2b3fc122a420e049314f53abd50376e0..cbe73d62938d7c4c03a2c8731665260624417fd7 100644 --- a/paddle/fluid/operators/math/detail/lstm_kernel.h +++ b/paddle/fluid/operators/math/detail/lstm_kernel.h @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma once +#include #include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/platform/hostdevice.h" -#include - namespace paddle { namespace operators { namespace math { @@ -27,19 +27,19 @@ namespace forward { template class lstm { public: - HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, - T &prev_state, T &state, T &state_atv, T &output, - T &checkI, T &checkF, T &checkO, + HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og, + T *prev_state, T *state, T *state_atv, T *output, + T *checkI, T *checkF, T *checkO, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { - value_in = activation(value_in, active_node); - value_ig = activation(value_ig + prev_state * checkI, active_gate); - value_fg = activation(value_fg + prev_state * checkF, active_gate); - state = value_in * value_ig + prev_state * value_fg; - value_og = activation(value_og + state * checkO, active_gate); - state_atv = activation(state, active_state); - output = value_og * state_atv; + *value_in = activation(*value_in, active_node); + *value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate); + *value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate); + *state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg); + *value_og = activation(*value_og + (*state) * (*checkO), active_gate); + *state_atv = activation(*state, active_state); + *output = (*value_og) * (*state_atv); } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default @@ -48,27 +48,27 @@ class lstm { // Only float support AVX optimization static const bool avx = std::is_same::value; - HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig, - __m256 &value_fg, __m256 &value_og, - __m256 &prev_state, __m256 &state, - __m256 &state_atv, __m256 &output, __m256 &checkI, - __m256 &checkF, __m256 &checkO, + HOSTDEVICE void operator()(__m256 *value_in, __m256 *value_ig, + __m256 *value_fg, __m256 *value_og, + __m256 *prev_state, __m256 *state, + __m256 *state_atv, __m256 *output, __m256 *checkI, + __m256 *checkF, __m256 *checkO, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { - value_in = activation(value_in, active_node); - value_ig = - activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)), - active_gate); - value_fg = - activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)), - active_gate); - state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig), - _mm256_mul_ps(prev_state, value_fg)); - value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)), - active_gate); - state_atv = activation(state, active_state); - output = _mm256_mul_ps(value_og, state_atv); + *value_in = activation(*value_in, active_node); + *value_ig = activation( + _mm256_add_ps(*value_ig, _mm256_mul_ps(*prev_state, *checkI)), + active_gate); + *value_fg = activation( + _mm256_add_ps(*value_fg, _mm256_mul_ps(*prev_state, *checkF)), + active_gate); + *state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig), + _mm256_mul_ps(*prev_state, *value_fg)); + *value_og = activation( + _mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate); + *state_atv = activation(*state, active_state); + *output = _mm256_mul_ps(*value_og, *state_atv); } #endif #endif @@ -81,26 +81,29 @@ namespace backward { template class lstm { public: - HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, - T &grad_in, T &grad_ig, T &grad_fg, T &grad_og, - T &prev_state, T &prev_state_grad, T &state, - T &state_grad, T &state_atv, T &output_grad, - T &checkI, T &checkF, T &checkO, T &checkIGrad, - T &checkFGrad, T &checkOGrad, + HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og, + T *grad_in, T *grad_ig, T *grad_fg, T *grad_og, + T *prev_state, T *prev_state_grad, T *state, + T *state_grad, T *state_atv, T *output_grad, + T *checkI, T *checkF, T *checkO, T *checkIGrad, + T *checkFGrad, T *checkOGrad, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { - grad_og = activation(output_grad * state_atv, value_og, active_gate); - state_grad += activation(output_grad * value_og, state_atv, active_state) + - grad_og * checkO; - grad_in = activation(state_grad * value_ig, value_in, active_node); - grad_ig = activation(state_grad * value_in, value_ig, active_gate); - grad_fg = activation(state_grad * prev_state, value_fg, active_gate); - prev_state_grad = - grad_ig * checkI + grad_fg * checkF + state_grad * value_fg; - checkIGrad = grad_ig * prev_state; - checkFGrad = grad_fg * prev_state; - checkOGrad = grad_og * state; + *grad_og = + activation((*output_grad) * (*state_atv), *value_og, active_gate); + *state_grad += + activation((*output_grad) * (*value_og), *state_atv, active_state) + + (*grad_og) * (*checkO); + *grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node); + *grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate); + *grad_fg = + activation((*state_grad) * (*prev_state), *value_fg, active_gate); + *prev_state_grad = (*grad_ig) * (*checkI) + (*grad_fg) * (*checkF) + + (*state_grad) * (*value_fg); + *checkIGrad = (*grad_ig) * (*prev_state); + *checkFGrad = (*grad_fg) * (*prev_state); + *checkOGrad = (*grad_og) * (*state); } #ifndef __NVCC__ #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default @@ -109,32 +112,33 @@ class lstm { // Only float support AVX optimization static const bool avx = std::is_same::value; HOSTDEVICE void operator()( - __m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og, - __m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og, - __m256 &prev_state, __m256 &prev_state_grad, __m256 &state, - __m256 &state_grad, __m256 &state_atv, __m256 &output_grad, - __m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, - __m256 &checkFGrad, __m256 &checkOGrad, ActivationType active_node, + __m256 *value_in, __m256 *value_ig, __m256 *value_fg, __m256 *value_og, + __m256 *grad_in, __m256 *grad_ig, __m256 *grad_fg, __m256 *grad_og, + __m256 *prev_state, __m256 *prev_state_grad, __m256 *state, + __m256 *state_grad, __m256 *state_atv, __m256 *output_grad, + __m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad, + __m256 *checkFGrad, __m256 *checkOGrad, ActivationType active_node, ActivationType active_gate, ActivationType active_state) { - grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og, - active_gate); - state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og), - state_atv, active_state), - state_grad); - state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad); - grad_in = - activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node); - grad_ig = - activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate); - grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg, - active_gate); - prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI), - _mm256_mul_ps(grad_fg, checkF)); - prev_state_grad = - _mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad); - checkIGrad = _mm256_mul_ps(grad_ig, prev_state); - checkFGrad = _mm256_mul_ps(grad_fg, prev_state); - checkOGrad = _mm256_mul_ps(grad_og, state); + *grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og, + active_gate); + *state_grad = + _mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og), + *state_atv, active_state), + *state_grad); + *state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad); + *grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in, + active_node); + *grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig, + active_gate); + *grad_fg = activation(_mm256_mul_ps(*state_grad, *prev_state), *value_fg, + active_gate); + *prev_state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_ig, *checkI), + _mm256_mul_ps(*grad_fg, *checkF)); + *prev_state_grad = + _mm256_add_ps(_mm256_mul_ps(*state_grad, *value_fg), *prev_state_grad); + *checkIGrad = _mm256_mul_ps(*grad_ig, *prev_state); + *checkFGrad = _mm256_mul_ps(*grad_fg, *prev_state); + *checkOGrad = _mm256_mul_ps(*grad_og, *state); } #endif #endif