未验证 提交 adba4384 编写于 作者: 乔龙飞 Qiao Longfei 提交者: GitHub

Merge pull request #15161 from jacquesqiao/gru-add-mode

gru add origin mode
......@@ -70,8 +70,8 @@ paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param
paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32'))
paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None))
paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None))
paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None))
paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid'))
paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False))
paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False))
paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.crf_decoding ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None)
......
......@@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) "
"whether to compute reversed GRU.")
.SetDefault(false);
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555")
.SetDefault(false);
AddComment(R"DOC(
GRU Operator implements part calculations of the complete GRU as following:
......@@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
......@@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math::detail::forward_final_output(
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
cur_batch_size, active_node);
cur_batch_size, active_node, origin_mode);
gru_value.prev_out_value = gru_value.output_value;
}
......@@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
active_gate, origin_mode);
gru_value.prev_out_value = gru_value.output_value;
}
......
......@@ -21,6 +21,7 @@ template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
......@@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
active_gate, origin_mode);
gru_value.prev_out_value = gru_value.output_value;
}
......
......@@ -41,6 +41,7 @@ template <typename DeviceContext, typename T>
class GRUGradKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
bool origin_mode = context.Attr<bool>("origin_mode");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
......@@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
math::GRUUnitGradFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node,
active_gate);
active_gate, origin_mode);
}
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
......
......@@ -111,6 +111,13 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
"The activation type used in update gate and reset gate.")
.SetDefault(sigmoid)
.InEnum({identity, sigmoid, tanh, relu});
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article <Learning Phrase Representations "
"using RNN Encoder–Decoder\n"
"for Statistical Machine "
"Translation>(https://arxiv.org/pdf/1406.1078.pdf)")
.SetDefault(false);
AddComment(R"DOC(
GRUUnit Operator implements partial calculations of the GRU unit as following:
......
......@@ -113,7 +113,11 @@ class GRUUnitKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output
h.device(place) = u * (c - h_p) + h_p;
if (context.Attr<bool>("origin_mode")) {
h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p
} else {
h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p
}
}
};
......@@ -180,11 +184,19 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate
// backward for unactivated update gate
if (context.Attr<bool>("origin_mode")) {
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (h_p - c));
// backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * (1 - u));
} else {
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (c - h_p));
// backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * u);
}
// backward for reset_hidden_prev
auto blas = math::GetBlas<DeviceContext, T>(context);
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
......@@ -213,7 +225,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
T* hidden_prev_grad_data =
hidden_prev_grad->mutable_data<T>(context.GetPlace());
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
if (context.Attr<bool>("origin_mode")) {
d_h_p.device(place) = d_r_h_p * r + d_h * u;
} else {
d_h_p.device(place) = d_r_h_p * r + d_h * (1 - u);
}
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1,
hidden_prev_grad_data, frame_size);
......
......@@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
T r_value_update_gate;
T r_value_frame_state;
T r_prev_out = 0;
......@@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
}
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node);
&r_output, active_node, origin_mode);
frame_state[i] = r_value_frame_state;
output_value[i] = r_output;
......@@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
#ifdef __AVX__
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
......@@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
}
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node);
&r_output, active_node, origin_mode);
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state);
......@@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
if (rest > 0) {
i = n - block;
op_final_output(&r_value_update_gate_last, &r_value_frame_state_last,
&r_prev_out_last, &r_output, active_node);
&r_prev_out_last, &r_output, active_node, origin_mode);
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state_last);
......@@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output,
GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node) {
int batch_size, ActivationType active_node,
bool origin_mode) {
for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.output_value,
frame_size, active_node);
frame_size, active_node, origin_mode);
} else {
hl_naive_gru_forward_final_output(
op_final_output, value.gate_value, value.prev_out_value,
value.output_value, frame_size, active_node);
value.output_value, frame_size, active_node, origin_mode);
}
value.gate_value += frame_size * 3;
......@@ -253,7 +256,8 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
T r_update_gate_value;
T r_update_gate_grad;
T r_frame_state_value;
......@@ -279,7 +283,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad,
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
&r_prev_out_grad, &r_out_grad, active_node);
&r_prev_out_grad, &r_out_grad, active_node, origin_mode);
update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad;
......@@ -338,8 +342,8 @@ template <class OpStateGrad, typename T>
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size,
ActivationType active_node) {
int frame_size, ActivationType active_node,
bool origin_mode) {
#ifdef __AVX__
__m256 r_update_gate_value;
__m256 r_update_gate_grad;
......@@ -368,7 +372,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad,
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
&r_prev_out_grad, &r_out_grad, active_node);
&r_prev_out_grad, &r_out_grad, active_node, origin_mode);
update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad;
......@@ -431,16 +435,18 @@ template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad op_state_grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
ActivationType active_node) {
ActivationType active_node, bool origin_mode) {
for (int b = 0; b < batch_size; b++) {
if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_state_grad(
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad, frame_size, active_node);
hl_avx_gru_backward_state_grad(op_state_grad, value.gate_value,
grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad,
frame_size, active_node, origin_mode);
} else {
hl_naive_gru_backward_state_grad(
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad, frame_size, active_node);
hl_naive_gru_backward_state_grad(op_state_grad, value.gate_value,
grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad,
frame_size, active_node, origin_mode);
}
value.gate_value += frame_size * 3;
......
......@@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
int batch_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......@@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
}
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node);
&r_output, active_node, origin_mode);
gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
output_value[frame_idx] = r_output;
......@@ -109,7 +110,8 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size, int batch_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......@@ -139,7 +141,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value,
&r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad,
&r_out_grad, active_node);
&r_out_grad, active_node, origin_mode);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
......
......@@ -57,11 +57,17 @@ class gru_finalOutput {
public:
HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state,
T *prev_out, T *value_output,
ActivationType act_input) {
ActivationType act_input, bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input);
if (origin_mode) {
*value_output = ((*value_update_gate) * (*prev_out)) +
*value_frame_state -
((*value_update_gate) * (*value_frame_state));
} else {
*value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
((*value_update_gate) * (*value_frame_state));
}
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
......@@ -69,12 +75,21 @@ class gru_finalOutput {
static const bool avx = true;
HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 *value_frame_state, __m256 *prev_out,
__m256 *value_output, ActivationType act_input) {
__m256 *value_output, ActivationType act_input,
bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input);
if (origin_mode) {
*value_output = _mm256_sub_ps(
_mm256_add_ps(_mm256_mul_ps(*value_update_gate, *prev_out),
*value_frame_state),
_mm256_mul_ps(*value_update_gate, *value_frame_state));
} else {
*value_output = _mm256_add_ps(
_mm256_sub_ps(*prev_out, _mm256_mul_ps(*value_update_gate, *prev_out)),
_mm256_sub_ps(*prev_out,
_mm256_mul_ps(*value_update_gate, *prev_out)),
_mm256_mul_ps(*value_update_gate, *value_frame_state));
}
}
#endif
#endif
};
......@@ -88,14 +103,24 @@ class gru_stateGrad {
HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
T *value_frame_state, T *grad_frame_state,
T *value_prev_out, T *grad_prev_out,
T *grad_output, ActivationType act_input) {
*grad_update_gate = (*grad_output * (*value_frame_state));
*grad_update_gate -= (*grad_output * (*value_prev_out));
*grad_prev_out -= (*grad_output * (*value_update_gate));
*grad_prev_out += *grad_output;
T *grad_output, ActivationType act_input,
bool origin_mode) {
if (origin_mode) {
*grad_update_gate =
(*grad_output) * ((*value_prev_out) - (*value_frame_state));
*grad_prev_out += (*grad_output * (*value_update_gate));
*grad_frame_state = activation(
*grad_output * (static_cast<T>(1.0) - (*value_update_gate)),
*value_frame_state, act_input);
} else {
*grad_update_gate =
(*grad_output) * ((*value_frame_state) - (*value_prev_out));
*grad_prev_out +=
(*grad_output * (static_cast<T>(1.0) - *value_update_gate));
*grad_frame_state = activation(*grad_output * (*value_update_gate),
*value_frame_state, act_input);
}
}
#ifndef __NVCC__
#ifndef __AVX__
static const bool avx = false;
......@@ -106,18 +131,28 @@ class gru_stateGrad {
__m256 *value_frame_state,
__m256 *grad_frame_state, __m256 *value_prev_out,
__m256 *grad_prev_out, __m256 *grad_output,
ActivationType act_input) {
*grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state);
*grad_update_gate = _mm256_sub_ps(
*grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out));
ActivationType act_input, bool origin_mode) {
if (origin_mode) {
*grad_update_gate = _mm256_mul_ps(
*grad_output, _mm256_sub_ps(*value_prev_out, *value_frame_state));
*grad_prev_out = _mm256_add_ps(
_mm256_sub_ps(*grad_prev_out,
_mm256_mul_ps(*grad_output, *value_update_gate)),
*grad_output);
*grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate));
*grad_frame_state = activation(
_mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f),
*value_update_gate)),
*value_frame_state, act_input);
} else {
*grad_update_gate = _mm256_mul_ps(
*grad_output, _mm256_sub_ps(*value_frame_state, *value_prev_out));
*grad_prev_out = _mm256_add_ps(
*grad_prev_out,
_mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f),
*value_update_gate)));
*grad_frame_state =
activation(_mm256_mul_ps(*grad_output, *value_update_gate),
*value_frame_state, act_input);
}
}
#endif
#endif
};
......
......@@ -23,7 +23,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
const detail::ActivationType active_gate,
bool origin_mode) {
#ifndef __NVCC__
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) {
......@@ -43,7 +44,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
}
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frame_size, batch_size, active_node);
frame_size, batch_size, active_node,
origin_mode);
#endif
}
};
......@@ -54,10 +56,12 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
const detail::ActivationType active_gate,
bool origin_mode) {
#ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node);
grad, frame_size, batch_size, active_node,
origin_mode);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value && grad.prev_out_grad) {
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
......
......@@ -24,7 +24,8 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
const detail::ActivationType active_gate,
bool origin_mode) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
......@@ -73,14 +74,14 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size,
active_node);
active_node, origin_mode);
} else {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* is_batch= */ true,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size,
active_node);
active_node, origin_mode);
}
}
};
......@@ -91,7 +92,8 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
const detail::ActivationType active_gate,
bool origin_mode) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
......@@ -111,14 +113,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.output_grad, frame_size, batch_size, active_node);
grad.output_grad, frame_size, batch_size, active_node, origin_mode);
} else {
detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>,
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.output_grad, frame_size, batch_size, active_node);
grad.output_grad, frame_size, batch_size, active_node, origin_mode);
}
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
......
......@@ -44,7 +44,8 @@ struct GRUUnitFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
const detail::ActivationType active_gate,
bool origin_mode);
};
template <typename DeviceContext, typename T>
......@@ -52,7 +53,8 @@ struct GRUUnitGradFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
const detail::ActivationType active_gate,
bool origin_mode);
};
} // namespace math
......
......@@ -864,12 +864,14 @@ def dynamic_gru(input,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
h_0=None):
h_0=None,
origin_mode=False):
"""
**Gated Recurrent Unit (GRU) Layer**
Refer to `Empirical Evaluation of Gated Recurrent Neural Networks on
Sequence Modeling <https://arxiv.org/abs/1412.3555>`_ .
if origin_mode is False, then the equation of a gru step is from paper
`Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
Modeling <https://arxiv.org/pdf/1412.3555.pdf>`_ .
The formula is as follows:
......@@ -883,6 +885,21 @@ def dynamic_gru(input,
h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t}
if origin_mode is True then the equation is from paper
Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_
.. math::
u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u)
r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r)
\\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c)
h_t & = u_t \odot h_{t-1} + (1-u_t) \odot \\tilde{h_t}
The :math:`\odot` is the element-wise product of the vectors. :math:`act_g`
is the update gate and reset gate activation function and :math:`sigmoid`
is usually used for it. :math:`act_c` is the activation function for
......@@ -980,7 +997,8 @@ def dynamic_gru(input,
attrs={
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'activation': candidate_activation
'activation': candidate_activation,
'origin_mode': origin_mode
})
return hidden
......@@ -991,9 +1009,27 @@ def gru_unit(input,
param_attr=None,
bias_attr=None,
activation='tanh',
gate_activation='sigmoid'):
gate_activation='sigmoid',
origin_mode=False):
"""
GRU unit layer. The equation of a gru step is:
**GRU unit layer**
if origin_mode is True, then the equation of a gru step is from paper
`Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_
.. math::
u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u)
r_t & = actGate(xr_{t} + W_r h_{t-1} + b_r)
m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m)
h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t)
if origin_mode is False, then the equation of a gru step is from paper
`Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
Modeling <https://arxiv.org/pdf/1412.3555.pdf>`_
.. math::
u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u)
......@@ -1002,7 +1038,8 @@ def gru_unit(input,
m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m)
h_t & = dot((1-u_t), m_t) + dot(u_t, h_{t-1})
h_t & = dot((1-u_t), h_{t-1}) + dot(u_t, m_t)
The inputs of gru unit includes :math:`z_t`, :math:`h_{t-1}`. In terms
of the equation above, the :math:`z_t` is split into 3 parts -
......
......@@ -31,7 +31,8 @@ def gru(
is_reverse,
act_state,
act_gate,
dtype='float32'):
dtype='float32',
origin_mode=False):
def _seq_to_batch(lod, is_reverse):
idx_in_seq_list = []
seq_lens = lod[0]
......@@ -66,6 +67,9 @@ def gru(
w_c = w.flatten()[D * D * 2:].reshape((D, D))
c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
g = np.hstack((u_r, c))
if origin_mode:
h = (1 - u) * c + u * h_p
else:
h = u * c + (1 - u) * h_p
return g, r_h_p, h
......@@ -110,6 +114,7 @@ class TestGRUOp(OpTest):
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.dtype = 'float64'
self.origin_mode = False
self.set_confs()
T = sum(self.lod[0])
......@@ -126,7 +131,8 @@ class TestGRUOp(OpTest):
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
input, self.lod, h0, weight, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype)
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype,
self.origin_mode)
self.inputs = {'Input': (input, self.lod), 'Weight': weight}
if self.with_bias:
......@@ -145,7 +151,8 @@ class TestGRUOp(OpTest):
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse
'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode
}
def test_check_output(self):
......@@ -155,12 +162,24 @@ class TestGRUOp(OpTest):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOriginMode(TestGRUOp):
def set_confs(self):
self.origin_mode = True
class TestGRUOp2(TestGRUOp):
def set_confs(self):
self.D = 19
self.dtype = 'float32'
class TestGRUOp2OriginMode(TestGRUOp):
def set_confs(self):
self.D = 19
self.dtype = 'float32'
self.origin_mode = True
class TestGRUOpNoInitial(TestGRUOp):
def set_confs(self):
self.with_h0 = False
......@@ -182,5 +201,11 @@ class TestGRUOpReverse(TestGRUOp):
self.is_reverse = True
class TestGRUOpReverseOriginMode(TestGRUOp):
def set_confs(self):
self.is_reverse = True
self.origin_mode = True
if __name__ == "__main__":
unittest.main()
......@@ -53,7 +53,7 @@ class TestGRUUnitOp(OpTest):
GRUActivationType.relu: relu,
}
def set_inputs(self):
def set_inputs(self, origin_mode=False):
batch_size = self.batch_size
frame_size = self.frame_size
self.op_type = 'gru_unit'
......@@ -68,10 +68,11 @@ class TestGRUUnitOp(OpTest):
}
self.attrs = {
'activation': GRUActivationType.tanh,
'gate_activation': GRUActivationType.sigmoid
'gate_activation': GRUActivationType.sigmoid,
'origin_mode': origin_mode
}
def set_outputs(self):
def set_outputs(self, origin_mode=False):
# GRU calculations
batch_size = self.batch_size
frame_size = self.frame_size
......@@ -93,6 +94,9 @@ class TestGRUUnitOp(OpTest):
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
g = np.hstack((u_r, c))
if origin_mode:
h = (1 - u) * c + u * h_p
else:
h = u * c + (1 - u) * h_p
self.outputs = {
'Gate': g.astype('float64'),
......@@ -111,8 +115,14 @@ class TestGRUUnitOp(OpTest):
self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden'])
class TestGRUUnitOpOriginMode(TestGRUUnitOp):
def setUp(self):
self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True)
class TestGRUUnitOpWithBias(TestGRUUnitOp):
def set_inputs(self):
def set_inputs(self, origin_mode=False):
batch_size = self.batch_size
frame_size = self.frame_size
super(TestGRUUnitOpWithBias, self).set_inputs()
......@@ -120,7 +130,8 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
-0.1, 0.1, (1, frame_size * 3)).astype('float64')
self.attrs = {
'activation': GRUActivationType.identity,
'gate_activation': GRUActivationType.sigmoid
'gate_activation': GRUActivationType.sigmoid,
'origin_mode': origin_mode
}
def test_check_grad(self):
......@@ -132,5 +143,11 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
no_grad_set=set('Input'))
class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias):
def setUp(self):
self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册