未验证 提交 e7ac74c8 编写于 作者: W wuhuanzhou 提交者: GitHub

optimize compilation time of argmin/argmax op (#29595)

* Using VisitDataTypeTiny and put CastOP after ReduceOP, test=develop

* remove changes of reduce_op.h, test=develop
上级 c4eb5d03
......@@ -175,12 +175,13 @@ class ArgMinMaxOpCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataType(static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
return;
}
framework::VisitDataType(
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataCudaArgMinMaxFunctor<T, Reducer>(ctx));
}
......
......@@ -128,13 +128,13 @@ class ArgMinMaxKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dtype = ctx.Attr<int>("dtype");
if (dtype < 0) {
framework::VisitDataType(
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(
framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
return;
}
framework::VisitDataType(
framework::VisitDataTypeTiny(
static_cast<framework::proto::VarType::Type>(dtype),
VisitDataArgMinMaxFunctor<DeviceContext, T, EnumArgMinMaxValue>(ctx));
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册