提交 f74dff97 编写于 作者: G guosheng

Refine the activation type in the GRU operator related

上级 18311767
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/gru_compute.h" #include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h" #include "paddle/operators/math/sequence2batch.h"
...@@ -70,7 +71,7 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -70,7 +71,7 @@ class GRUKernel : public framework::OpKernel<T> {
} }
int frame_size = hidden_dims[1]; int frame_size = hidden_dims[1];
math::hl_gru_value<T> gru_value; math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data); gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight = gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
...@@ -102,8 +103,10 @@ class GRUKernel : public framework::OpKernel<T> { ...@@ -102,8 +103,10 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.reset_output_value = reset_hidden_prev_t.data<T>(); gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute( math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, dev_ctx, gru_value, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")), math::detail::GetActivationType(
math::ActiveType(context.Attr<std::string>("gate_activation"))); context.Attr<std::string>("activation")),
math::detail::GetActivationType(
context.Attr<std::string>("gate_activation")));
gru_value.prev_out_value = gru_value.output_value; gru_value.prev_out_value = gru_value.output_value;
} }
...@@ -170,12 +173,12 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -170,12 +173,12 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_hidden_grad.set_lod(batch_hidden->lod()); batch_hidden_grad.set_lod(batch_hidden->lod());
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse); to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
math::hl_gru_value<T> gru_value; math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data); gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight = gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size); const_cast<T*>(weight_data + 2 * frame_size * frame_size);
math::hl_gru_grad<T> gru_grad; math::GRUMetaGrad<T> gru_grad;
if (weight_grad) { if (weight_grad) {
gru_grad.gate_weight_grad = gru_grad.gate_weight_grad =
weight_grad->mutable_data<T>(context.GetPlace()); weight_grad->mutable_data<T>(context.GetPlace());
...@@ -220,8 +223,10 @@ class GRUGradKernel : public framework::OpKernel<T> { ...@@ -220,8 +223,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
math::GRUUnitGradFunctor<DeviceContext, T>::compute( math::GRUUnitGradFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")), math::detail::GetActivationType(
math::ActiveType(context.Attr<std::string>("gate_activation"))); context.Attr<std::string>("activation")),
math::detail::GetActivationType(
context.Attr<std::string>("gate_activation")));
} }
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace()); input_grad->mutable_data<T>(context.GetPlace());
......
...@@ -28,7 +28,7 @@ template <class OpResetOutput, typename T> ...@@ -28,7 +28,7 @@ template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output, void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value, T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size, T *prev_output_value, int frame_size,
activation_mode_t active_gate) { ActivationType active_gate) {
T r_value_update_gate; T r_value_update_gate;
T r_value_reset_gate; T r_value_reset_gate;
T r_value_reset_output; T r_value_reset_output;
...@@ -56,7 +56,7 @@ template <class OpFinalOutput, typename T> ...@@ -56,7 +56,7 @@ template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output, void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value, T *gate_value, T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size,
activation_mode_t active_node) { ActivationType active_node) {
T r_value_update_gate; T r_value_update_gate;
T r_value_frame_state; T r_value_frame_state;
T r_prev_out = 0; T r_prev_out = 0;
...@@ -83,7 +83,7 @@ template <class OpResetOutput, typename T> ...@@ -83,7 +83,7 @@ template <class OpResetOutput, typename T>
void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output, void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value, T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size, T *prev_output_value, int frame_size,
activation_mode_t active_gate) { ActivationType active_gate) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_update_gate; __m256 r_value_update_gate;
__m256 r_value_reset_gate; __m256 r_value_reset_gate;
...@@ -113,7 +113,7 @@ template <class OpFinalOutput, typename T> ...@@ -113,7 +113,7 @@ template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value, T *gate_value, T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size,
activation_mode_t active_node) { ActivationType active_node) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_value_update_gate; __m256 r_value_update_gate;
__m256 r_value_frame_state; __m256 r_value_frame_state;
...@@ -140,9 +140,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output, ...@@ -140,9 +140,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
template <class OpResetOutput, typename T> template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput op_reset_output, inline void forward_reset_output(OpResetOutput op_reset_output,
hl_gru_value<T> value, int frame_size, GRUMetaValue<T> value, int frame_size,
int batch_size, int batch_size, ActivationType active_gate) {
activation_mode_t active_gate) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_reset_output( hl_avx_gru_forward_reset_output(
...@@ -164,9 +163,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output, ...@@ -164,9 +163,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
template <class OpFinalOutput, typename T> template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output, inline void forward_final_output(OpFinalOutput op_final_output,
hl_gru_value<T> value, int frame_size, GRUMetaValue<T> value, int frame_size,
int batch_size, int batch_size, ActivationType active_node) {
activation_mode_t active_node) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value, hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
...@@ -191,7 +189,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -191,7 +189,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int frame_size,
activation_mode_t active_node) { ActivationType active_node) {
T r_update_gate_value; T r_update_gate_value;
T r_update_gate_grad; T r_update_gate_grad;
T r_frame_state_value; T r_frame_state_value;
...@@ -232,7 +230,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -232,7 +230,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int frame_size,
activation_mode_t active_gate) { ActivationType active_gate) {
T r_update_gate_value; T r_update_gate_value;
T r_update_gate_grad; T r_update_gate_grad;
T r_reset_gate_value; T r_reset_gate_value;
...@@ -277,7 +275,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value, ...@@ -277,7 +275,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int frame_size,
activation_mode_t active_node) { ActivationType active_node) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_update_gate_value; __m256 r_update_gate_value;
__m256 r_update_gate_grad; __m256 r_update_gate_grad;
...@@ -320,7 +318,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -320,7 +318,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int frame_size,
activation_mode_t active_gate) { ActivationType active_gate) {
#ifdef __AVX__ #ifdef __AVX__
__m256 r_update_gate_value; __m256 r_update_gate_value;
__m256 r_update_gate_grad; __m256 r_update_gate_grad;
...@@ -364,9 +362,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -364,9 +362,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
template <class OpStateGrad, typename T> template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad op_state_grad, inline void backward_state_grad(OpStateGrad op_state_grad,
hl_gru_value<T> value, hl_gru_grad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
activation_mode_t active_node) { ActivationType active_node) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_state_grad( hl_avx_gru_backward_state_grad(
...@@ -393,9 +391,9 @@ inline void backward_state_grad(OpStateGrad op_state_grad, ...@@ -393,9 +391,9 @@ inline void backward_state_grad(OpStateGrad op_state_grad,
template <class OpResetGrad, typename T> template <class OpResetGrad, typename T>
inline void backward_reset_grad(OpResetGrad op_reset_grad, inline void backward_reset_grad(OpResetGrad op_reset_grad,
hl_gru_value<T> value, hl_gru_grad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
activation_mode_t active_gate) { ActivationType active_gate) {
for (int b = 0; b < batch_size; b++) { for (int b = 0; b < batch_size; b++) {
if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) { if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_reset_grad( hl_avx_gru_backward_reset_grad(
......
...@@ -19,8 +19,6 @@ limitations under the License. */ ...@@ -19,8 +19,6 @@ limitations under the License. */
#include "paddle/platform/cuda_helper.h" #include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include <glog/logging.h>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
...@@ -35,7 +33,7 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output, ...@@ -35,7 +33,7 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value, T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size, T *prev_output_value, int frame_size,
int batch_size, int batch_size,
activation_mode_t active_gate) { ActivationType active_gate) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
...@@ -74,7 +72,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output, ...@@ -74,7 +72,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value, T *gate_value, T *prev_output_value,
T *output_value, int frame_size, T *output_value, int frame_size,
int batch_size, int batch_size,
activation_mode_t active_node) { ActivationType active_node) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
...@@ -111,7 +109,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value, ...@@ -111,7 +109,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad, T *prev_out_grad, T *output_grad,
int frame_size, int batch_size, int frame_size, int batch_size,
activation_mode_t active_node) { ActivationType active_node) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
...@@ -159,7 +157,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value, ...@@ -159,7 +157,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value, T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad, T *prev_out_grad, T *reset_output_grad,
int frame_size, int batch_size, int frame_size, int batch_size,
activation_mode_t active_gate) { ActivationType active_gate) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x; const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return; if (frame_idx >= frame_size) return;
int batch_idx = 0; int batch_idx = 0;
......
...@@ -30,7 +30,7 @@ class gru_resetOutput { ...@@ -30,7 +30,7 @@ class gru_resetOutput {
public: public:
HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate, HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate,
T &prev_out, T &value_reset_output, T &prev_out, T &value_reset_output,
activation_mode_t act_gate) { ActivationType act_gate) {
value_update_gate = activation(value_update_gate, act_gate); value_update_gate = activation(value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate); value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = prev_out * value_reset_gate; value_reset_output = prev_out * value_reset_gate;
...@@ -43,7 +43,7 @@ class gru_resetOutput { ...@@ -43,7 +43,7 @@ class gru_resetOutput {
HOSTDEVICE void operator()(__m256 &value_update_gate, HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &value_reset_gate, __m256 &prev_out, __m256 &value_reset_gate, __m256 &prev_out,
__m256 &value_reset_output, __m256 &value_reset_output,
activation_mode_t act_gate) { ActivationType act_gate) {
value_update_gate = activation(value_update_gate, act_gate); value_update_gate = activation(value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate); value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate); value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate);
...@@ -57,7 +57,7 @@ class gru_finalOutput { ...@@ -57,7 +57,7 @@ class gru_finalOutput {
public: public:
HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state, HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state,
T &prev_out, T &value_output, T &prev_out, T &value_output,
activation_mode_t act_input) { ActivationType act_input) {
value_frame_state = activation(value_frame_state, act_input); value_frame_state = activation(value_frame_state, act_input);
value_output = prev_out - (value_update_gate * prev_out) + value_output = prev_out - (value_update_gate * prev_out) +
(value_update_gate * value_frame_state); (value_update_gate * value_frame_state);
...@@ -69,8 +69,7 @@ class gru_finalOutput { ...@@ -69,8 +69,7 @@ class gru_finalOutput {
static const bool avx = true; static const bool avx = true;
HOSTDEVICE void operator()(__m256 &value_update_gate, HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &value_frame_state, __m256 &prev_out, __m256 &value_frame_state, __m256 &prev_out,
__m256 &value_output, __m256 &value_output, ActivationType act_input) {
activation_mode_t act_input) {
value_frame_state = activation(value_frame_state, act_input); value_frame_state = activation(value_frame_state, act_input);
value_output = _mm256_add_ps( value_output = _mm256_add_ps(
_mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)), _mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)),
...@@ -89,7 +88,7 @@ class gru_stateGrad { ...@@ -89,7 +88,7 @@ class gru_stateGrad {
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &value_frame_state, T &grad_frame_state, T &value_frame_state, T &grad_frame_state,
T &value_prev_out, T &grad_prev_out, T &value_prev_out, T &grad_prev_out,
T &grad_output, activation_mode_t act_input) { T &grad_output, ActivationType act_input) {
grad_update_gate = (grad_output * value_frame_state); grad_update_gate = (grad_output * value_frame_state);
grad_update_gate -= (grad_output * value_prev_out); grad_update_gate -= (grad_output * value_prev_out);
grad_prev_out -= (grad_output * value_update_gate); grad_prev_out -= (grad_output * value_update_gate);
...@@ -107,7 +106,7 @@ class gru_stateGrad { ...@@ -107,7 +106,7 @@ class gru_stateGrad {
__m256 &value_frame_state, __m256 &value_frame_state,
__m256 &grad_frame_state, __m256 &value_prev_out, __m256 &grad_frame_state, __m256 &value_prev_out,
__m256 &grad_prev_out, __m256 &grad_output, __m256 &grad_prev_out, __m256 &grad_output,
activation_mode_t act_input) { ActivationType act_input) {
grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state); grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state);
grad_update_gate = _mm256_sub_ps( grad_update_gate = _mm256_sub_ps(
grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out)); grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out));
...@@ -128,7 +127,7 @@ class gru_resetGrad { ...@@ -128,7 +127,7 @@ class gru_resetGrad {
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate, HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &value_reset_gate, T &grad_reset_gate, T &value_reset_gate, T &grad_reset_gate,
T &value_prev_out, T &grad_prev_out, T &value_prev_out, T &grad_prev_out,
T &grad_reset_output, activation_mode_t act_gate) { T &grad_reset_output, ActivationType act_gate) {
grad_reset_gate = (grad_reset_output * value_prev_out); grad_reset_gate = (grad_reset_output * value_prev_out);
grad_prev_out += (grad_reset_output * value_reset_gate); grad_prev_out += (grad_reset_output * value_reset_gate);
grad_update_gate = grad_update_gate =
...@@ -144,7 +143,7 @@ class gru_resetGrad { ...@@ -144,7 +143,7 @@ class gru_resetGrad {
__m256 &grad_update_gate, __m256 &value_reset_gate, __m256 &grad_update_gate, __m256 &value_reset_gate,
__m256 &grad_reset_gate, __m256 &value_prev_out, __m256 &grad_reset_gate, __m256 &value_prev_out,
__m256 &grad_prev_out, __m256 &grad_reset_output, __m256 &grad_prev_out, __m256 &grad_reset_output,
activation_mode_t act_gate) { ActivationType act_gate) {
grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out); grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out);
grad_prev_out = _mm256_add_ps( grad_prev_out = _mm256_add_ps(
grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate)); grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate));
......
...@@ -21,9 +21,9 @@ namespace math { ...@@ -21,9 +21,9 @@ namespace math {
template <typename T> template <typename T>
struct GRUUnitFunctor<platform::CPUDeviceContext, T> { struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context, static void compute(const platform::CPUDeviceContext &context,
hl_gru_value<T> value, int frame_size, int batch_size, GRUMetaValue<T> value, int frame_size, int batch_size,
activation_mode_t active_node, const detail::ActivationType active_node,
activation_mode_t active_gate) { const detail::ActivationType active_gate) {
#ifndef __NVCC__ #ifndef __NVCC__
if (value.prev_out_value) { if (value.prev_out_value) {
math::gemm<platform::CPUDeviceContext, T>( math::gemm<platform::CPUDeviceContext, T>(
...@@ -51,10 +51,10 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> { ...@@ -51,10 +51,10 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
template <typename T> template <typename T>
struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> { struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context, static void compute(const platform::CPUDeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
activation_mode_t active_node, const detail::ActivationType active_node,
activation_mode_t active_gate) { const detail::ActivationType active_gate) {
#ifndef __NVCC__ #ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value, detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node); grad, frame_size, batch_size, active_node);
......
...@@ -21,9 +21,8 @@ namespace math { ...@@ -21,9 +21,8 @@ namespace math {
template <typename T> template <typename T>
struct GRUUnitFunctor<platform::CUDADeviceContext, T> { struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context, static void compute(const platform::CUDADeviceContext &context,
hl_gru_value<T> value, int frame_size, int batch_size, GRUMetaValue<T> value, int frame_size, int batch_size,
activation_mode_t active_node, ActivationType active_node, ActivationType active_gate) {
activation_mode_t active_gate) {
auto stream = context.stream(); auto stream = context.stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
...@@ -88,10 +87,9 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> { ...@@ -88,10 +87,9 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
template <typename T> template <typename T>
struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> { struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context, static void compute(const platform::CUDADeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad, GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size, int frame_size, int batch_size,
activation_mode_t active_node, ActivationType active_node, ActivationType active_gate) {
activation_mode_t active_gate) {
auto stream = context.stream(); auto stream = context.stream();
dim3 threads; dim3 threads;
dim3 grid; dim3 grid;
......
...@@ -11,7 +11,7 @@ limitations under the License. */ ...@@ -11,7 +11,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/operators/math/lstm_compute.h" #include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
...@@ -19,9 +19,8 @@ namespace paddle { ...@@ -19,9 +19,8 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// TODO(guosheng): refine code style in gru_compute
template <typename T> template <typename T>
struct hl_gru_value { struct GRUMetaValue {
T *gate_weight; T *gate_weight;
T *state_weight; T *state_weight;
T *gate_value; T *gate_value;
...@@ -31,7 +30,7 @@ struct hl_gru_value { ...@@ -31,7 +30,7 @@ struct hl_gru_value {
}; };
template <typename T> template <typename T>
struct hl_gru_grad { struct GRUMetaGrad {
T *gate_weight_grad; T *gate_weight_grad;
T *state_weight_grad; T *state_weight_grad;
T *gate_grad; T *gate_grad;
...@@ -42,18 +41,18 @@ struct hl_gru_grad { ...@@ -42,18 +41,18 @@ struct hl_gru_grad {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct GRUUnitFunctor { struct GRUUnitFunctor {
static void compute(const DeviceContext &context, hl_gru_value<T> value, static void compute(const DeviceContext &context, GRUMetaValue<T> value,
int frame_size, int batch_size, int frame_size, int batch_size,
activation_mode_t active_node, const detail::ActivationType active_node,
activation_mode_t active_gate); const detail::ActivationType active_gate);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct GRUUnitGradFunctor { struct GRUUnitGradFunctor {
static void compute(const DeviceContext &context, hl_gru_value<T> value, static void compute(const DeviceContext &context, GRUMetaValue<T> value,
hl_gru_grad<T> grad, int frame_size, int batch_size, GRUMetaGrad<T> grad, int frame_size, int batch_size,
activation_mode_t active_node, const detail::ActivationType active_node,
activation_mode_t active_gate); const detail::ActivationType active_gate);
}; };
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册