提交 12fc02c5 编写于 作者: Z zhupengyang

fix dropout arm kernel

test=develop
上级 cdd63eb4
...@@ -27,10 +27,11 @@ void DropoutCompute::Run() { ...@@ -27,10 +27,11 @@ void DropoutCompute::Run() {
float* out_data = param.output->mutable_data<float>(); float* out_data = param.output->mutable_data<float>();
int num = param.x->dims().production(); int num = param.x->dims().production();
const float prob_data = param.dropout_prob; const float prob_data = param.dropout_prob;
if (param.dropout_implementation.compare(std::string({"downgrade_in_infer"}))) if (param.dropout_implementation == "upscale_in_train") {
lite::arm::math::dropout_down(x_data, out_data, num, prob_data);
else
lite::arm::math::dropout_up(x_data, out_data, num); lite::arm::math::dropout_up(x_data, out_data, num);
} else {
lite::arm::math::dropout_down(x_data, out_data, num, prob_data);
}
} }
} // namespace arm } // namespace arm
...@@ -41,8 +42,5 @@ void DropoutCompute::Run() { ...@@ -41,8 +42,5 @@ void DropoutCompute::Run() {
REGISTER_LITE_KERNEL(dropout, kARM, kFloat, kNCHW, REGISTER_LITE_KERNEL(dropout, kARM, kFloat, kNCHW,
paddle::lite::kernels::arm::DropoutCompute, def) paddle::lite::kernels::arm::DropoutCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("dropout_prob", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("dropout_implementation", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Mask", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -44,7 +44,7 @@ void dropout_compute_ref(const operators::DropoutParam& param) { ...@@ -44,7 +44,7 @@ void dropout_compute_ref(const operators::DropoutParam& param) {
int num = param.x->dims().production(); int num = param.x->dims().production();
const float prob_data = param.dropout_prob; const float prob_data = param.dropout_prob;
if (param.dropout_implementation.compare( if (param.dropout_implementation.compare(
std::string({"downgrade_in_infer"}))) { std::string({"downgrade_in_infer"})) == 0) {
float scale = 1.0 - prob_data; float scale = 1.0 - prob_data;
for (int i = 0; i < num; i++) { for (int i = 0; i < num; i++) {
output_data[i] = x_data[i] * scale; output_data[i] = x_data[i] * scale;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册