提交 f74dff97 编写于 作者: G guosheng

Refine the activation type in the GRU operator related

上级 18311767
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/operators/math/gru_compute.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence2batch.h"
......@@ -70,7 +71,7 @@ class GRUKernel : public framework::OpKernel<T> {
}
int frame_size = hidden_dims[1];
math::hl_gru_value<T> gru_value;
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
......@@ -102,8 +103,10 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
math::detail::GetActivationType(
context.Attr<std::string>("activation")),
math::detail::GetActivationType(
context.Attr<std::string>("gate_activation")));
gru_value.prev_out_value = gru_value.output_value;
}
......@@ -170,12 +173,12 @@ class GRUGradKernel : public framework::OpKernel<T> {
batch_hidden_grad.set_lod(batch_hidden->lod());
to_batch(dev_ctx, *hidden_grad, batch_hidden_grad, false, is_reverse);
math::hl_gru_value<T> gru_value;
math::GRUMetaValue<T> gru_value;
gru_value.gate_weight = const_cast<T*>(weight_data);
gru_value.state_weight =
const_cast<T*>(weight_data + 2 * frame_size * frame_size);
math::hl_gru_grad<T> gru_grad;
math::GRUMetaGrad<T> gru_grad;
if (weight_grad) {
gru_grad.gate_weight_grad =
weight_grad->mutable_data<T>(context.GetPlace());
......@@ -220,8 +223,10 @@ class GRUGradKernel : public framework::OpKernel<T> {
math::GRUUnitGradFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size,
math::ActiveType(context.Attr<std::string>("activation")),
math::ActiveType(context.Attr<std::string>("gate_activation")));
math::detail::GetActivationType(
context.Attr<std::string>("activation")),
math::detail::GetActivationType(
context.Attr<std::string>("gate_activation")));
}
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
......
......@@ -28,7 +28,7 @@ template <class OpResetOutput, typename T>
void hl_naive_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
T r_value_update_gate;
T r_value_reset_gate;
T r_value_reset_output;
......@@ -56,7 +56,7 @@ template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
activation_mode_t active_node) {
ActivationType active_node) {
T r_value_update_gate;
T r_value_frame_state;
T r_prev_out = 0;
......@@ -83,7 +83,7 @@ template <class OpResetOutput, typename T>
void hl_avx_gru_forward_reset_output(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
#ifdef __AVX__
__m256 r_value_update_gate;
__m256 r_value_reset_gate;
......@@ -113,7 +113,7 @@ template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
activation_mode_t active_node) {
ActivationType active_node) {
#ifdef __AVX__
__m256 r_value_update_gate;
__m256 r_value_frame_state;
......@@ -140,9 +140,8 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
template <class OpResetOutput, typename T>
inline void forward_reset_output(OpResetOutput op_reset_output,
hl_gru_value<T> value, int frame_size,
int batch_size,
activation_mode_t active_gate) {
GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_gate) {
for (int b = 0; b < batch_size; b++) {
if (OpResetOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_reset_output(
......@@ -164,9 +163,8 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output,
hl_gru_value<T> value, int frame_size,
int batch_size,
activation_mode_t active_node) {
GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node) {
for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
......@@ -191,7 +189,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size,
activation_mode_t active_node) {
ActivationType active_node) {
T r_update_gate_value;
T r_update_gate_grad;
T r_frame_state_value;
......@@ -232,7 +230,7 @@ void hl_naive_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
T r_update_gate_value;
T r_update_gate_grad;
T r_reset_gate_value;
......@@ -277,7 +275,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size,
activation_mode_t active_node) {
ActivationType active_node) {
#ifdef __AVX__
__m256 r_update_gate_value;
__m256 r_update_gate_grad;
......@@ -320,7 +318,7 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
#ifdef __AVX__
__m256 r_update_gate_value;
__m256 r_update_gate_grad;
......@@ -364,9 +362,9 @@ void hl_avx_gru_backward_reset_grad(OpResetGrad op_reset_grad, T *gate_value,
template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad op_state_grad,
hl_gru_value<T> value, hl_gru_grad<T> grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_node) {
ActivationType active_node) {
for (int b = 0; b < batch_size; b++) {
if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_state_grad(
......@@ -393,9 +391,9 @@ inline void backward_state_grad(OpStateGrad op_state_grad,
template <class OpResetGrad, typename T>
inline void backward_reset_grad(OpResetGrad op_reset_grad,
hl_gru_value<T> value, hl_gru_grad<T> grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
for (int b = 0; b < batch_size; b++) {
if (OpResetGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_reset_grad(
......
......@@ -19,8 +19,6 @@ limitations under the License. */
#include "paddle/platform/cuda_helper.h"
#include "paddle/platform/device_context.h"
#include <glog/logging.h>
namespace paddle {
namespace operators {
namespace math {
......@@ -35,7 +33,7 @@ __global__ void KeGruForwardResetOutput(OpResetOutput op_reset_output,
T *gate_value, T *reset_output_value,
T *prev_output_value, int frame_size,
int batch_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
......@@ -74,7 +72,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
int batch_size,
activation_mode_t active_node) {
ActivationType active_node) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......@@ -111,7 +109,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size, int batch_size,
activation_mode_t active_node) {
ActivationType active_node) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......@@ -159,7 +157,7 @@ __global__ void KeGruBackwardResetGrad(OpResetGrad op_reset_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *reset_output_grad,
int frame_size, int batch_size,
activation_mode_t active_gate) {
ActivationType active_gate) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......
......@@ -30,7 +30,7 @@ class gru_resetOutput {
public:
HOSTDEVICE void operator()(T &value_update_gate, T &value_reset_gate,
T &prev_out, T &value_reset_output,
activation_mode_t act_gate) {
ActivationType act_gate) {
value_update_gate = activation(value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = prev_out * value_reset_gate;
......@@ -43,7 +43,7 @@ class gru_resetOutput {
HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &value_reset_gate, __m256 &prev_out,
__m256 &value_reset_output,
activation_mode_t act_gate) {
ActivationType act_gate) {
value_update_gate = activation(value_update_gate, act_gate);
value_reset_gate = activation(value_reset_gate, act_gate);
value_reset_output = _mm256_mul_ps(prev_out, value_reset_gate);
......@@ -57,7 +57,7 @@ class gru_finalOutput {
public:
HOSTDEVICE void operator()(T &value_update_gate, T &value_frame_state,
T &prev_out, T &value_output,
activation_mode_t act_input) {
ActivationType act_input) {
value_frame_state = activation(value_frame_state, act_input);
value_output = prev_out - (value_update_gate * prev_out) +
(value_update_gate * value_frame_state);
......@@ -69,8 +69,7 @@ class gru_finalOutput {
static const bool avx = true;
HOSTDEVICE void operator()(__m256 &value_update_gate,
__m256 &value_frame_state, __m256 &prev_out,
__m256 &value_output,
activation_mode_t act_input) {
__m256 &value_output, ActivationType act_input) {
value_frame_state = activation(value_frame_state, act_input);
value_output = _mm256_add_ps(
_mm256_sub_ps(prev_out, _mm256_mul_ps(value_update_gate, prev_out)),
......@@ -89,7 +88,7 @@ class gru_stateGrad {
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &value_frame_state, T &grad_frame_state,
T &value_prev_out, T &grad_prev_out,
T &grad_output, activation_mode_t act_input) {
T &grad_output, ActivationType act_input) {
grad_update_gate = (grad_output * value_frame_state);
grad_update_gate -= (grad_output * value_prev_out);
grad_prev_out -= (grad_output * value_update_gate);
......@@ -107,7 +106,7 @@ class gru_stateGrad {
__m256 &value_frame_state,
__m256 &grad_frame_state, __m256 &value_prev_out,
__m256 &grad_prev_out, __m256 &grad_output,
activation_mode_t act_input) {
ActivationType act_input) {
grad_update_gate = _mm256_mul_ps(grad_output, value_frame_state);
grad_update_gate = _mm256_sub_ps(
grad_update_gate, _mm256_mul_ps(grad_output, value_prev_out));
......@@ -128,7 +127,7 @@ class gru_resetGrad {
HOSTDEVICE void operator()(T &value_update_gate, T &grad_update_gate,
T &value_reset_gate, T &grad_reset_gate,
T &value_prev_out, T &grad_prev_out,
T &grad_reset_output, activation_mode_t act_gate) {
T &grad_reset_output, ActivationType act_gate) {
grad_reset_gate = (grad_reset_output * value_prev_out);
grad_prev_out += (grad_reset_output * value_reset_gate);
grad_update_gate =
......@@ -144,7 +143,7 @@ class gru_resetGrad {
__m256 &grad_update_gate, __m256 &value_reset_gate,
__m256 &grad_reset_gate, __m256 &value_prev_out,
__m256 &grad_prev_out, __m256 &grad_reset_output,
activation_mode_t act_gate) {
ActivationType act_gate) {
grad_reset_gate = _mm256_mul_ps(grad_reset_output, value_prev_out);
grad_prev_out = _mm256_add_ps(
grad_prev_out, _mm256_mul_ps(grad_reset_output, value_reset_gate));
......
......@@ -21,9 +21,9 @@ namespace math {
template <typename T>
struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context,
hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
#ifndef __NVCC__
if (value.prev_out_value) {
math::gemm<platform::CPUDeviceContext, T>(
......@@ -51,10 +51,10 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
template <typename T>
struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
#ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node);
......
......@@ -21,9 +21,8 @@ namespace math {
template <typename T>
struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context,
hl_gru_value<T> value, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
GRUMetaValue<T> value, int frame_size, int batch_size,
ActivationType active_node, ActivationType active_gate) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
......@@ -88,10 +87,9 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
template <typename T>
struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context,
hl_gru_value<T> value, hl_gru_grad<T> grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate) {
ActivationType active_node, ActivationType active_gate) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
......
......@@ -11,7 +11,7 @@ limitations under the License. */
#pragma once
#include "paddle/operators/math/lstm_compute.h"
#include "paddle/operators/math/detail/activation_functions.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
......@@ -19,9 +19,8 @@ namespace paddle {
namespace operators {
namespace math {
// TODO(guosheng): refine code style in gru_compute
template <typename T>
struct hl_gru_value {
struct GRUMetaValue {
T *gate_weight;
T *state_weight;
T *gate_value;
......@@ -31,7 +30,7 @@ struct hl_gru_value {
};
template <typename T>
struct hl_gru_grad {
struct GRUMetaGrad {
T *gate_weight_grad;
T *state_weight_grad;
T *gate_grad;
......@@ -42,18 +41,18 @@ struct hl_gru_grad {
template <typename DeviceContext, typename T>
struct GRUUnitFunctor {
static void compute(const DeviceContext &context, hl_gru_value<T> value,
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate);
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
};
template <typename DeviceContext, typename T>
struct GRUUnitGradFunctor {
static void compute(const DeviceContext &context, hl_gru_value<T> value,
hl_gru_grad<T> grad, int frame_size, int batch_size,
activation_mode_t active_node,
activation_mode_t active_gate);
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
};
} // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册