未验证 提交 5859d0a6 编写于 作者: S sneaxiy 提交者: GitHub

add gather dtype err msg (#48002)

上级 fd550c1b
......@@ -65,6 +65,10 @@ void GatherGradKernel(const Context& dev_ctx,
phi::funcs::ScatterAssignAdd<T, int64_t>(
dev_ctx, out_grad, index, x_grad);
}
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The data type of Input(Index) of gather_grad must be int32 or int64 "
"on CPU."));
}
}
......
......@@ -49,6 +49,10 @@ void GatherKernel(const Context& dev_ctx,
phi::funcs::CPUGather<T, int>(dev_ctx, x, index, out);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::CPUGather<T, int64_t>(dev_ctx, x, index, out);
} else {
PADDLE_THROW(
phi::errors::InvalidArgument("The data type of Input(Index) of gather "
"must be int32 or int64 on CPU."));
}
}
......
......@@ -55,6 +55,10 @@ void GatherGradKernel(const Context& dev_ctx,
} else if (index_type == DataType::INT64) {
phi::funcs::GPUScatterAssign<T, int64_t>(
dev_ctx, out_grad, index, x_grad, overwrite);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"The data type of Input(Index) of gather_grad must be int32 or int64 "
"on GPU."));
}
}
......
......@@ -52,6 +52,10 @@ void GatherKernel(const Context& dev_ctx,
phi::funcs::GPUGather<T, int64_t>(dev_ctx, x, index, out);
} else if (index_type == phi::DataType::INT16) {
phi::funcs::GPUGather<T, int16_t>(dev_ctx, x, index, out);
} else {
PADDLE_THROW(
phi::errors::InvalidArgument("The data type of Input(Index) of gather "
"must be int16, int32 or int64 on GPU."));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册