From 12fc02c5655a73e2019197992ccab1d2aca3f5ec Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Mon, 24 Jun 2019 01:50:26 +0000 Subject: [PATCH] fix dropout arm kernel test=develop --- paddle/fluid/lite/kernels/arm/dropout_compute.cc | 10 ++++------ paddle/fluid/lite/kernels/arm/dropout_compute_test.cc | 2 +- 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/lite/kernels/arm/dropout_compute.cc b/paddle/fluid/lite/kernels/arm/dropout_compute.cc index d76b303f94..99b3e85974 100644 --- a/paddle/fluid/lite/kernels/arm/dropout_compute.cc +++ b/paddle/fluid/lite/kernels/arm/dropout_compute.cc @@ -27,10 +27,11 @@ void DropoutCompute::Run() { float* out_data = param.output->mutable_data(); int num = param.x->dims().production(); const float prob_data = param.dropout_prob; - if (param.dropout_implementation.compare(std::string({"downgrade_in_infer"}))) - lite::arm::math::dropout_down(x_data, out_data, num, prob_data); - else + if (param.dropout_implementation == "upscale_in_train") { lite::arm::math::dropout_up(x_data, out_data, num); + } else { + lite::arm::math::dropout_down(x_data, out_data, num, prob_data); + } } } // namespace arm @@ -41,8 +42,5 @@ void DropoutCompute::Run() { REGISTER_LITE_KERNEL(dropout, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::DropoutCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("dropout_prob", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("dropout_implementation", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Mask", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/dropout_compute_test.cc b/paddle/fluid/lite/kernels/arm/dropout_compute_test.cc index 428901c159..960d47442b 100644 --- a/paddle/fluid/lite/kernels/arm/dropout_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/dropout_compute_test.cc @@ -44,7 +44,7 @@ void dropout_compute_ref(const operators::DropoutParam& param) { int num = param.x->dims().production(); const float prob_data = param.dropout_prob; if (param.dropout_implementation.compare( - std::string({"downgrade_in_infer"}))) { + std::string({"downgrade_in_infer"})) == 0) { float scale = 1.0 - prob_data; for (int i = 0; i < num; i++) { output_data[i] = x_data[i] * scale; -- GitLab