提交 b6528216 编写于 作者: S Siddharth Goyal 提交者: Abhinav Arora

Fix cpplint errors in lstm kernel (#10394)

上级 bd66eed5
......@@ -59,9 +59,9 @@ void naive_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[i];
}
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node,
active_gate, active_state);
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_node, active_gate, active_state);
value_in[i] = r_value_in;
value_ig[i] = r_value_ig;
......@@ -125,11 +125,11 @@ void naive_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[i];
}
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
active_state);
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
&r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
&r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
&r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_node, active_gate, active_state);
grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig;
......@@ -186,9 +186,9 @@ void avx_lstm_forward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
}
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node,
active_gate, active_state);
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_node, active_gate, active_state);
value_in[i] = r_value_in;
value_ig[i] = r_value_ig;
......@@ -258,11 +258,11 @@ void avx_lstm_backward_one_sequence(Op op, LstmMetaValue<T> value,
r_prev_state = (reinterpret_cast<__m256 *>(value.prev_state_value))[i];
}
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
active_state);
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in,
&r_grad_ig, &r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad,
&r_state, &r_state_grad, &r_state_atv, &r_output_grad, &r_checkI,
&r_checkF, &r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad,
active_node, active_gate, active_state);
grad_in[i] = r_grad_in;
grad_ig[i] = r_grad_ig;
......
......@@ -70,9 +70,9 @@ __global__ void KeLstmForward(Op op, LstmMetaValue<T> value, int frame_size,
r_prev_state = value.prev_state_value[frame_idx];
}
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_prev_state, r_state,
r_state_atv, r_out, r_checkI, r_checkF, r_checkO, active_node, active_gate,
active_state);
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_prev_state,
&r_state, &r_state_atv, &r_out, &r_checkI, &r_checkF, &r_checkO,
active_node, active_gate, active_state);
value.gate_value[frame_idx] = r_value_in;
value.gate_value[frame_idx + frame_size] = r_value_ig;
......@@ -145,11 +145,11 @@ __global__ void KeLstmBackward(Op op, LstmMetaValue<T> value,
r_prev_state = value.prev_state_value[frame_idx];
}
op(r_value_in, r_value_ig, r_value_fg, r_value_og, r_grad_in, r_grad_ig,
r_grad_fg, r_grad_og, r_prev_state, r_prev_state_grad, r_state,
r_state_grad, r_state_atv, r_output_grad, r_checkI, r_checkF, r_checkO,
r_checkIGrad, r_checkFGrad, r_checkOGrad, active_node, active_gate,
active_state);
op(&r_value_in, &r_value_ig, &r_value_fg, &r_value_og, &r_grad_in, &r_grad_ig,
&r_grad_fg, &r_grad_og, &r_prev_state, &r_prev_state_grad, &r_state,
&r_state_grad, &r_state_atv, &r_output_grad, &r_checkI, &r_checkF,
&r_checkO, &r_checkIGrad, &r_checkFGrad, &r_checkOGrad, active_node,
active_gate, active_state);
grad.gate_grad[frame_idx] = r_grad_in;
grad.gate_grad[frame_idx + frame_size] = r_grad_ig;
......
......@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/platform/hostdevice.h"
#include <type_traits>
namespace paddle {
namespace operators {
namespace math {
......@@ -27,19 +27,19 @@ namespace forward {
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
T &prev_state, T &state, T &state_atv, T &output,
T &checkI, T &checkF, T &checkO,
HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
T *prev_state, T *state, T *state_atv, T *output,
T *checkI, T *checkF, T *checkO,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
value_in = activation(value_in, active_node);
value_ig = activation(value_ig + prev_state * checkI, active_gate);
value_fg = activation(value_fg + prev_state * checkF, active_gate);
state = value_in * value_ig + prev_state * value_fg;
value_og = activation(value_og + state * checkO, active_gate);
state_atv = activation(state, active_state);
output = value_og * state_atv;
*value_in = activation(*value_in, active_node);
*value_ig = activation(*value_ig + (*prev_state) * (*checkI), active_gate);
*value_fg = activation(*value_fg + (*prev_state) * (*checkF), active_gate);
*state = (*value_in) * (*value_ig) + (*prev_state) * (*value_fg);
*value_og = activation(*value_og + (*state) * (*checkO), active_gate);
*state_atv = activation(*state, active_state);
*output = (*value_og) * (*state_atv);
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
......@@ -48,27 +48,27 @@ class lstm {
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(__m256 &value_in, __m256 &value_ig,
__m256 &value_fg, __m256 &value_og,
__m256 &prev_state, __m256 &state,
__m256 &state_atv, __m256 &output, __m256 &checkI,
__m256 &checkF, __m256 &checkO,
HOSTDEVICE void operator()(__m256 *value_in, __m256 *value_ig,
__m256 *value_fg, __m256 *value_og,
__m256 *prev_state, __m256 *state,
__m256 *state_atv, __m256 *output, __m256 *checkI,
__m256 *checkF, __m256 *checkO,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
value_in = activation(value_in, active_node);
value_ig =
activation(_mm256_add_ps(value_ig, _mm256_mul_ps(prev_state, checkI)),
active_gate);
value_fg =
activation(_mm256_add_ps(value_fg, _mm256_mul_ps(prev_state, checkF)),
active_gate);
state = _mm256_add_ps(_mm256_mul_ps(value_in, value_ig),
_mm256_mul_ps(prev_state, value_fg));
value_og = activation(_mm256_add_ps(value_og, _mm256_mul_ps(state, checkO)),
active_gate);
state_atv = activation(state, active_state);
output = _mm256_mul_ps(value_og, state_atv);
*value_in = activation(*value_in, active_node);
*value_ig = activation(
_mm256_add_ps(*value_ig, _mm256_mul_ps(*prev_state, *checkI)),
active_gate);
*value_fg = activation(
_mm256_add_ps(*value_fg, _mm256_mul_ps(*prev_state, *checkF)),
active_gate);
*state = _mm256_add_ps(_mm256_mul_ps(*value_in, *value_ig),
_mm256_mul_ps(*prev_state, *value_fg));
*value_og = activation(
_mm256_add_ps(*value_og, _mm256_mul_ps(*state, *checkO)), active_gate);
*state_atv = activation(*state, active_state);
*output = _mm256_mul_ps(*value_og, *state_atv);
}
#endif
#endif
......@@ -81,26 +81,29 @@ namespace backward {
template <class T>
class lstm {
public:
HOSTDEVICE void operator()(T &value_in, T &value_ig, T &value_fg, T &value_og,
T &grad_in, T &grad_ig, T &grad_fg, T &grad_og,
T &prev_state, T &prev_state_grad, T &state,
T &state_grad, T &state_atv, T &output_grad,
T &checkI, T &checkF, T &checkO, T &checkIGrad,
T &checkFGrad, T &checkOGrad,
HOSTDEVICE void operator()(T *value_in, T *value_ig, T *value_fg, T *value_og,
T *grad_in, T *grad_ig, T *grad_fg, T *grad_og,
T *prev_state, T *prev_state_grad, T *state,
T *state_grad, T *state_atv, T *output_grad,
T *checkI, T *checkF, T *checkO, T *checkIGrad,
T *checkFGrad, T *checkOGrad,
ActivationType active_node,
ActivationType active_gate,
ActivationType active_state) {
grad_og = activation(output_grad * state_atv, value_og, active_gate);
state_grad += activation(output_grad * value_og, state_atv, active_state) +
grad_og * checkO;
grad_in = activation(state_grad * value_ig, value_in, active_node);
grad_ig = activation(state_grad * value_in, value_ig, active_gate);
grad_fg = activation(state_grad * prev_state, value_fg, active_gate);
prev_state_grad =
grad_ig * checkI + grad_fg * checkF + state_grad * value_fg;
checkIGrad = grad_ig * prev_state;
checkFGrad = grad_fg * prev_state;
checkOGrad = grad_og * state;
*grad_og =
activation((*output_grad) * (*state_atv), *value_og, active_gate);
*state_grad +=
activation((*output_grad) * (*value_og), *state_atv, active_state) +
(*grad_og) * (*checkO);
*grad_in = activation((*state_grad) * (*value_ig), *value_in, active_node);
*grad_ig = activation((*state_grad) * (*value_in), *value_ig, active_gate);
*grad_fg =
activation((*state_grad) * (*prev_state), *value_fg, active_gate);
*prev_state_grad = (*grad_ig) * (*checkI) + (*grad_fg) * (*checkF) +
(*state_grad) * (*value_fg);
*checkIGrad = (*grad_ig) * (*prev_state);
*checkFGrad = (*grad_fg) * (*prev_state);
*checkOGrad = (*grad_og) * (*state);
}
#ifndef __NVCC__
#ifndef __AVX__ // If not compiled with AVX instructs. Disable AVX by default
......@@ -109,32 +112,33 @@ class lstm {
// Only float support AVX optimization
static const bool avx = std::is_same<T, float>::value;
HOSTDEVICE void operator()(
__m256 &value_in, __m256 &value_ig, __m256 &value_fg, __m256 &value_og,
__m256 &grad_in, __m256 &grad_ig, __m256 &grad_fg, __m256 &grad_og,
__m256 &prev_state, __m256 &prev_state_grad, __m256 &state,
__m256 &state_grad, __m256 &state_atv, __m256 &output_grad,
__m256 &checkI, __m256 &checkF, __m256 &checkO, __m256 &checkIGrad,
__m256 &checkFGrad, __m256 &checkOGrad, ActivationType active_node,
__m256 *value_in, __m256 *value_ig, __m256 *value_fg, __m256 *value_og,
__m256 *grad_in, __m256 *grad_ig, __m256 *grad_fg, __m256 *grad_og,
__m256 *prev_state, __m256 *prev_state_grad, __m256 *state,
__m256 *state_grad, __m256 *state_atv, __m256 *output_grad,
__m256 *checkI, __m256 *checkF, __m256 *checkO, __m256 *checkIGrad,
__m256 *checkFGrad, __m256 *checkOGrad, ActivationType active_node,
ActivationType active_gate, ActivationType active_state) {
grad_og = activation(_mm256_mul_ps(output_grad, state_atv), value_og,
active_gate);
state_grad = _mm256_add_ps(activation(_mm256_mul_ps(output_grad, value_og),
state_atv, active_state),
state_grad);
state_grad = _mm256_add_ps(_mm256_mul_ps(grad_og, checkO), state_grad);
grad_in =
activation(_mm256_mul_ps(state_grad, value_ig), value_in, active_node);
grad_ig =
activation(_mm256_mul_ps(state_grad, value_in), value_ig, active_gate);
grad_fg = activation(_mm256_mul_ps(state_grad, prev_state), value_fg,
active_gate);
prev_state_grad = _mm256_add_ps(_mm256_mul_ps(grad_ig, checkI),
_mm256_mul_ps(grad_fg, checkF));
prev_state_grad =
_mm256_add_ps(_mm256_mul_ps(state_grad, value_fg), prev_state_grad);
checkIGrad = _mm256_mul_ps(grad_ig, prev_state);
checkFGrad = _mm256_mul_ps(grad_fg, prev_state);
checkOGrad = _mm256_mul_ps(grad_og, state);
*grad_og = activation(_mm256_mul_ps(*output_grad, *state_atv), *value_og,
active_gate);
*state_grad =
_mm256_add_ps(activation(_mm256_mul_ps(*output_grad, *value_og),
*state_atv, active_state),
*state_grad);
*state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_og, *checkO), *state_grad);
*grad_in = activation(_mm256_mul_ps(*state_grad, *value_ig), *value_in,
active_node);
*grad_ig = activation(_mm256_mul_ps(*state_grad, *value_in), *value_ig,
active_gate);
*grad_fg = activation(_mm256_mul_ps(*state_grad, *prev_state), *value_fg,
active_gate);
*prev_state_grad = _mm256_add_ps(_mm256_mul_ps(*grad_ig, *checkI),
_mm256_mul_ps(*grad_fg, *checkF));
*prev_state_grad =
_mm256_add_ps(_mm256_mul_ps(*state_grad, *value_fg), *prev_state_grad);
*checkIGrad = _mm256_mul_ps(*grad_ig, *prev_state);
*checkFGrad = _mm256_mul_ps(*grad_fg, *prev_state);
*checkOGrad = _mm256_mul_ps(*grad_og, *state);
}
#endif
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册