未验证 提交 95aa5342 编写于 作者: W wawltor 提交者: GitHub

update the code for the topk message optimize

update the code for the topk message optimize 
上级 4ba977c7
......@@ -23,20 +23,18 @@ class TopkV2Op : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of TopkOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of TopkOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Indices"),
"Output(Indices) of TopkOp should not be null.");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "topk_v2");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "topk_v2");
OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "topk_v2");
auto input_dims = ctx->GetInputDim("X");
const int& dim_size = input_dims.size();
int axis = static_cast<int>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE_EQ((axis < dim_size) && (axis >= (-1 * dim_size)), true,
"the axis of topk"
"must be [-%d, %d), but you set axis is %d",
dim_size, dim_size, axis);
PADDLE_ENFORCE_EQ(
(axis < dim_size) && (axis >= (-1 * dim_size)), true,
paddle::platform::errors::InvalidArgument(
"the axis of topk must be [-%d, %d), but you set axis is %d",
dim_size, dim_size, axis));
if (axis < 0) axis += dim_size;
......@@ -47,18 +45,22 @@ class TopkV2Op : public framework::OperatorWithKernel {
} else {
k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_EQ(k >= 1, true,
"the attribute of k in the topk must >= 1 or be a "
"Tensor, but received %d .",
k);
paddle::platform::errors::InvalidArgument(
"the attribute of k in the topk must >= 1 or be a "
"Tensor, but received %d .",
k));
}
PADDLE_ENFORCE_GE(input_dims.size(), 1,
"input of topk must have >= 1d shape");
paddle::platform::errors::InvalidArgument(
"input of topk must have >= 1d shape"));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(
input_dims[axis], k,
"input of topk op must have >= %d columns in axis of %d", k, axis);
paddle::platform::errors::InvalidArgument(
"input of topk op must have >= %d columns in axis of %d", k,
axis));
}
framework::DDim dims = input_dims;
......
......@@ -38,8 +38,10 @@ template <typename DeviceContext, typename T>
class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use CUDAPlace.");
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
"It must use CUDAPlace, you must check your device set."));
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
......@@ -194,7 +196,8 @@ class TopkV2OpGradCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace."));
platform::errors::InvalidArgument(
"It must use CUDAPlace, you must check your device set."));
auto* x = context.Input<Tensor>("X");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<Tensor>("Indices");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册