未验证 提交 c9f55dfa 编写于 作者: A Abhinav Arora 提交者: GitHub

Fix CPPLint issues in /math/detail/gru_kernel.h (#10390)

* Fix CPPLint issyes in gru_kernel.h

* Fix CPPLint issyes in gru_kernel.h

* Fix Compile error
上级 20fa8480
...@@ -43,8 +43,8 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, ...@@ -43,8 +43,8 @@ void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
r_prev_out = prev_output_value[i]; r_prev_out = prev_output_value[i];
} }
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out, op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
r_value_reset_output, active_gate); &r_value_reset_output, active_gate);
update_gate[i] = r_value_update_gate; update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate; reset_gate[i] = r_value_reset_gate;
...@@ -71,8 +71,8 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -71,8 +71,8 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
r_prev_out = prev_output_value[i]; r_prev_out = prev_output_value[i];
} }
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
r_output, active_node); &r_output, active_node);
frame_state[i] = r_value_frame_state; frame_state[i] = r_value_frame_state;
output_value[i] = r_output; output_value[i] = r_output;
...@@ -99,8 +99,8 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, ...@@ -99,8 +99,8 @@ void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i]; r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
} }
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out, op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
r_value_reset_output, active_gate); &r_value_reset_output, active_gate);
update_gate[i] = r_value_update_gate; update_gate[i] = r_value_update_gate;
reset_gate[i] = r_value_reset_gate; reset_gate[i] = r_value_reset_gate;
...@@ -129,8 +129,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -129,8 +129,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i]; r_prev_out = (reinterpret_cast<__m256 *>(prev_output_value))[i];
} }
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
r_output, active_node); &r_output, active_node);
frame_state[i] = r_value_frame_state; frame_state[i] = r_value_frame_state;
(reinterpret_cast<__m256 *>(output_value))[i] = r_output; (reinterpret_cast<__m256 *>(output_value))[i] = r_output;
...@@ -213,9 +213,9 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -213,9 +213,9 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
r_prev_out_grad = prev_out_grad[i]; r_prev_out_grad = prev_out_grad[i];
} }
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad,
r_frame_state_grad, r_prev_out_value, r_prev_out_grad, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
r_out_grad, active_node); &r_prev_out_grad, &r_out_grad, active_node);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad; frame_state_grad[i] = r_frame_state_grad;
...@@ -258,9 +258,9 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -258,9 +258,9 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
r_prev_out_grad = prev_out_grad[i]; r_prev_out_grad = prev_out_grad[i];
} }
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value, op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad, &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
r_reset_output_grad, active_gate); &r_prev_out_grad, &r_reset_output_grad, active_gate);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad; reset_gate_grad[i] = r_reset_gate_grad;
...@@ -302,9 +302,9 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -302,9 +302,9 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
} }
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad,
r_frame_state_grad, r_prev_out_value, r_prev_out_grad, &r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
r_out_grad, active_node); &r_prev_out_grad, &r_out_grad, active_node);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad; frame_state_grad[i] = r_frame_state_grad;
...@@ -350,9 +350,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -350,9 +350,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i]; r_prev_out_grad = (reinterpret_cast<__m256 *>(prev_out_grad))[i];
} }
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value, op_reset_grad(&r_update_gate_value, &r_update_gate_grad,
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad, &r_reset_gate_value, &r_reset_gate_grad, &r_prev_out_value,
r_reset_output_grad, active_gate); &r_prev_out_grad, &r_reset_output_grad, active_gate);
update_gate_grad[i] = r_update_gate_grad; update_gate_grad[i] = r_update_gate_grad;
reset_gate_grad[i] = r_reset_gate_grad; reset_gate_grad[i] = r_reset_gate_grad;
......
...@@ -55,8 +55,8 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, ...@@ -55,8 +55,8 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
r_prev_out = prev_output_value[frame_idx]; r_prev_out = prev_output_value[frame_idx];
} }
op_reset_output(r_value_update_gate, r_value_reset_gate, r_prev_out, op_reset_output(&r_value_update_gate, &r_value_reset_gate, &r_prev_out,
r_value_reset_output, active_gate); &r_value_reset_output, active_gate);
gate_value[frame_idx + frame_size * 0] = r_value_update_gate; gate_value[frame_idx + frame_size * 0] = r_value_update_gate;
gate_value[frame_idx + frame_size * 1] = r_value_reset_gate; gate_value[frame_idx + frame_size * 1] = r_value_reset_gate;
...@@ -93,8 +93,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, ...@@ -93,8 +93,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
r_prev_out = prev_output_value[frame_idx]; r_prev_out = prev_output_value[frame_idx];
} }
op_final_output(r_value_update_gate, r_value_frame_state, r_prev_out, op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
r_output, active_node); &r_output, active_node);
gate_value[frame_idx + frame_size * 2] = r_value_frame_state; gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
output_value[frame_idx] = r_output; output_value[frame_idx] = r_output;
...@@ -137,9 +137,9 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, ...@@ -137,9 +137,9 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
r_prev_out_grad = prev_out_grad[frame_idx]; r_prev_out_grad = prev_out_grad[frame_idx];
} }
op_state_grad(r_update_gate_value, r_update_gate_grad, r_frame_state_value, op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value,
r_frame_state_grad, r_prev_out_value, r_prev_out_grad, &r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad,
r_out_grad, active_node); &r_out_grad, active_node);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad; gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
...@@ -185,9 +185,9 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -185,9 +185,9 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
r_reset_output_grad = reset_output_grad[frame_idx]; r_reset_output_grad = reset_output_grad[frame_idx];
} }
op_reset_grad(r_update_gate_value, r_update_gate_grad, r_reset_gate_value, op_reset_grad(&r_update_gate_value, &r_update_gate_grad, &r_reset_gate_value,
r_reset_gate_grad, r_prev_out_value, r_prev_out_grad, &r_reset_gate_grad, &r_prev_out_value, &r_prev_out_grad,
r_reset_output_grad, active_gate); &r_reset_output_grad, active_gate);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad; gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad; gate_grad[frame_idx + frame_size * 1] = r_reset_gate_grad;
......
...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/fluid/operators/math/detail/activation_functions.h" #include "paddle/fluid/operators/math/detail/activation_functions.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
#include <type_traits>
// TODO(guosheng): refine code style in gru_kernel // TODO(guosheng): refine code style in gru_kernel
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,25 +28,25 @@ namespace forward { ...@@ -28,25 +28,25 @@ namespace forward {
template <typename T> template <typename T>
class gru_resetOutput { class gru_resetOutput {
public: public:
HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate, HOSTDEVICE void operator()(T *value_update_gate, T *value_reset_gate,
T &prev_out, T &value_reset_output, T *prev_out, T *value_reset_output,
ActivationType act_gate) { ActivationType act_gate) {
value_update_gate = activation(value_update_gate, act_gate); *value_update_gate = activation(*value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate); *value_reset_gate = activation(*value_reset_gate, act_gate);
value_reset_output = prev_out * value_reset_gate; *value_reset_output = (*prev_out) * (*value_reset_gate);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 &value_reset_gate, __m256 &prev_out, __m256 *value_reset_gate, __m256 *prev_out,
__m256 &value_reset_output, __m256 *value_reset_output,
ActivationType act_gate) { ActivationType act_gate) {
value_update_gate = activation(value_update_gate, act_gate); *value_update_gate = activation(*value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate); *value_reset_gate = activation(*value_reset_gate, act_gate);
value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate); *value_reset_output = _mm256_mul_ps(*prev_out, *value_reset_gate);
} }
#endif #endif
#endif #endif
...@@ -55,25 +55,25 @@ class gru_resetOutput { ...@@ -55,25 +55,25 @@ class gru_resetOutput {
template <typename T> template <typename T>
class gru_finalOutput { class gru_finalOutput {
public: public:
HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state, HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state,
T &prev_out, T &value_output, T *prev_out, T *value_output,
ActivationType act_input) { ActivationType act_input) {
value_frame_state = activation(value_frame_state, act_input); *value_frame_state = activation(*value_frame_state, act_input);
value_output = prev_out - (value_update_gate * prev_out) + *value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
(value_update_gate * value_frame_state); ((*value_update_gate) * (*value_frame_state));
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 &value_frame_state, __m256 &prev_out, __m256 *value_frame_state, __m256 *prev_out,
__m256 &value_output, ActivationType act_input) { __m256 *value_output, ActivationType act_input) {
value_frame_state = activation(value_frame_state, act_input); *value_frame_state = activation(*value_frame_state, act_input);
value_output = _mm256_add_ps( *value_output = _mm256_add_ps(
_mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)), _mm256_sub_ps(*prev_out, _mm256_mul_ps(*value_update_gate, *prev_out)),
_mm256_mul_ps(value_update_gate, value_frame_state)); _mm256_mul_ps(*value_update_gate, *value_frame_state));
} }
#endif #endif
#endif #endif
...@@ -85,37 +85,38 @@ namespace backward { ...@@ -85,37 +85,38 @@ namespace backward {
template <typename T> template <typename T>
class gru_stateGrad { class gru_stateGrad {
public: public:
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
T &value_frame_state, T &grad_frame_state, T *value_frame_state, T *grad_frame_state,
T &value_prev_out, T &grad_prev_out, T *value_prev_out, T *grad_prev_out,
T &grad_output, ActivationType act_input) { T *grad_output, ActivationType act_input) {
grad_update_gate = (grad_output * value_frame_state); *grad_update_gate = (*grad_output * (*value_frame_state));
grad_update_gate -= (grad_output * value_prev_out); *grad_update_gate -= (*grad_output * (*value_prev_out));
grad_prev_out -= (grad_output * value_update_gate); *grad_prev_out -= (*grad_output * (*value_update_gate));
grad_prev_out += grad_output; *grad_prev_out += *grad_output;
grad_frame_state = activation(grad_output * value_update_gate, *grad_frame_state = activation(*grad_output * (*value_update_gate),
value_frame_state, act_input); *value_frame_state, act_input);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 &grad_update_gate, __m256 *grad_update_gate,
__m256 &value_frame_state, __m256 *value_frame_state,
__m256 &grad_frame_state, __m256 &value_prev_out, __m256 *grad_frame_state, __m256 *value_prev_out,
__m256 &grad_prev_out, __m256 &grad_output, __m256 *grad_prev_out, __m256 *grad_output,
ActivationType act_input) { ActivationType act_input) {
grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state); *grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state);
grad_update_gate = _mm256_sub_ps( *grad_update_gate = _mm256_sub_ps(
grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out)); *grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out));
grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
_mm256_sub_ps(grad_prev_out, _mm256_sub_ps(*grad_prev_out,
_mm256_mul_ps(grad_output, value_update_gate)), _mm256_mul_ps(*grad_output, *value_update_gate)),
grad_output); *grad_output);
grad_frame_state = activation(_mm256_mul_ps(grad_output, value_update_gate), *grad_frame_state =
value_frame_state, act_input); activation(_mm256_mul_ps(*grad_output, *value_update_gate),
*value_frame_state, act_input);
} }
#endif #endif
#endif #endif
...@@ -124,32 +125,34 @@ class gru_stateGrad { ...@@ -124,32 +125,34 @@ class gru_stateGrad {
template <typename T> template <typename T>
class gru_resetGrad { class gru_resetGrad {
public: public:
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
T &value_reset_gate, T &grad_reset_gate, T *value_reset_gate, T *grad_reset_gate,
T &value_prev_out, T &grad_prev_out, T *value_prev_out, T *grad_prev_out,
T &grad_reset_output, ActivationType act_gate) { T *grad_reset_output, ActivationType act_gate) {
grad_reset_gate = (grad_reset_output * value_prev_out); *grad_reset_gate = (*grad_reset_output * (*value_prev_out));
grad_prev_out += (grad_reset_output * value_reset_gate); *grad_prev_out += (*grad_reset_output * (*value_reset_gate));
grad_update_gate = *grad_update_gate =
activation(grad_update_gate, value_update_gate, act_gate); activation(*grad_update_gate, *value_update_gate, act_gate);
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate); *grad_reset_gate =
activation(*grad_reset_gate, *value_reset_gate, act_gate);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
static const bool avx = false; static const bool avx = false;
#else #else
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &value_update_gate, HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 &grad_update_gate, __m256 &value_reset_gate, __m256 *grad_update_gate, __m256 *value_reset_gate,
__m256 &grad_reset_gate, __m256 &value_prev_out, __m256 *grad_reset_gate, __m256 *value_prev_out,
__m256 &grad_prev_out, __m256 &grad_reset_output, __m256 *grad_prev_out, __m256 *grad_reset_output,
ActivationType act_gate) { ActivationType act_gate) {
grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out); *grad_reset_gate = _mm256_mul_ps(*grad_reset_output, *value_prev_out);
grad_prev_out = _mm256_add_ps( *grad_prev_out = _mm256_add_ps(
grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate)); *grad_prev_out, _mm256_mul_ps(*grad_reset_output, *value_reset_gate));
grad_update_gate = *grad_update_gate =
activation(grad_update_gate, value_update_gate, act_gate); activation(*grad_update_gate, *value_update_gate, act_gate);
grad_reset_gate = activation(grad_reset_gate, value_reset_gate, act_gate); *grad_reset_gate =
activation(*grad_reset_gate, *value_reset_gate, act_gate);
} }
#endif #endif
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册