diff --git a/paddle/fluid/operators/arg_min_max_op_base.cu.h b/paddle/fluid/operators/arg_min_max_op_base.cu.h index 73581dac4e419ca9c970db4414ff54d4cbd3fd70..3e549428b04182e001b31ae138b2d63e37d95475 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.cu.h +++ b/paddle/fluid/operators/arg_min_max_op_base.cu.h @@ -175,12 +175,13 @@ class ArgMinMaxOpCUDAKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto& dtype = ctx.Attr("dtype"); if (dtype < 0) { - framework::VisitDataType(static_cast( - framework::proto::VarType::INT64), - VisitDataCudaArgMinMaxFunctor(ctx)); + framework::VisitDataTypeTiny( + static_cast( + framework::proto::VarType::INT64), + VisitDataCudaArgMinMaxFunctor(ctx)); return; } - framework::VisitDataType( + framework::VisitDataTypeTiny( static_cast(dtype), VisitDataCudaArgMinMaxFunctor(ctx)); } diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index 57e1c06f73c56334fc93dee7a16d6899f5a6f12a..77598c9a9ebbdc6b1486c5b71e7fc0995061050d 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -128,13 +128,13 @@ class ArgMinMaxKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& ctx) const override { auto& dtype = ctx.Attr("dtype"); if (dtype < 0) { - framework::VisitDataType( + framework::VisitDataTypeTiny( static_cast( framework::proto::VarType::INT64), VisitDataArgMinMaxFunctor(ctx)); return; } - framework::VisitDataType( + framework::VisitDataTypeTiny( static_cast(dtype), VisitDataArgMinMaxFunctor(ctx)); }