未验证 提交 f9e8a775 编写于 作者: L Li Min 提交者: GitHub

Add constructor for fused dropout param to ease use. (#38475)

上级 dba59db7
...@@ -51,6 +51,18 @@ struct DropoutParam { ...@@ -51,6 +51,18 @@ struct DropoutParam {
seed_val = 0; seed_val = 0;
} }
DropoutParam(bool fix_seed_, uint64_t seed_, bool is_test_,
bool is_upscale_in_train_, float dropout_prob_,
const framework::Tensor* tensor_seed_, int seed_val_) {
fix_seed = fix_seed_;
seed = seed_;
is_test = is_test_;
is_upscale_in_train = is_upscale_in_train_;
dropout_prob = dropout_prob_;
tensor_seed = tensor_seed_;
seed_val = seed_val_;
}
/** /**
* dropout_index: can be 0, 1, 2. 0 means there is only one dropout, * dropout_index: can be 0, 1, 2. 0 means there is only one dropout,
* 1 and 2 represent two dropout, the parameter name of dropout * 1 and 2 represent two dropout, the parameter name of dropout
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册