提交 1d85b2bd 编写于 作者: G guosheng

Refine GRU Operator according to activation_functions

上级 4b8bcf32
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <type_traits> #include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h" #include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/gru_compute.h" #include "paddle/operators/math/gru_compute.h"
namespace paddle { namespace paddle {
...@@ -43,9 +43,8 @@ void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput, ...@@ -43,9 +43,8 @@ void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput,
rPrevOut = prevOutputValue[i]; rPrevOut = prevOutputValue[i];
} }
hppl::cpu::ForwardAct<T> act;
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut,
rValueResetOutput, act(active_gate)); rValueResetOutput, active_gate);
updateGate[i] = rValueUpdateGate; updateGate[i] = rValueUpdateGate;
resetGate[i] = rValueResetGate; resetGate[i] = rValueResetGate;
...@@ -72,9 +71,8 @@ void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput, ...@@ -72,9 +71,8 @@ void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput,
rPrevOut = prevOutputValue[i]; rPrevOut = prevOutputValue[i];
} }
hppl::cpu::ForwardAct<T> act;
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
act(active_node)); active_node);
frameState[i] = rValueFrameState; frameState[i] = rValueFrameState;
outputValue[i] = rOutput; outputValue[i] = rOutput;
...@@ -102,7 +100,7 @@ void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue, ...@@ -102,7 +100,7 @@ void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue,
} }
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut,
rValueResetOutput, hppl::avx::forward[active_gate]); rValueResetOutput, active_gate);
updateGate[i] = rValueUpdateGate; updateGate[i] = rValueUpdateGate;
resetGate[i] = rValueResetGate; resetGate[i] = rValueResetGate;
...@@ -132,7 +130,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue, ...@@ -132,7 +130,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue,
} }
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
hppl::avx::forward[active_node]); active_node);
frameState[i] = rValueFrameState; frameState[i] = rValueFrameState;
((__m256 *)outputValue)[i] = rOutput; ((__m256 *)outputValue)[i] = rOutput;
...@@ -215,10 +213,9 @@ void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue, ...@@ -215,10 +213,9 @@ void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
rPrevOutGrad = prevOutGrad[i]; rPrevOutGrad = prevOutGrad[i];
} }
hppl::cpu::BackwardAct<T> act;
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
act(active_node)); active_node);
updateGateGrad[i] = rUpdateGateGrad; updateGateGrad[i] = rUpdateGateGrad;
frameStateGrad[i] = rFrameStateGrad; frameStateGrad[i] = rFrameStateGrad;
...@@ -261,10 +258,9 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue, ...@@ -261,10 +258,9 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
rPrevOutGrad = prevOutGrad[i]; rPrevOutGrad = prevOutGrad[i];
} }
hppl::cpu::BackwardAct<T> act;
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
act(active_gate)); active_gate);
updateGateGrad[i] = rUpdateGateGrad; updateGateGrad[i] = rUpdateGateGrad;
resetGateGrad[i] = rResetGateGrad; resetGateGrad[i] = rResetGateGrad;
...@@ -306,7 +302,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue, ...@@ -306,7 +302,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
hppl::avx::backward[active_node]); active_node);
updateGateGrad[i] = rUpdateGateGrad; updateGateGrad[i] = rUpdateGateGrad;
frameStateGrad[i] = rFrameStateGrad; frameStateGrad[i] = rFrameStateGrad;
...@@ -353,7 +349,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue, ...@@ -353,7 +349,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
hppl::avx::backward[active_gate]); active_gate);
updateGateGrad[i] = rUpdateGateGrad; updateGateGrad[i] = rUpdateGateGrad;
resetGateGrad[i] = rResetGateGrad; resetGateGrad[i] = rResetGateGrad;
......
...@@ -57,9 +57,8 @@ __global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput, ...@@ -57,9 +57,8 @@ __global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput,
rPrevOut = prevOutputValue[frameIdx]; rPrevOut = prevOutputValue[frameIdx];
} }
hppl::gpu::ForwardAct<T> act;
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput, opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput,
act(active_gate)); active_gate);
gateValue[frameIdx + frameSize * 0] = rValueUpdateGate; gateValue[frameIdx + frameSize * 0] = rValueUpdateGate;
gateValue[frameIdx + frameSize * 1] = rValueResetGate; gateValue[frameIdx + frameSize * 1] = rValueResetGate;
...@@ -96,9 +95,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput, ...@@ -96,9 +95,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput,
rPrevOut = prevOutputValue[frameIdx]; rPrevOut = prevOutputValue[frameIdx];
} }
hppl::gpu::ForwardAct<T> act;
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput, opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
act(active_node)); active_node);
gateValue[frameIdx + frameSize * 2] = rValueFrameState; gateValue[frameIdx + frameSize * 2] = rValueFrameState;
outputValue[frameIdx] = rOutput; outputValue[frameIdx] = rOutput;
...@@ -141,10 +139,9 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue, ...@@ -141,10 +139,9 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue,
rPrevOutGrad = prevOutGrad[frameIdx]; rPrevOutGrad = prevOutGrad[frameIdx];
} }
hppl::gpu::BackwardAct<T> act;
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue, opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad, rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
act(active_node)); active_node);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad; gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
gateGrad[frameIdx + frameSize * 2] = rFrameStateGrad; gateGrad[frameIdx + frameSize * 2] = rFrameStateGrad;
...@@ -190,10 +187,9 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue, ...@@ -190,10 +187,9 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue,
rResetOutputGrad = resetOutputGrad[frameIdx]; rResetOutputGrad = resetOutputGrad[frameIdx];
} }
hppl::gpu::BackwardAct<T> act;
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue, opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad, rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
act(active_gate)); active_gate);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad; gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
gateGrad[frameIdx + frameSize * 1] = rResetGateGrad; gateGrad[frameIdx + frameSize * 1] = rResetGateGrad;
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/operators/math/detail/hl_activation_functions.h" #include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/hostdevice.h" #include "paddle/platform/hostdevice.h"
#include <type_traits> #include <type_traits>
...@@ -27,18 +27,10 @@ namespace forward { ...@@ -27,18 +27,10 @@ namespace forward {
template <typename T> template <typename T>
class gru_resetOutput { class gru_resetOutput {
public: public:
/**
* @param[in,out] valueUpdateGate update gate
* @param[in,out] valueResetGate reset gate
* @param[in] prevOut previous output
* @param[out] valueResetOutput intermediate value for frame state
* @param[in] actGate forward function of gate
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueResetGate, T &prevOut, HOSTDEVICE void operator()(T &valueUpdateGate, T &valueResetGate, T &prevOut,
T &valueResetOutput, T &valueResetOutput, activation_mode_t actGate) {
typename hppl::Active<T>::forward actGate) { valueUpdateGate = activation(valueUpdateGate, actGate);
valueUpdateGate = actGate(valueUpdateGate); valueResetGate = activation(valueResetGate, actGate);
valueResetGate = actGate(valueResetGate);
valueResetOutput = prevOut * valueResetGate; valueResetOutput = prevOut * valueResetGate;
} }
#ifndef __NVCC__ #ifndef __NVCC__
...@@ -48,9 +40,9 @@ class gru_resetOutput { ...@@ -48,9 +40,9 @@ class gru_resetOutput {
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueResetGate, HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueResetGate,
__m256 &prevOut, __m256 &valueResetOutput, __m256 &prevOut, __m256 &valueResetOutput,
typename hppl::Active<__m256>::forward actGate) { activation_mode_t actGate) {
valueUpdateGate = actGate(valueUpdateGate); valueUpdateGate = activation(valueUpdateGate, actGate);
valueResetGate = actGate(valueResetGate); valueResetGate = activation(valueResetGate, actGate);
valueResetOutput = _mm256_mul_ps(prevOut, valueResetGate); valueResetOutput = _mm256_mul_ps(prevOut, valueResetGate);
} }
#endif #endif
...@@ -60,17 +52,9 @@ class gru_resetOutput { ...@@ -60,17 +52,9 @@ class gru_resetOutput {
template <typename T> template <typename T>
class gru_finalOutput { class gru_finalOutput {
public: public:
/**
* @param[in] valueUpdateGate update gate
* @param[in,out] valueFrameState frame state ({\tilde{h}_t})
* @param[in] prevOut previous output
* @param[out] valueOutput output
* @param[in] actInput forward function of node
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueFrameState, T &prevOut, HOSTDEVICE void operator()(T &valueUpdateGate, T &valueFrameState, T &prevOut,
T &valueOutput, T &valueOutput, activation_mode_t actInput) {
typename hppl::Active<T>::forward actInput) { valueFrameState = activation(valueFrameState, actInput);
valueFrameState = actInput(valueFrameState);
valueOutput = prevOut - (valueUpdateGate * prevOut) + valueOutput = prevOut - (valueUpdateGate * prevOut) +
(valueUpdateGate * valueFrameState); (valueUpdateGate * valueFrameState);
} }
...@@ -81,8 +65,8 @@ class gru_finalOutput { ...@@ -81,8 +65,8 @@ class gru_finalOutput {
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueFrameState, HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueFrameState,
__m256 &prevOut, __m256 &valueOutput, __m256 &prevOut, __m256 &valueOutput,
typename hppl::Active<__m256>::forward actInput) { activation_mode_t actInput) {
valueFrameState = actInput(valueFrameState); valueFrameState = activation(valueFrameState, actInput);
valueOutput = _mm256_add_ps( valueOutput = _mm256_add_ps(
_mm256_sub_ps(prevOut, _mm256_mul_ps(valueUpdateGate, prevOut)), _mm256_sub_ps(prevOut, _mm256_mul_ps(valueUpdateGate, prevOut)),
_mm256_mul_ps(valueUpdateGate, valueFrameState)); _mm256_mul_ps(valueUpdateGate, valueFrameState));
...@@ -97,25 +81,16 @@ namespace backward { ...@@ -97,25 +81,16 @@ namespace backward {
template <typename T> template <typename T>
class gru_stateGrad { class gru_stateGrad {
public: public:
/**
* @param[in] valueUpdateGate update gate value
* @param[out] gradUpdateGate update gate grad
* @param[in] valueFrameState frame state value
* @param[out] gradFrameState frame state grad
* @param[in] valuePrevOut previous output value
* @param[in,out] gradPrevOut previous output grad
* @param[in] gradOutput output grad
* @param[in] actInput backward function of frame state
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate, HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
T &valueFrameState, T &gradFrameState, T &valueFrameState, T &gradFrameState,
T &valuePrevOut, T &gradPrevOut, T &gradOutput, T &valuePrevOut, T &gradPrevOut, T &gradOutput,
typename hppl::Active<T>::backward actInput) { activation_mode_t actInput) {
gradUpdateGate = (gradOutput * valueFrameState); gradUpdateGate = (gradOutput * valueFrameState);
gradUpdateGate -= (gradOutput * valuePrevOut); gradUpdateGate -= (gradOutput * valuePrevOut);
gradPrevOut -= (gradOutput * valueUpdateGate); gradPrevOut -= (gradOutput * valueUpdateGate);
gradPrevOut += gradOutput; gradPrevOut += gradOutput;
gradFrameState = actInput(gradOutput * valueUpdateGate, valueFrameState); gradFrameState =
activation(gradOutput * valueUpdateGate, valueFrameState, actInput);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
...@@ -125,16 +100,15 @@ class gru_stateGrad { ...@@ -125,16 +100,15 @@ class gru_stateGrad {
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate, HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate,
__m256 &valueFrameState, __m256 &gradFrameState, __m256 &valueFrameState, __m256 &gradFrameState,
__m256 &valuePrevOut, __m256 &gradPrevOut, __m256 &valuePrevOut, __m256 &gradPrevOut,
__m256 &gradOutput, __m256 &gradOutput, activation_mode_t actInput) {
typename hppl::Active<__m256>::backward actInput) {
gradUpdateGate = _mm256_mul_ps(gradOutput, valueFrameState); gradUpdateGate = _mm256_mul_ps(gradOutput, valueFrameState);
gradUpdateGate = gradUpdateGate =
_mm256_sub_ps(gradUpdateGate, _mm256_mul_ps(gradOutput, valuePrevOut)); _mm256_sub_ps(gradUpdateGate, _mm256_mul_ps(gradOutput, valuePrevOut));
gradPrevOut = _mm256_add_ps( gradPrevOut = _mm256_add_ps(
_mm256_sub_ps(gradPrevOut, _mm256_mul_ps(gradOutput, valueUpdateGate)), _mm256_sub_ps(gradPrevOut, _mm256_mul_ps(gradOutput, valueUpdateGate)),
gradOutput); gradOutput);
gradFrameState = gradFrameState = activation(_mm256_mul_ps(gradOutput, valueUpdateGate),
actInput(_mm256_mul_ps(gradOutput, valueUpdateGate), valueFrameState); valueFrameState, actInput);
} }
#endif #endif
#endif #endif
...@@ -143,25 +117,14 @@ class gru_stateGrad { ...@@ -143,25 +117,14 @@ class gru_stateGrad {
template <typename T> template <typename T>
class gru_resetGrad { class gru_resetGrad {
public: public:
/**
* @param[in] valueUpdateGate update gate value
* @param[in,out] gradUpdateGate update gate grad
* @param[in] valueResetGate reset gate value
* @param[out] gradResetGate reset gate grad
* @param[in] valuePrevOut previous output value
* @param[in,out] gradPrevOut previous output grad
* @param[in] gradResetOutput reset output grad (temp val)
* @param[in] actGate backward function of gate
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate, HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
T &valueResetGate, T &gradResetGate, T &valueResetGate, T &gradResetGate,
T &valuePrevOut, T &gradPrevOut, T &valuePrevOut, T &gradPrevOut,
T &gradResetOutput, T &gradResetOutput, activation_mode_t actGate) {
typename hppl::Active<T>::backward actGate) {
gradResetGate = (gradResetOutput * valuePrevOut); gradResetGate = (gradResetOutput * valuePrevOut);
gradPrevOut += (gradResetOutput * valueResetGate); gradPrevOut += (gradResetOutput * valueResetGate);
gradUpdateGate = actGate(gradUpdateGate, valueUpdateGate); gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate);
gradResetGate = actGate(gradResetGate, valueResetGate); gradResetGate = activation(gradResetGate, valueResetGate, actGate);
} }
#ifndef __NVCC__ #ifndef __NVCC__
#ifndef __AVX__ #ifndef __AVX__
...@@ -172,12 +135,12 @@ class gru_resetGrad { ...@@ -172,12 +135,12 @@ class gru_resetGrad {
__m256 &valueResetGate, __m256 &gradResetGate, __m256 &valueResetGate, __m256 &gradResetGate,
__m256 &valuePrevOut, __m256 &gradPrevOut, __m256 &valuePrevOut, __m256 &gradPrevOut,
__m256 &gradResetOutput, __m256 &gradResetOutput,
typename hppl::Active<__m256>::backward actGate) { activation_mode_t actGate) {
gradResetGate = _mm256_mul_ps(gradResetOutput, valuePrevOut); gradResetGate = _mm256_mul_ps(gradResetOutput, valuePrevOut);
gradPrevOut = _mm256_add_ps(gradPrevOut, gradPrevOut = _mm256_add_ps(gradPrevOut,
_mm256_mul_ps(gradResetOutput, valueResetGate)); _mm256_mul_ps(gradResetOutput, valueResetGate));
gradUpdateGate = actGate(gradUpdateGate, valueUpdateGate); gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate);
gradResetGate = actGate(gradResetGate, valueResetGate); gradResetGate = activation(gradResetGate, valueResetGate, actGate);
} }
#endif #endif
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册