提交 b6528216 编写于 作者: S Siddharth Goyal 提交者: Abhinav Arora

Fix cpplint errors in lstm kernel (#10394)

上级 bd66eed5
...@@ -59,9 +59,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -59,9 +59,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[i]; r_prev_state = value.prev_state_value[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_gate, active_state); active_node, active_gate, active_state);
value_in[i] = r_value_in; value_in[i] = r_value_in;
value_ig[i] = r_value_ig; value_ig[i] = r_value_ig;
...@@ -125,11 +125,11 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -125,11 +125,11 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[i]; r_prev_state = value.prev_state_value[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_state); active_node, active_gate, active_state);
grad_in[i] = r_grad_in; grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig; grad_ig[i] = r_grad_ig;
...@@ -186,9 +186,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -186,9 +186,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_gate, active_state); active_node, active_gate, active_state);
value_in[i] = r_value_in; value_in[i] = r_value_in;
value_ig[i] = r_value_ig; value_ig[i] = r_value_ig;
...@@ -258,11 +258,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value, ...@@ -258,11 +258,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i]; r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, &r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, &r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, &r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_state); active_node, active_gate, active_state);
grad_in[i] = r_grad_in; grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig; grad_ig[i] = r_grad_ig;
......
...@@ -70,9 +70,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size, ...@@ -70,9 +70,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
r_prev_state = value.prev_state_value[frame_idx]; r_prev_state = value.prev_state_value[frame_idx];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, active_gate, &r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_state); active_node, active_gate, active_state);
value.gate_value[frame_idx] = r_value_in; value.gate_value[frame_idx] = r_value_in;
value.gate_value[frame_idx + frame_size] = r_value_ig; value.gate_value[frame_idx + frame_size] = r_value_ig;
...@@ -145,11 +145,11 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value, ...@@ -145,11 +145,11 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[frame_idx]; r_prev_state = value.prev_state_value[frame_idx];
} }
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig, op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, active_node,
active_state); active_gate, active_state);
grad.gate_grad[frame_idx] = r_grad_in; grad.gate_grad[frame_idx] = r_grad_in;
grad.gate_grad[frame_idx + frame_size] = r_grad_ig; grad.gate_grad[frame_idx + frame_size] = r_grad_ig;
......
...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include <type_traits>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -27,19 +27,19 @@ namespace forward { ...@@ -27,19 +27,19 @@ namespace forward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
T &prev_state, T &state, T &state_atv, T &output, T *prev_state, T *state, T *state_atv, T *output,
T &checkI, T &checkF, T &checkO, T *checkI, T *checkF, T *checkO,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
value_in = activation(value_in, active_node); *value_in = activation(*value_in, active_node);
value_ig = activation(value_ig + prev_state * checkI, active_gate); *value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate);
value_fg = activation(value_fg + prev_state * checkF, active_gate); *value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate);
state = value_in * value_ig + prev_state * value_fg; *state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg);
value_og = activation(value_og + state * checkO, active_gate); *value_og = activation(*value_og + (*state) * (*checkO), active_gate);
state_atv = activation(state, active_state); *state_atv = activation(*state, active_state);
output = value_og * state_atv; *output = (*value_og) * (*state_atv);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...@@ -48,27 +48,27 @@ class lstm { ...@@ -48,27 +48,27 @@ class lstm {
// Only float support AVX optimization // Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value; static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig, HOSTDEVICE void operator()(__m256 *value_in, __m256 *value_ig,
__m256 &value_fg, __m256 &value_og, __m256 *value_fg, __m256 *value_og,
__m256 &prev_state, __m256 &state, __m256 *prev_state, __m256 *state,
__m256 &state_atv, __m256 &output, __m256 &checkI, __m256 *state_atv, __m256 *output, __m256 *checkI,
__m256 &checkF, __m256 &checkO, __m256 *checkF, __m256 *checkO,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
value_in = activation(value_in, active_node); *value_in = activation(*value_in, active_node);
value_ig = *value_ig = activation(
activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)), _mm256_add_ps(*value_ig, _mm256_mul_ps(*prev_state, *checkI)),
active_gate); active_gate);
value_fg = *value_fg = activation(
activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)), _mm256_add_ps(*value_fg, _mm256_mul_ps(*prev_state, *checkF)),
active_gate); active_gate);
state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig), *state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig),
_mm256_mul_ps(prev_state, value_fg)); _mm256_mul_ps(*prev_state, *value_fg));
value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)), *value_og = activation(
active_gate); _mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate);
state_atv = activation(state, active_state); *state_atv = activation(*state, active_state);
output = _mm256_mul_ps(value_og, state_atv); *output = _mm256_mul_ps(*value_og, *state_atv);
} }
#endif #endif
#endif #endif
...@@ -81,26 +81,29 @@ namespace backward { ...@@ -81,26 +81,29 @@ namespace backward {
template <class T> template <class T>
class lstm { class lstm {
public: public:
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og, HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
T &grad_in, T &grad_ig, T &grad_fg, T &grad_og, T *grad_in, T *grad_ig, T *grad_fg, T *grad_og,
T &prev_state, T &prev_state_grad, T &state, T *prev_state, T *prev_state_grad, T *state,
T &state_grad, T &state_atv, T &output_grad, T *state_grad, T *state_atv, T *output_grad,
T &checkI, T &checkF, T &checkO, T &checkIGrad, T *checkI, T *checkF, T *checkO, T *checkIGrad,
T &checkFGrad, T &checkOGrad, T *checkFGrad, T *checkOGrad,
ActivationType active_node, ActivationType active_node,
ActivationType active_gate, ActivationType active_gate,
ActivationType active_state) { ActivationType active_state) {
grad_og = activation(output_grad * state_atv, value_og, active_gate); *grad_og =
state_grad += activation(output_grad * value_og, state_atv, active_state) + activation((*output_grad) * (*state_atv), *value_og, active_gate);
grad_og * checkO; *state_grad +=
grad_in = activation(state_grad * value_ig, value_in, active_node); activation((*output_grad) * (*value_og), *state_atv, active_state) +
grad_ig = activation(state_grad * value_in, value_ig, active_gate); (*grad_og) * (*checkO);
grad_fg = activation(state_grad * prev_state, value_fg, active_gate); *grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node);
prev_state_grad = *grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate);
grad_ig * checkI + grad_fg * checkF + state_grad * value_fg; *grad_fg =
checkIGrad = grad_ig * prev_state; activation((*state_grad) * (*prev_state), *value_fg, active_gate);
checkFGrad = grad_fg * prev_state; *prev_state_grad = (*grad_ig) * (*checkI) + (*grad_fg) * (*checkF) +
checkOGrad = grad_og * state; (*state_grad) * (*value_fg);
*checkIGrad = (*grad_ig) * (*prev_state);
*checkFGrad = (*grad_fg) * (*prev_state);
*checkOGrad = (*grad_og) * (*state);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default #ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
...@@ -109,32 +112,33 @@ class lstm { ...@@ -109,32 +112,33 @@ class lstm {
// Only float support AVX optimization // Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value; static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()( HOSTDEVICE void operator()(
__m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og, __m256 *value_in, __m256 *value_ig, __m256 *value_fg, __m256 *value_og,
__m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og, __m256 *grad_in, __m256 *grad_ig, __m256 *grad_fg, __m256 *grad_og,
__m256 &prev_state, __m256 &prev_state_grad, __m256 &state, __m256 *prev_state, __m256 *prev_state_grad, __m256 *state,
__m256 &state_grad, __m256 &state_atv, __m256 &output_grad, __m256 *state_grad, __m256 *state_atv, __m256 *output_grad,
__m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad, __m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad, ActivationType active_node, __m256 *checkFGrad, __m256 *checkOGrad, ActivationType active_node,
ActivationType active_gate, ActivationType active_state) { ActivationType active_gate, ActivationType active_state) {
grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og, *grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og,
active_gate);
*state_grad =
_mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og),
*state_atv, active_state),
*state_grad);
*state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad);
*grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in,
active_node);
*grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig,
active_gate); active_gate);
state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og), *grad_fg = activation(_mm256_mul_ps(*state_grad, *prev_state), *value_fg,
state_atv, active_state),
state_grad);
state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad);
grad_in =
activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node);
grad_ig =
activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate);
grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg,
active_gate); active_gate);
prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI), *prev_state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_ig, *checkI),
_mm256_mul_ps(grad_fg, checkF)); _mm256_mul_ps(*grad_fg, *checkF));
prev_state_grad = *prev_state_grad =
_mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad); _mm256_add_ps(_mm256_mul_ps(*state_grad, *value_fg), *prev_state_grad);
checkIGrad = _mm256_mul_ps(grad_ig, prev_state); *checkIGrad = _mm256_mul_ps(*grad_ig, *prev_state);
checkFGrad = _mm256_mul_ps(grad_fg, prev_state); *checkFGrad = _mm256_mul_ps(*grad_fg, *prev_state);
checkOGrad = _mm256_mul_ps(grad_og, state); *checkOGrad = _mm256_mul_ps(*grad_og, *state);
} }
#endif #endif
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册