提交 1d85b2bd 编写于 作者: G guosheng

Refine GRU Operator according to activation_functions

上级 4b8bcf32
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
#include <type_traits>
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/gru_compute.h"
namespace paddle {
......@@ -43,9 +43,8 @@ void hl_naive_gru_forward_reset_output(OpResetOutput opResetOutput,
rPrevOut = prevOutputValue[i];
}
hppl::cpu::ForwardAct<T> act;
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut,
rValueResetOutput, act(active_gate));
rValueResetOutput, active_gate);
updateGate[i] = rValueUpdateGate;
resetGate[i] = rValueResetGate;
......@@ -72,9 +71,8 @@ void hl_naive_gru_forward_final_output(OpFinalOutput opFinalOutput,
rPrevOut = prevOutputValue[i];
}
hppl::cpu::ForwardAct<T> act;
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
act(active_node));
active_node);
frameState[i] = rValueFrameState;
outputValue[i] = rOutput;
......@@ -102,7 +100,7 @@ void hl_avx_gru_forward_reset_output(OpResetOutput opResetOutput, T *gateValue,
}
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut,
rValueResetOutput, hppl::avx::forward[active_gate]);
rValueResetOutput, active_gate);
updateGate[i] = rValueUpdateGate;
resetGate[i] = rValueResetGate;
......@@ -132,7 +130,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput opFinalOutput, T *gateValue,
}
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
hppl::avx::forward[active_node]);
active_node);
frameState[i] = rValueFrameState;
((__m256 *)outputValue)[i] = rOutput;
......@@ -215,10 +213,9 @@ void hl_naive_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
rPrevOutGrad = prevOutGrad[i];
}
hppl::cpu::BackwardAct<T> act;
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
act(active_node));
active_node);
updateGateGrad[i] = rUpdateGateGrad;
frameStateGrad[i] = rFrameStateGrad;
......@@ -261,10 +258,9 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
rPrevOutGrad = prevOutGrad[i];
}
hppl::cpu::BackwardAct<T> act;
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
act(active_gate));
active_gate);
updateGateGrad[i] = rUpdateGateGrad;
resetGateGrad[i] = rResetGateGrad;
......@@ -306,7 +302,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad opStateGrad, T *gateValue,
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
hppl::avx::backward[active_node]);
active_node);
updateGateGrad[i] = rUpdateGateGrad;
frameStateGrad[i] = rFrameStateGrad;
......@@ -353,7 +349,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad opResetGrad, T *gateValue,
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
hppl::avx::backward[active_gate]);
active_gate);
updateGateGrad[i] = rUpdateGateGrad;
resetGateGrad[i] = rResetGateGrad;
......
......@@ -57,9 +57,8 @@ __global__ void KeGruForwardResetOutput(OpResetOutput opResetOutput,
rPrevOut = prevOutputValue[frameIdx];
}
hppl::gpu::ForwardAct<T> act;
opResetOutput(rValueUpdateGate, rValueResetGate, rPrevOut, rValueResetOutput,
act(active_gate));
active_gate);
gateValue[frameIdx + frameSize * 0] = rValueUpdateGate;
gateValue[frameIdx + frameSize * 1] = rValueResetGate;
......@@ -96,9 +95,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput opFinalOutput,
rPrevOut = prevOutputValue[frameIdx];
}
hppl::gpu::ForwardAct<T> act;
opFinalOutput(rValueUpdateGate, rValueFrameState, rPrevOut, rOutput,
act(active_node));
active_node);
gateValue[frameIdx + frameSize * 2] = rValueFrameState;
outputValue[frameIdx] = rOutput;
......@@ -141,10 +139,9 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad opStateGrad, T *gateValue,
rPrevOutGrad = prevOutGrad[frameIdx];
}
hppl::gpu::BackwardAct<T> act;
opStateGrad(rUpdateGateValue, rUpdateGateGrad, rFrameStateValue,
rFrameStateGrad, rPrevOutValue, rPrevOutGrad, rOutGrad,
act(active_node));
active_node);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
gateGrad[frameIdx + frameSize * 2] = rFrameStateGrad;
......@@ -190,10 +187,9 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad opResetGrad, T *gateValue,
rResetOutputGrad = resetOutputGrad[frameIdx];
}
hppl::gpu::BackwardAct<T> act;
opResetGrad(rUpdateGateValue, rUpdateGateGrad, rResetGateValue,
rResetGateGrad, rPrevOutValue, rPrevOutGrad, rResetOutputGrad,
act(active_gate));
active_gate);
gateGrad[frameIdx + frameSize * 0] = rUpdateGateGrad;
gateGrad[frameIdx + frameSize * 1] = rResetGateGrad;
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/detail/hl_activation_functions.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/hostdevice.h"
#include <type_traits>
......@@ -27,18 +27,10 @@ namespace forward {
template <typename T>
class gru_resetOutput {
public:
/**
* @param[in,out] valueUpdateGate update gate
* @param[in,out] valueResetGate reset gate
* @param[in] prevOut previous output
* @param[out] valueResetOutput intermediate value for frame state
* @param[in] actGate forward function of gate
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueResetGate, T &prevOut,
T &valueResetOutput,
typename hppl::Active<T>::forward actGate) {
valueUpdateGate = actGate(valueUpdateGate);
valueResetGate = actGate(valueResetGate);
T &valueResetOutput, activation_mode_t actGate) {
valueUpdateGate = activation(valueUpdateGate, actGate);
valueResetGate = activation(valueResetGate, actGate);
valueResetOutput = prevOut * valueResetGate;
}
#ifndef __NVCC__
......@@ -48,9 +40,9 @@ class gru_resetOutput {
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueResetGate,
__m256 &prevOut, __m256 &valueResetOutput,
typename hppl::Active<__m256>::forward actGate) {
valueUpdateGate = actGate(valueUpdateGate);
valueResetGate = actGate(valueResetGate);
activation_mode_t actGate) {
valueUpdateGate = activation(valueUpdateGate, actGate);
valueResetGate = activation(valueResetGate, actGate);
valueResetOutput = _mm256_mul_ps(prevOut, valueResetGate);
}
#endif
......@@ -60,17 +52,9 @@ class gru_resetOutput {
template <typename T>
class gru_finalOutput {
public:
/**
* @param[in] valueUpdateGate update gate
* @param[in,out] valueFrameState frame state ({\tilde{h}_t})
* @param[in] prevOut previous output
* @param[out] valueOutput output
* @param[in] actInput forward function of node
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &valueFrameState, T &prevOut,
T &valueOutput,
typename hppl::Active<T>::forward actInput) {
valueFrameState = actInput(valueFrameState);
T &valueOutput, activation_mode_t actInput) {
valueFrameState = activation(valueFrameState, actInput);
valueOutput = prevOut - (valueUpdateGate * prevOut) +
(valueUpdateGate * valueFrameState);
}
......@@ -81,8 +65,8 @@ class gru_finalOutput {
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &valueFrameState,
__m256 &prevOut, __m256 &valueOutput,
typename hppl::Active<__m256>::forward actInput) {
valueFrameState = actInput(valueFrameState);
activation_mode_t actInput) {
valueFrameState = activation(valueFrameState, actInput);
valueOutput = _mm256_add_ps(
_mm256_sub_ps(prevOut, _mm256_mul_ps(valueUpdateGate, prevOut)),
_mm256_mul_ps(valueUpdateGate, valueFrameState));
......@@ -97,25 +81,16 @@ namespace backward {
template <typename T>
class gru_stateGrad {
public:
/**
* @param[in] valueUpdateGate update gate value
* @param[out] gradUpdateGate update gate grad
* @param[in] valueFrameState frame state value
* @param[out] gradFrameState frame state grad
* @param[in] valuePrevOut previous output value
* @param[in,out] gradPrevOut previous output grad
* @param[in] gradOutput output grad
* @param[in] actInput backward function of frame state
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
T &valueFrameState, T &gradFrameState,
T &valuePrevOut, T &gradPrevOut, T &gradOutput,
typename hppl::Active<T>::backward actInput) {
activation_mode_t actInput) {
gradUpdateGate = (gradOutput * valueFrameState);
gradUpdateGate -= (gradOutput * valuePrevOut);
gradPrevOut -= (gradOutput * valueUpdateGate);
gradPrevOut += gradOutput;
gradFrameState = actInput(gradOutput * valueUpdateGate, valueFrameState);
gradFrameState =
activation(gradOutput * valueUpdateGate, valueFrameState, actInput);
}
#ifndef __NVCC__
#ifndef __AVX__
......@@ -125,16 +100,15 @@ class gru_stateGrad {
HOSTDEVICE void operator()(__m256 &valueUpdateGate, __m256 &gradUpdateGate,
__m256 &valueFrameState, __m256 &gradFrameState,
__m256 &valuePrevOut, __m256 &gradPrevOut,
__m256 &gradOutput,
typename hppl::Active<__m256>::backward actInput) {
__m256 &gradOutput, activation_mode_t actInput) {
gradUpdateGate = _mm256_mul_ps(gradOutput, valueFrameState);
gradUpdateGate =
_mm256_sub_ps(gradUpdateGate, _mm256_mul_ps(gradOutput, valuePrevOut));
gradPrevOut = _mm256_add_ps(
_mm256_sub_ps(gradPrevOut, _mm256_mul_ps(gradOutput, valueUpdateGate)),
gradOutput);
gradFrameState =
actInput(_mm256_mul_ps(gradOutput, valueUpdateGate), valueFrameState);
gradFrameState = activation(_mm256_mul_ps(gradOutput, valueUpdateGate),
valueFrameState, actInput);
}
#endif
#endif
......@@ -143,25 +117,14 @@ class gru_stateGrad {
template <typename T>
class gru_resetGrad {
public:
/**
* @param[in] valueUpdateGate update gate value
* @param[in,out] gradUpdateGate update gate grad
* @param[in] valueResetGate reset gate value
* @param[out] gradResetGate reset gate grad
* @param[in] valuePrevOut previous output value
* @param[in,out] gradPrevOut previous output grad
* @param[in] gradResetOutput reset output grad (temp val)
* @param[in] actGate backward function of gate
*/
HOSTDEVICE void operator()(T &valueUpdateGate, T &gradUpdateGate,
T &valueResetGate, T &gradResetGate,
T &valuePrevOut, T &gradPrevOut,
T &gradResetOutput,
typename hppl::Active<T>::backward actGate) {
T &gradResetOutput, activation_mode_t actGate) {
gradResetGate = (gradResetOutput * valuePrevOut);
gradPrevOut += (gradResetOutput * valueResetGate);
gradUpdateGate = actGate(gradUpdateGate, valueUpdateGate);
gradResetGate = actGate(gradResetGate, valueResetGate);
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate);
gradResetGate = activation(gradResetGate, valueResetGate, actGate);
}
#ifndef __NVCC__
#ifndef __AVX__
......@@ -172,12 +135,12 @@ class gru_resetGrad {
__m256 &valueResetGate, __m256 &gradResetGate,
__m256 &valuePrevOut, __m256 &gradPrevOut,
__m256 &gradResetOutput,
typename hppl::Active<__m256>::backward actGate) {
activation_mode_t actGate) {
gradResetGate = _mm256_mul_ps(gradResetOutput, valuePrevOut);
gradPrevOut = _mm256_add_ps(gradPrevOut,
_mm256_mul_ps(gradResetOutput, valueResetGate));
gradUpdateGate = actGate(gradUpdateGate, valueUpdateGate);
gradResetGate = actGate(gradResetGate, valueResetGate);
gradUpdateGate = activation(gradUpdateGate, valueUpdateGate, actGate);
gradResetGate = activation(gradResetGate, valueResetGate, actGate);
}
#endif
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册