未验证 提交 adba4384 编写于 作者: 乔龙飞 Qiao Longfei 提交者: GitHub

Merge pull request #15161 from jacquesqiao/gru-add-mode

gru add origin mode
上级 7cd4dd7c 4d15515c
develop 2.0.1-rocm-post Ligoml-patch-1 OliverLPH-patch-1 OliverLPH-patch-2 PaddlePM-patch-1 PaddlePM-patch-2 ZHUI-patch-1 add_default_att add_model_benchmark_ci add_some_yaml_config addfile all_new_design_exec ascendrc ascendrelease cherry_undefined_var compile_windows delete_2.0.1-rocm-post delete_add_default_att delete_all_new_design_exec delete_ascendrc delete_compile_windows delete_delete_addfile delete_disable_iterable_dataset_unittest delete_fix_dataloader_memory_leak delete_fix_imperative_dygraph_error delete_fix_retry_ci delete_fix_undefined_var delete_improve_sccache delete_incubate/lite delete_paddle_tiny_install delete_paralleltest delete_prv-disable-more-cache delete_revert-31068-fix_conv3d_windows delete_revert-31562-mean delete_revert-33630-bug-fix delete_revert-34159-add_npu_bce_logical_dev delete_revert-34910-spinlocks_for_allocator delete_revert-35069-revert-34910-spinlocks_for_allocator delete_revert-36057-dev/read_flags_in_ut dingjiaweiww-patch-1 disable_iterable_dataset_unittest dy2static enable_eager_model_test final_state_gen_python_c final_state_intermediate fix-numpy-issue fix_concat_slice fix_dataloader_memory_leak fix_imperative_dygraph_error fix_npu_ci fix_op_flops fix_retry_ci fix_rnn_docs fix_tensor_type fix_undefined_var fixiscan fixiscan1 fixiscan2 fixiscan3 github/fork/123malin/netifaces github/fork/123malin/tdm_abacus github/fork/AshburnLee/dev_unique github/fork/ForFishes/fix_memory_matmul github/fork/ForFishes/rm_fluid github/fork/LielinJiang/move-2.0-api github/fork/LielinJiang/visual-dl-cb github/fork/LiuChiachi/add-transformer-generate-square-subsequent-mask-api github/fork/LiuChiachi/fix-example-code-for-hapi-Model github/fork/LiuChiachi/remove-input-requirment-in-dygraph-Model github/fork/MrChengmo/fix_ps_profiler github/fork/MrChengmo/update_ps_heter github/fork/PWhiddy/patch-1 github/fork/Shixiaowei02/dev/save_load_upgrade github/fork/TCChenlong/fix_hapi github/fork/TCChenlong/fix_inden github/fork/Thunderbrook/xpu_slice github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_2 github/fork/XieYunshen/disable_ut_test_parallel_executor_fetch_isolated_var_3 github/fork/XieYunshen/timeout_20S_ut github/fork/ZeyuChen/remove-nltk github/fork/arlesniak/arlesniak/selective__mkldnn_flags github/fork/baiyfbupt/code_doc_mig github/fork/chalsliu/set_timeout github/fork/chen-zhiyu/develop github/fork/chenwhql/ci/try_to_find_test_buffer_shared_memory_reuse_pass_error github/fork/chenwhql/dygraph/remove_scale_loss_and_apply_collective_grads github/fork/chenwhql/saveload/add_get_inference_program github/fork/chenwhql/saveload/remove_save_load_config github/fork/cryoco/pass-compatibility-trt github/fork/danleifeng/isempty_api2.0 github/fork/frankwhzhang/api_transfer github/fork/hbwx24/error_msg/cuda_kernel_error_msg github/fork/heavengate/cherry_yolo_box github/fork/heavengate/update_yolo_box github/fork/iclementine/rnn_fix github/fork/iducn/testestse github/fork/jczaja/prv-25537-fix github/fork/jeff41404/release/1.8 github/fork/jiweibo/api_2.0 github/fork/jiweibo/fix_lite_resnet50_test github/fork/juncaipeng/fix_doc_1 github/fork/lfchener/sample_code github/fork/littletomatodonkey/fix_reg_doc github/fork/liym27/dy2stat_update_assign_to_rc20 github/fork/luotao1/profiler_ut github/fork/mapingshuo/add_wait github/fork/mapingshuo/doc_2.0 github/fork/mapingshuo/zero-0.5 github/fork/miraiwk/dev github/fork/pangyoki/add-Categorical-class-branch github/fork/pangyoki/add-multinomial-op-branch github/fork/pangyoki/fix-test_distritbution-CI github/fork/qjing666/doublegrad github/fork/qjing666/fix_hdfs_download github/fork/sandyhouse/add_gather_etc github/fork/sandyhouse/add_send_recv_alltoall_etc github/fork/sandyhouse/pipeline_exe_run github/fork/seiriosPlus/feature/large_scale_kv_save_delta github/fork/seiriosPlus/fix/paddle_errors_fix github/fork/seiriosPlus/fix/paddle_op_errors github/fork/shangzhizhou/fix_test_activation_op_random_bug github/fork/smallv0221/yxp0924 github/fork/smallv0221/yxp0925 github/fork/swtkiwi/del-matplotlib github/fork/tianshuo78520a/kunlun_test github/fork/tianshuo78520a/update_dockerfile github/fork/wanghaoshuang/bert_fuse github/fork/wanghaoshuang/label_smooth github/fork/wanghuancoder/develop_CUDASynchronize github/fork/wanghuancoder/develop_Layer_doc github/fork/wanghuancoder/develop_ParameterList_doc github/fork/wanghuancoder/develop_Sequential_doc github/fork/wanghuancoder/develop_bilinear_tensor_product github/fork/wanghuancoder/develop_coverage_build_sh github/fork/wanghuancoder/develop_in_dynamic_mode_doc github/fork/wanghuancoder/develop_unique_name_doc github/fork/wangxicoding/fleet_meta_combine github/fork/wawltor/error_message_fix_5 github/fork/willthefrog/remove_l2_norm github/fork/windstamp/momentum_op github/fork/windstamp/mv_op_5 github/fork/windstamp/normal_api github/fork/wojtuss/wojtuss/fusion_gru_quantization github/fork/wojtuss/wojtuss/quantization-with-shift github/fork/wzzju/fix_err_info github/fork/wzzju/pure_fp16 github/fork/xiemoyuan/op_error_message github/fork/xiemoyuan/optimize_error_message github/fork/yaoxuefeng6/fix_doc github/fork/yaoxuefeng6/mod_dataset_v2 github/fork/yongqiangma/lod github/fork/ysh329/fix-clip-by-norm-error github/fork/ysh329/fix-error-clip-by-value github/fork/yukavio/error_info github/fork/zhangting2020/conv_filter_grad github/fork/zhangting2020/is_compile_with_cuda github/fork/zhangting2020/place_doc github/fork/zhangting2020/program github/fork/zhhsplendid/fix_any github/fork/zhhsplendid/refine_api2 github/fork/zhhsplendid/refine_api2_test github/fork/zhhsplendid/refine_api_test_ptb_lm github/fork/zhhsplendid/refine_api_test_resnet github/fork/zhhsplendid/refine_api_test_simnet github/fork/zhiqiu/dev/refine_initializer github/fork/zhiqiu/dev/remove_inplace_argument github/fork/zlsh80826/nvinfer_plugin_var_len_cuda11 improve_sccache incubate/infrt incubate/lite inplace_addto make_flag_adding_easier move_embedding_to_phi move_histogram_to_pten move_sgd_to_phi move_slice_to_pten move_temporal_shift_to_phi move_yolo_box_to_phi npu_fix_alloc numel paddle_tiny_install paralleltest preln_ernie prv-disable-more-cache prv-md-even-more prv-onednn-2.5 pten_tensor_refactor release/1.3 release/1.4 release/1.5 release/1.6 release/1.7 release/1.8 release/2.0 release/2.0-alpha release/2.0-beta release/2.0-rc release/2.0-rc1 release/2.1 release/2.2 release/2.3 release/2.3-fc-ernie-fix release/2.4 release/lite-0.1 revert-24981-add_device_attr_for_regulization revert-26856-strategy_example2 revert-27520-disable_pr revert-31068-fix_conv3d_windows revert-31562-mean revert-32290-develop-hardlabel revert-33037-forci revert-33475-fix_cifar_label_dimension revert-33630-bug-fix revert-34159-add_npu_bce_logical_dev revert-34406-add_copy_from_tensor revert-34910-spinlocks_for_allocator revert-35069-revert-34910-spinlocks_for_allocator revert-36057-dev/read_flags_in_ut revert-36201-refine_fast_threaded_ssa_graph_executor revert-36985-add_license revert-37318-refactor_dygraph_to_eager revert-37926-eager_coreops_500 revert-37956-revert-37727-pylayer_support_tuple revert-38100-mingdong revert-38301-allocation_rearrange_pr revert-38703-numpy_bf16_package_reupload revert-38732-remove_useless_header_in_elementwise_mul_grad revert-38959-Reduce_Grad revert-39143-adjust_empty revert-39227-move_trace_op_to_pten revert-39268-dev/remove_concat_fluid_kernel revert-40170-support_partial_grad revert-41056-revert-40727-move_some_activaion_to_phi revert-41065-revert-40993-mv_ele_floordiv_pow revert-41068-revert-40790-phi_new revert-41944-smaller_inference_api_test revert-42149-do-not-reset-default-stream-for-stream-safe-cuda-allocator revert-43155-fix_ut_tempfile revert-43882-revert-41944-smaller_inference_api_test revert-45808-phi/simplify_size_op revert-46827-deform_comment rocm_dev_0217 support_weight_transpose test_benchmark_ci test_feature_precision_test_c test_model_benchmark test_model_benchmark_ci zhiqiu-patch-1 v2.4.0-rc0 v2.3.2 v2.3.1 v2.3.0 v2.3.0-rc0 v2.2.2 v2.2.1 v2.2.0 v2.2.0-rc0 v2.2.0-bak0 v2.1.3 v2.1.2 v2.1.1 v2.1.0 v2.1.0-rc0 v2.0.2 v2.0.1 v2.0.0 v2.0.0-rc1 v2.0.0-rc0 v2.0.0-beta0 v2.0.0-alpha0 v1.8.5 v1.8.4 v1.8.3 v1.8.2 v1.8.1 v1.8.0 v1.7.2 v1.7.1 v1.7.0 v1.6.3 v1.6.2 v1.6.1 v1.6.0 v1.6.0-rc0 v1.5.2 v1.5.1 v1.5.0 v1.4.1 v1.4.0 v1.3.2 v1.3.1 v1.3.0 lite-v0.1
无相关合并请求
......@@ -70,8 +70,8 @@ paddle.fluid.layers.fc ArgSpec(args=['input', 'size', 'num_flatten_dims', 'param
paddle.fluid.layers.embedding ArgSpec(args=['input', 'size', 'is_sparse', 'is_distributed', 'padding_idx', 'param_attr', 'dtype'], varargs=None, keywords=None, defaults=(False, False, None, None, 'float32'))
paddle.fluid.layers.dynamic_lstm ArgSpec(args=['input', 'size', 'h_0', 'c_0', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'float32', None))
paddle.fluid.layers.dynamic_lstmp ArgSpec(args=['input', 'size', 'proj_size', 'param_attr', 'bias_attr', 'use_peepholes', 'is_reverse', 'gate_activation', 'cell_activation', 'candidate_activation', 'proj_activation', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, None, True, False, 'sigmoid', 'tanh', 'tanh', 'tanh', 'float32', None))
paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None))
paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid'))
paddle.fluid.layers.dynamic_gru ArgSpec(args=['input', 'size', 'param_attr', 'bias_attr', 'is_reverse', 'gate_activation', 'candidate_activation', 'h_0', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, False, 'sigmoid', 'tanh', None, False))
paddle.fluid.layers.gru_unit ArgSpec(args=['input', 'hidden', 'size', 'param_attr', 'bias_attr', 'activation', 'gate_activation', 'origin_mode'], varargs=None, keywords=None, defaults=(None, None, 'tanh', 'sigmoid', False))
paddle.fluid.layers.linear_chain_crf ArgSpec(args=['input', 'label', 'param_attr'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.crf_decoding ArgSpec(args=['input', 'param_attr', 'label'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.cos_sim ArgSpec(args=['X', 'Y'], varargs=None, keywords=None, defaults=None)
......
......@@ -137,6 +137,10 @@ class GRUOpMaker : public framework::OpProtoAndCheckerMaker {
"(bool, defalut: False) "
"whether to compute reversed GRU.")
.SetDefault(false);
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article https://arxiv.org/abs/1412.3555")
.SetDefault(false);
AddComment(R"DOC(
GRU Operator implements part calculations of the complete GRU as following:
......@@ -221,6 +225,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
using DeviceContext = paddle::platform::CPUDeviceContext;
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
......@@ -327,7 +332,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math::detail::forward_final_output(
math::detail::forward::gru_finalOutput<T>(), gru_value, frame_size,
cur_batch_size, active_node);
cur_batch_size, active_node, origin_mode);
gru_value.prev_out_value = gru_value.output_value;
}
......@@ -351,7 +356,7 @@ class GRUCPUKernel : public framework::OpKernel<T> {
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
active_gate, origin_mode);
gru_value.prev_out_value = gru_value.output_value;
}
......
......@@ -21,6 +21,7 @@ template <typename DeviceContext, typename T>
class GRUKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
bool origin_mode = context.Attr<bool>("origin_mode");
auto* input = context.Input<LoDTensor>("Input");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
......@@ -87,7 +88,7 @@ class GRUKernel : public framework::OpKernel<T> {
gru_value.reset_output_value = reset_hidden_prev_t.data<T>();
math::GRUUnitFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, frame_size, cur_batch_size, active_node,
active_gate);
active_gate, origin_mode);
gru_value.prev_out_value = gru_value.output_value;
}
......
......@@ -41,6 +41,7 @@ template <typename DeviceContext, typename T>
class GRUGradKernel : public framework::OpKernel<T> {
public:
void BatchCompute(const framework::ExecutionContext& context) const {
bool origin_mode = context.Attr<bool>("origin_mode");
auto* h0 = context.Input<Tensor>("H0");
auto* weight = context.Input<Tensor>("Weight");
const T* weight_data = weight->data<T>();
......@@ -146,7 +147,7 @@ class GRUGradKernel : public framework::OpKernel<T> {
math::GRUUnitGradFunctor<DeviceContext, T>::compute(
dev_ctx, gru_value, gru_grad, frame_size, cur_batch_size, active_node,
active_gate);
active_gate, origin_mode);
}
if (input_grad) {
input_grad->mutable_data<T>(context.GetPlace());
......
......@@ -111,6 +111,13 @@ class GRUUnitOpMaker : public framework::OpProtoAndCheckerMaker {
"The activation type used in update gate and reset gate.")
.SetDefault(sigmoid)
.InEnum({identity, sigmoid, tanh, relu});
AddAttr<bool>("origin_mode",
"bool"
"use origin mode in article <Learning Phrase Representations "
"using RNN Encoder–Decoder\n"
"for Statistical Machine "
"Translation>(https://arxiv.org/pdf/1406.1078.pdf)")
.SetDefault(false);
AddComment(R"DOC(
GRUUnit Operator implements partial calculations of the GRU unit as following:
......
......@@ -113,7 +113,11 @@ class GRUUnitKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate
// calculate final output
h.device(place) = u * (c - h_p) + h_p;
if (context.Attr<bool>("origin_mode")) {
h.device(place) = c + u * (h_p - c); // (1 - u) * c + u * h_p
} else {
h.device(place) = u * (c - h_p) + h_p; // u * c + (1 - u) * h_p
}
}
};
......@@ -180,11 +184,19 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
auto c = g.slice(c_offsets, extents); // output candidate
// backward for unactivated update gate
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (c - h_p));
// backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * u);
if (context.Attr<bool>("origin_mode")) {
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (h_p - c));
// backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * (1 - u));
} else {
ActGradCompute(context.Attr<int>("gate_activation"), place, u, u,
d_g.slice(u_offsets, extents), d_h * (c - h_p));
// backward for unactivated output candidate
ActGradCompute(context.Attr<int>("activation"), place, c, c,
d_g.slice(c_offsets, extents), d_h * u);
}
// backward for reset_hidden_prev
auto blas = math::GetBlas<DeviceContext, T>(context);
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
......@@ -213,7 +225,11 @@ class GRUUnitGradKernel : public framework::OpKernel<T> {
T* hidden_prev_grad_data =
hidden_prev_grad->mutable_data<T>(context.GetPlace());
auto d_h_p = EigenMatrix<T>::From(*hidden_prev_grad);
d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u);
if (context.Attr<bool>("origin_mode")) {
d_h_p.device(place) = d_r_h_p * r + d_h * u;
} else {
d_h_p.device(place) = d_r_h_p * r + d_h * (1 - u);
}
blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1,
gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1,
hidden_prev_grad_data, frame_size);
......
......@@ -56,7 +56,8 @@ template <class OpFinalOutput, typename T>
void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
T r_value_update_gate;
T r_value_frame_state;
T r_prev_out = 0;
......@@ -72,7 +73,7 @@ void hl_naive_gru_forward_final_output(OpFinalOutput op_final_output,
}
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node);
&r_output, active_node, origin_mode);
frame_state[i] = r_value_frame_state;
output_value[i] = r_output;
......@@ -146,7 +147,8 @@ template <class OpFinalOutput, typename T>
void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
#ifdef __AVX__
__m256 r_value_update_gate, r_value_update_gate_last = _mm256_set1_ps(0.0f);
__m256 r_value_frame_state, r_value_frame_state_last = _mm256_set1_ps(0.0f);
......@@ -180,7 +182,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
}
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node);
&r_output, active_node, origin_mode);
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state);
......@@ -190,7 +192,7 @@ void hl_avx_gru_forward_final_output(OpFinalOutput op_final_output,
if (rest > 0) {
i = n - block;
op_final_output(&r_value_update_gate_last, &r_value_frame_state_last,
&r_prev_out_last, &r_output, active_node);
&r_prev_out_last, &r_output, active_node, origin_mode);
_mm256_storeu_ps(reinterpret_cast<float *>(frame_state + i),
r_value_frame_state_last);
......@@ -227,17 +229,18 @@ inline void forward_reset_output(OpResetOutput op_reset_output,
template <class OpFinalOutput, typename T>
inline void forward_final_output(OpFinalOutput op_final_output,
GRUMetaValue<T> value, int frame_size,
int batch_size, ActivationType active_node) {
int batch_size, ActivationType active_node,
bool origin_mode) {
for (int b = 0; b < batch_size; b++) {
if (OpFinalOutput::avx && (frame_size > static_cast<int>(8 - 1)) &&
(sizeof(T) == 4)) {
hl_avx_gru_forward_final_output(op_final_output, value.gate_value,
value.prev_out_value, value.output_value,
frame_size, active_node);
frame_size, active_node, origin_mode);
} else {
hl_naive_gru_forward_final_output(
op_final_output, value.gate_value, value.prev_out_value,
value.output_value, frame_size, active_node);
value.output_value, frame_size, active_node, origin_mode);
}
value.gate_value += frame_size * 3;
......@@ -253,7 +256,8 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
T r_update_gate_value;
T r_update_gate_grad;
T r_frame_state_value;
......@@ -279,7 +283,7 @@ void hl_naive_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad,
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
&r_prev_out_grad, &r_out_grad, active_node);
&r_prev_out_grad, &r_out_grad, active_node, origin_mode);
update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad;
......@@ -338,8 +342,8 @@ template <class OpStateGrad, typename T>
void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size,
ActivationType active_node) {
int frame_size, ActivationType active_node,
bool origin_mode) {
#ifdef __AVX__
__m256 r_update_gate_value;
__m256 r_update_gate_grad;
......@@ -368,7 +372,7 @@ void hl_avx_gru_backward_state_grad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad,
&r_frame_state_value, &r_frame_state_grad, &r_prev_out_value,
&r_prev_out_grad, &r_out_grad, active_node);
&r_prev_out_grad, &r_out_grad, active_node, origin_mode);
update_gate_grad[i] = r_update_gate_grad;
frame_state_grad[i] = r_frame_state_grad;
......@@ -431,16 +435,18 @@ template <class OpStateGrad, typename T>
inline void backward_state_grad(OpStateGrad op_state_grad,
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
ActivationType active_node) {
ActivationType active_node, bool origin_mode) {
for (int b = 0; b < batch_size; b++) {
if (OpStateGrad::avx && !(frame_size & (8 - 1)) && (sizeof(T) == 4)) {
hl_avx_gru_backward_state_grad(
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad, frame_size, active_node);
hl_avx_gru_backward_state_grad(op_state_grad, value.gate_value,
grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad,
frame_size, active_node, origin_mode);
} else {
hl_naive_gru_backward_state_grad(
op_state_grad, value.gate_value, grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad, frame_size, active_node);
hl_naive_gru_backward_state_grad(op_state_grad, value.gate_value,
grad.gate_grad, value.prev_out_value,
grad.prev_out_grad, grad.output_grad,
frame_size, active_node, origin_mode);
}
value.gate_value += frame_size * 3;
......
......@@ -72,7 +72,8 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
T *gate_value, T *prev_output_value,
T *output_value, int frame_size,
int batch_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......@@ -94,7 +95,7 @@ __global__ void KeGruForwardFinalOutput(OpFinalOutput op_final_output,
}
op_final_output(&r_value_update_gate, &r_value_frame_state, &r_prev_out,
&r_output, active_node);
&r_output, active_node, origin_mode);
gate_value[frame_idx + frame_size * 2] = r_value_frame_state;
output_value[frame_idx] = r_output;
......@@ -109,7 +110,8 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
T *gate_grad, T *prev_out_value,
T *prev_out_grad, T *output_grad,
int frame_size, int batch_size,
ActivationType active_node) {
ActivationType active_node,
bool origin_mode) {
const int frame_idx = blockIdx.x * blockDim.x + threadIdx.x;
if (frame_idx >= frame_size) return;
int batch_idx = 0;
......@@ -139,7 +141,7 @@ __global__ void KeGruBackwardStateGrad(OpStateGrad op_state_grad, T *gate_value,
op_state_grad(&r_update_gate_value, &r_update_gate_grad, &r_frame_state_value,
&r_frame_state_grad, &r_prev_out_value, &r_prev_out_grad,
&r_out_grad, active_node);
&r_out_grad, active_node, origin_mode);
gate_grad[frame_idx + frame_size * 0] = r_update_gate_grad;
gate_grad[frame_idx + frame_size * 2] = r_frame_state_grad;
......
......@@ -57,10 +57,16 @@ class gru_finalOutput {
public:
HOSTDEVICE void operator()(T *value_update_gate, T *value_frame_state,
T *prev_out, T *value_output,
ActivationType act_input) {
ActivationType act_input, bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input);
*value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
((*value_update_gate) * (*value_frame_state));
if (origin_mode) {
*value_output = ((*value_update_gate) * (*prev_out)) +
*value_frame_state -
((*value_update_gate) * (*value_frame_state));
} else {
*value_output = *prev_out - ((*value_update_gate) * (*prev_out)) +
((*value_update_gate) * (*value_frame_state));
}
}
#ifndef __NVCC__
#ifndef __AVX__
......@@ -69,11 +75,20 @@ class gru_finalOutput {
static const bool avx = true;
HOSTDEVICE void operator()(__m256 *value_update_gate,
__m256 *value_frame_state, __m256 *prev_out,
__m256 *value_output, ActivationType act_input) {
__m256 *value_output, ActivationType act_input,
bool origin_mode) {
*value_frame_state = activation(*value_frame_state, act_input);
*value_output = _mm256_add_ps(
_mm256_sub_ps(*prev_out, _mm256_mul_ps(*value_update_gate, *prev_out)),
_mm256_mul_ps(*value_update_gate, *value_frame_state));
if (origin_mode) {
*value_output = _mm256_sub_ps(
_mm256_add_ps(_mm256_mul_ps(*value_update_gate, *prev_out),
*value_frame_state),
_mm256_mul_ps(*value_update_gate, *value_frame_state));
} else {
*value_output = _mm256_add_ps(
_mm256_sub_ps(*prev_out,
_mm256_mul_ps(*value_update_gate, *prev_out)),
_mm256_mul_ps(*value_update_gate, *value_frame_state));
}
}
#endif
#endif
......@@ -88,13 +103,23 @@ class gru_stateGrad {
HOSTDEVICE void operator()(T *value_update_gate, T *grad_update_gate,
T *value_frame_state, T *grad_frame_state,
T *value_prev_out, T *grad_prev_out,
T *grad_output, ActivationType act_input) {
*grad_update_gate = (*grad_output * (*value_frame_state));
*grad_update_gate -= (*grad_output * (*value_prev_out));
*grad_prev_out -= (*grad_output * (*value_update_gate));
*grad_prev_out += *grad_output;
*grad_frame_state = activation(*grad_output * (*value_update_gate),
*value_frame_state, act_input);
T *grad_output, ActivationType act_input,
bool origin_mode) {
if (origin_mode) {
*grad_update_gate =
(*grad_output) * ((*value_prev_out) - (*value_frame_state));
*grad_prev_out += (*grad_output * (*value_update_gate));
*grad_frame_state = activation(
*grad_output * (static_cast<T>(1.0) - (*value_update_gate)),
*value_frame_state, act_input);
} else {
*grad_update_gate =
(*grad_output) * ((*value_frame_state) - (*value_prev_out));
*grad_prev_out +=
(*grad_output * (static_cast<T>(1.0) - *value_update_gate));
*grad_frame_state = activation(*grad_output * (*value_update_gate),
*value_frame_state, act_input);
}
}
#ifndef __NVCC__
#ifndef __AVX__
......@@ -106,17 +131,27 @@ class gru_stateGrad {
__m256 *value_frame_state,
__m256 *grad_frame_state, __m256 *value_prev_out,
__m256 *grad_prev_out, __m256 *grad_output,
ActivationType act_input) {
*grad_update_gate = _mm256_mul_ps(*grad_output, *value_frame_state);
*grad_update_gate = _mm256_sub_ps(
*grad_update_gate, _mm256_mul_ps(*grad_output, *value_prev_out));
*grad_prev_out = _mm256_add_ps(
_mm256_sub_ps(*grad_prev_out,
_mm256_mul_ps(*grad_output, *value_update_gate)),
*grad_output);
*grad_frame_state =
activation(_mm256_mul_ps(*grad_output, *value_update_gate),
*value_frame_state, act_input);
ActivationType act_input, bool origin_mode) {
if (origin_mode) {
*grad_update_gate = _mm256_mul_ps(
*grad_output, _mm256_sub_ps(*value_prev_out, *value_frame_state));
*grad_prev_out = _mm256_add_ps(
*grad_prev_out, _mm256_mul_ps(*grad_output, *value_update_gate));
*grad_frame_state = activation(
_mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f),
*value_update_gate)),
*value_frame_state, act_input);
} else {
*grad_update_gate = _mm256_mul_ps(
*grad_output, _mm256_sub_ps(*value_frame_state, *value_prev_out));
*grad_prev_out = _mm256_add_ps(
*grad_prev_out,
_mm256_mul_ps(*grad_output, _mm256_sub_ps(_mm256_set1_ps(1.0f),
*value_update_gate)));
*grad_frame_state =
activation(_mm256_mul_ps(*grad_output, *value_update_gate),
*value_frame_state, act_input);
}
}
#endif
#endif
......
......@@ -23,7 +23,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
static void compute(const platform::CPUDeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
const detail::ActivationType active_gate,
bool origin_mode) {
#ifndef __NVCC__
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value) {
......@@ -43,7 +44,8 @@ struct GRUUnitFunctor<platform::CPUDeviceContext, T> {
}
detail::forward_final_output(detail::forward::gru_finalOutput<T>(), value,
frame_size, batch_size, active_node);
frame_size, batch_size, active_node,
origin_mode);
#endif
}
};
......@@ -54,10 +56,12 @@ struct GRUUnitGradFunctor<platform::CPUDeviceContext, T> {
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
const detail::ActivationType active_gate,
bool origin_mode) {
#ifndef __NVCC__
detail::backward_state_grad(detail::backward::gru_stateGrad<T>(), value,
grad, frame_size, batch_size, active_node);
grad, frame_size, batch_size, active_node,
origin_mode);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
if (value.prev_out_value && grad.prev_out_grad) {
blas.GEMM(false, true, batch_size, frame_size, frame_size, 1,
......
......@@ -24,7 +24,8 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
static void compute(const platform::CUDADeviceContext &context,
GRUMetaValue<T> value, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
const detail::ActivationType active_gate,
bool origin_mode) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
......@@ -73,14 +74,14 @@ struct GRUUnitFunctor<platform::CUDADeviceContext, T> {
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size,
active_node);
active_node, origin_mode);
} else {
detail::KeGruForwardFinalOutput<detail::forward::gru_finalOutput<T>,
/* is_batch= */ true,
T><<<grid, threads, 0, stream>>>(
detail::forward::gru_finalOutput<T>(), value.gate_value,
value.prev_out_value, value.output_value, frame_size, batch_size,
active_node);
active_node, origin_mode);
}
}
};
......@@ -91,7 +92,8 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
GRUMetaValue<T> value, GRUMetaGrad<T> grad,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate) {
const detail::ActivationType active_gate,
bool origin_mode) {
auto stream = context.stream();
dim3 threads;
dim3 grid;
......@@ -111,14 +113,14 @@ struct GRUUnitGradFunctor<platform::CUDADeviceContext, T> {
/* is_batch= */ false><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.output_grad, frame_size, batch_size, active_node);
grad.output_grad, frame_size, batch_size, active_node, origin_mode);
} else {
detail::KeGruBackwardStateGrad<
detail::backward::gru_stateGrad<T>,
/* is_batch= */ true><<<grid, threads, 0, stream>>>(
detail::backward::gru_stateGrad<T>(), value.gate_value,
grad.gate_grad, value.prev_out_value, grad.prev_out_grad,
grad.output_grad, frame_size, batch_size, active_node);
grad.output_grad, frame_size, batch_size, active_node, origin_mode);
}
auto blas = math::GetBlas<platform::CUDADeviceContext, T>(context);
......
......@@ -44,7 +44,8 @@ struct GRUUnitFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
const detail::ActivationType active_gate,
bool origin_mode);
};
template <typename DeviceContext, typename T>
......@@ -52,7 +53,8 @@ struct GRUUnitGradFunctor {
static void compute(const DeviceContext &context, GRUMetaValue<T> value,
GRUMetaGrad<T> grad, int frame_size, int batch_size,
const detail::ActivationType active_node,
const detail::ActivationType active_gate);
const detail::ActivationType active_gate,
bool origin_mode);
};
} // namespace math
......
......@@ -864,12 +864,14 @@ def dynamic_gru(input,
is_reverse=False,
gate_activation='sigmoid',
candidate_activation='tanh',
h_0=None):
h_0=None,
origin_mode=False):
"""
**Gated Recurrent Unit (GRU) Layer**
Refer to `Empirical Evaluation of Gated Recurrent Neural Networks on
Sequence Modeling <https://arxiv.org/abs/1412.3555>`_ .
if origin_mode is False, then the equation of a gru step is from paper
`Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
Modeling <https://arxiv.org/pdf/1412.3555.pdf>`_ .
The formula is as follows:
......@@ -883,6 +885,21 @@ def dynamic_gru(input,
h_t & = (1-u_t) \odot h_{t-1} + u_t \odot \\tilde{h_t}
if origin_mode is True then the equation is from paper
Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_
.. math::
u_t & = act_g(W_{ux}x_{t} + W_{uh}h_{t-1} + b_u)
r_t & = act_g(W_{rx}x_{t} + W_{rh}h_{t-1} + b_r)
\\tilde{h_t} & = act_c(W_{cx}x_{t} + W_{ch}(r_t \odot h_{t-1}) + b_c)
h_t & = u_t \odot h_{t-1} + (1-u_t) \odot \\tilde{h_t}
The :math:`\odot` is the element-wise product of the vectors. :math:`act_g`
is the update gate and reset gate activation function and :math:`sigmoid`
is usually used for it. :math:`act_c` is the activation function for
......@@ -980,7 +997,8 @@ def dynamic_gru(input,
attrs={
'is_reverse': is_reverse,
'gate_activation': gate_activation,
'activation': candidate_activation
'activation': candidate_activation,
'origin_mode': origin_mode
})
return hidden
......@@ -991,9 +1009,14 @@ def gru_unit(input,
param_attr=None,
bias_attr=None,
activation='tanh',
gate_activation='sigmoid'):
gate_activation='sigmoid',
origin_mode=False):
"""
GRU unit layer. The equation of a gru step is:
**GRU unit layer**
if origin_mode is True, then the equation of a gru step is from paper
`Learning Phrase Representations using RNN Encoder-Decoder for Statistical
Machine Translation <https://arxiv.org/pdf/1406.1078.pdf>`_
.. math::
u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u)
......@@ -1002,7 +1025,21 @@ def gru_unit(input,
m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m)
h_t & = dot((1-u_t), m_t) + dot(u_t, h_{t-1})
h_t & = dot(u_t, h_{t-1}) + dot((1-u_t), m_t)
if origin_mode is False, then the equation of a gru step is from paper
`Empirical Evaluation of Gated Recurrent Neural Networks on Sequence
Modeling <https://arxiv.org/pdf/1412.3555.pdf>`_
.. math::
u_t & = actGate(xu_{t} + W_u h_{t-1} + b_u)
r_t & = actGate(xr_{t} + W_r h_{t-1} + b_r)
m_t & = actNode(xm_t + W_c dot(r_t, h_{t-1}) + b_m)
h_t & = dot((1-u_t), h_{t-1}) + dot(u_t, m_t)
The inputs of gru unit includes :math:`z_t`, :math:`h_{t-1}`. In terms
of the equation above, the :math:`z_t` is split into 3 parts -
......
......@@ -31,7 +31,8 @@ def gru(
is_reverse,
act_state,
act_gate,
dtype='float32'):
dtype='float32',
origin_mode=False):
def _seq_to_batch(lod, is_reverse):
idx_in_seq_list = []
seq_lens = lod[0]
......@@ -66,7 +67,10 @@ def gru(
w_c = w.flatten()[D * D * 2:].reshape((D, D))
c = act_state(np.dot(r_h_p, w_c) + g[:, D * 2:])
g = np.hstack((u_r, c))
h = u * c + (1 - u) * h_p
if origin_mode:
h = (1 - u) * c + u * h_p
else:
h = u * c + (1 - u) * h_p
return g, r_h_p, h
T = sum(lod[0])
......@@ -110,6 +114,7 @@ class TestGRUOp(OpTest):
self.act_state = 'tanh'
self.act_gate = 'sigmoid'
self.dtype = 'float64'
self.origin_mode = False
self.set_confs()
T = sum(self.lod[0])
......@@ -126,7 +131,8 @@ class TestGRUOp(OpTest):
batch_gate, batch_reset_hidden_prev, batch_hidden, hidden = gru(
input, self.lod, h0, weight, bias, self.is_reverse,
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype)
ACTIVATION[self.act_state], ACTIVATION[self.act_gate], self.dtype,
self.origin_mode)
self.inputs = {'Input': (input, self.lod), 'Weight': weight}
if self.with_bias:
......@@ -145,7 +151,8 @@ class TestGRUOp(OpTest):
self.attrs = {
'activation': self.act_state,
'gate_activation': self.act_gate,
'is_reverse': self.is_reverse
'is_reverse': self.is_reverse,
'origin_mode': self.origin_mode
}
def test_check_output(self):
......@@ -155,12 +162,24 @@ class TestGRUOp(OpTest):
self.check_grad(['Input', 'H0', 'Weight', 'Bias'], ['Hidden'])
class TestGRUOriginMode(TestGRUOp):
def set_confs(self):
self.origin_mode = True
class TestGRUOp2(TestGRUOp):
def set_confs(self):
self.D = 19
self.dtype = 'float32'
class TestGRUOp2OriginMode(TestGRUOp):
def set_confs(self):
self.D = 19
self.dtype = 'float32'
self.origin_mode = True
class TestGRUOpNoInitial(TestGRUOp):
def set_confs(self):
self.with_h0 = False
......@@ -182,5 +201,11 @@ class TestGRUOpReverse(TestGRUOp):
self.is_reverse = True
class TestGRUOpReverseOriginMode(TestGRUOp):
def set_confs(self):
self.is_reverse = True
self.origin_mode = True
if __name__ == "__main__":
unittest.main()
......@@ -53,7 +53,7 @@ class TestGRUUnitOp(OpTest):
GRUActivationType.relu: relu,
}
def set_inputs(self):
def set_inputs(self, origin_mode=False):
batch_size = self.batch_size
frame_size = self.frame_size
self.op_type = 'gru_unit'
......@@ -68,10 +68,11 @@ class TestGRUUnitOp(OpTest):
}
self.attrs = {
'activation': GRUActivationType.tanh,
'gate_activation': GRUActivationType.sigmoid
'gate_activation': GRUActivationType.sigmoid,
'origin_mode': origin_mode
}
def set_outputs(self):
def set_outputs(self, origin_mode=False):
# GRU calculations
batch_size = self.batch_size
frame_size = self.frame_size
......@@ -93,7 +94,10 @@ class TestGRUUnitOp(OpTest):
c = self.activate[self.attrs['activation']](np.dot(r_h_p, w_c) +
g[:, frame_size * 2:])
g = np.hstack((u_r, c))
h = u * c + (1 - u) * h_p
if origin_mode:
h = (1 - u) * c + u * h_p
else:
h = u * c + (1 - u) * h_p
self.outputs = {
'Gate': g.astype('float64'),
'ResetHiddenPrev': r_h_p.astype('float64'),
......@@ -111,8 +115,14 @@ class TestGRUUnitOp(OpTest):
self.check_grad(['Input', 'HiddenPrev', 'Weight'], ['Hidden'])
class TestGRUUnitOpOriginMode(TestGRUUnitOp):
def setUp(self):
self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True)
class TestGRUUnitOpWithBias(TestGRUUnitOp):
def set_inputs(self):
def set_inputs(self, origin_mode=False):
batch_size = self.batch_size
frame_size = self.frame_size
super(TestGRUUnitOpWithBias, self).set_inputs()
......@@ -120,7 +130,8 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
-0.1, 0.1, (1, frame_size * 3)).astype('float64')
self.attrs = {
'activation': GRUActivationType.identity,
'gate_activation': GRUActivationType.sigmoid
'gate_activation': GRUActivationType.sigmoid,
'origin_mode': origin_mode
}
def test_check_grad(self):
......@@ -132,5 +143,11 @@ class TestGRUUnitOpWithBias(TestGRUUnitOp):
no_grad_set=set('Input'))
class TestGRUUnitOpWithBiasOriginMode(TestGRUUnitOpWithBias):
def setUp(self):
self.set_inputs(origin_mode=True)
self.set_outputs(origin_mode=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部