未验证 提交 e7ac74c8 编写于 作者: W wuhuanzhou 提交者: GitHub

optimize compilation time of argmin/argmax op (#29595)

* Using VisitDataTypeTiny and put CastOP after ReduceOP, test=develop

* remove changes of reduce_op.h, test=develop
上级 c4eb5d03
...@@ -175,12 +175,13 @@ class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> { ...@@ -175,12 +175,13 @@ class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype"); auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) { if (dtype < 0) {
framework::VisitDataType(static_cast<framework::proto::VarType::Type>( framework::VisitDataTypeTiny(
framework::proto::VarType::INT64), static_cast<framework::proto::VarType::Type>(
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx)); framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
return; return;
} }
framework::VisitDataType( framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype), static_cast<framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx)); VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
} }
......
...@@ -128,13 +128,13 @@ class ArgMinMaxKernel : public framework::OpKernel<T> { ...@@ -128,13 +128,13 @@ class ArgMinMaxKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype"); auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) { if (dtype < 0) {
framework::VisitDataType( framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>( static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64), framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx)); VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
return; return;
} }
framework::VisitDataType( framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype), static_cast<framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx)); VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册