diff --git a/paddle/fluid/lite/kernels/arm/dropout_compute.cc b/paddle/fluid/lite/kernels/arm/dropout_compute.cc index d76b303f9465c899c2eec542921ecdcffbc927e6..99b3e859747c9518821c70cf68a9170d949f33f3 100644 --- a/paddle/fluid/lite/kernels/arm/dropout_compute.cc +++ b/paddle/fluid/lite/kernels/arm/dropout_compute.cc @@ -27,10 +27,11 @@ void DropoutCompute::Run() { float* out_data = param.output->mutable_data(); int num = param.x->dims().production(); const float prob_data = param.dropout_prob; - if (param.dropout_implementation.compare(std::string({"downgrade_in_infer"}))) - lite::arm::math::dropout_down(x_data, out_data, num, prob_data); - else + if (param.dropout_implementation == "upscale_in_train") { lite::arm::math::dropout_up(x_data, out_data, num); + } else { + lite::arm::math::dropout_down(x_data, out_data, num, prob_data); + } } } // namespace arm @@ -41,8 +42,5 @@ void DropoutCompute::Run() { REGISTER_LITE_KERNEL(dropout, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::DropoutCompute, def) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("dropout_prob", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindInput("dropout_implementation", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) - .BindOutput("Mask", {LiteType::GetTensorTy(TARGET(kARM))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/arm/dropout_compute_test.cc b/paddle/fluid/lite/kernels/arm/dropout_compute_test.cc index 428901c15973c1326f69282b10bc82f4b35688c2..960d47442b6f9d765fe17db60b90557e9625efcd 100644 --- a/paddle/fluid/lite/kernels/arm/dropout_compute_test.cc +++ b/paddle/fluid/lite/kernels/arm/dropout_compute_test.cc @@ -44,7 +44,7 @@ void dropout_compute_ref(const operators::DropoutParam& param) { int num = param.x->dims().production(); const float prob_data = param.dropout_prob; if (param.dropout_implementation.compare( - std::string({"downgrade_in_infer"}))) { + std::string({"downgrade_in_infer"})) == 0) { float scale = 1.0 - prob_data; for (int i = 0; i < num; i++) { output_data[i] = x_data[i] * scale;