From e7ac74c85bbc0a1a023a90b9516114c1f458a2d1 Mon Sep 17 00:00:00 2001 From: wuhuanzhou Date: Tue, 22 Dec 2020 20:30:12 +0800 Subject: [PATCH] optimize compilation time of argmin/argmax op (#29595) * Using VisitDataTypeTiny and put CastOP after ReduceOP, test=develop * remove changes of reduce_op.h, test=develop --- paddle/fluid/operators/arg_min_max_op_base.cu.h | 9 +++++---- paddle/fluid/operators/arg_min_max_op_base.h | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/arg_min_max_op_base.cu.h b/paddle/fluid/operators/arg_min_max_op_base.cu.h index 73581dac4e4..3e549428b04 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.cu.h +++ b/paddle/fluid/operators/arg_min_max_op_base.cu.h @@ -175,12 +175,13 @@ class ArgMinMaxOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto& dtype = ctx.Attr("dtype"); if (dtype < 0) { - framework::VisitDataType(static_cast( - framework::proto::VarType::INT64), - VisitDataCudaArgMinMaxFunctor(ctx)); + framework::VisitDataTypeTiny( + static_cast( + framework::proto::VarType::INT64), + VisitDataCudaArgMinMaxFunctor(ctx)); return; } - framework::VisitDataType( + framework::VisitDataTypeTiny( static_cast(dtype), VisitDataCudaArgMinMaxFunctor(ctx)); } diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index 57e1c06f73c..77598c9a9eb 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -128,13 +128,13 @@ class ArgMinMaxKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto& dtype = ctx.Attr("dtype"); if (dtype < 0) { - framework::VisitDataType( + framework::VisitDataTypeTiny( static_cast( framework::proto::VarType::INT64), VisitDataArgMinMaxFunctor(ctx)); return; } - framework::VisitDataType( + framework::VisitDataTypeTiny( static_cast(dtype), VisitDataArgMinMaxFunctor(ctx)); } -- GitLab