From f9e8a775f39f91ed85f078e33f2d387fb6876785 Mon Sep 17 00:00:00 2001 From: Li Min <11663212+limin2021@users.noreply.github.com> Date: Tue, 28 Dec 2021 09:42:15 +0800 Subject: [PATCH] Add constructor for fused dropout param to ease use. (#38475) --- paddle/fluid/operators/fused/fused_dropout_helper.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/paddle/fluid/operators/fused/fused_dropout_helper.h b/paddle/fluid/operators/fused/fused_dropout_helper.h index 33fde64164..970b2d82e2 100644 --- a/paddle/fluid/operators/fused/fused_dropout_helper.h +++ b/paddle/fluid/operators/fused/fused_dropout_helper.h @@ -51,6 +51,18 @@ struct DropoutParam { seed_val = 0; } + DropoutParam(bool fix_seed_, uint64_t seed_, bool is_test_, + bool is_upscale_in_train_, float dropout_prob_, + const framework::Tensor* tensor_seed_, int seed_val_) { + fix_seed = fix_seed_; + seed = seed_; + is_test = is_test_; + is_upscale_in_train = is_upscale_in_train_; + dropout_prob = dropout_prob_; + tensor_seed = tensor_seed_; + seed_val = seed_val_; + } + /** * dropout_index: can be 0, 1, 2. 0 means there is only one dropout, * 1 and 2 represent two dropout, the parameter name of dropout -- GitLab