未验证 提交 95aa5342 编写于 作者: W wawltor 提交者: GitHub

update the code for the topk message optimize

update the code for the topk message optimize 
上级 4ba977c7
...@@ -23,20 +23,18 @@ class TopkV2Op : public framework::OperatorWithKernel { ...@@ -23,20 +23,18 @@ class TopkV2Op : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "topk_v2");
"Input(X) of TopkOp should not be null."); OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "topk_v2");
PADDLE_ENFORCE(ctx->HasOutput("Out"), OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "topk_v2");
"Output(Out) of TopkOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Indices"),
"Output(Indices) of TopkOp should not be null.");
auto input_dims = ctx->GetInputDim("X"); auto input_dims = ctx->GetInputDim("X");
const int& dim_size = input_dims.size(); const int& dim_size = input_dims.size();
int axis = static_cast<int>(ctx->Attrs().Get<int>("axis")); int axis = static_cast<int>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE_EQ((axis < dim_size) && (axis >= (-1 * dim_size)), true, PADDLE_ENFORCE_EQ(
"the axis of topk" (axis < dim_size) && (axis >= (-1 * dim_size)), true,
"must be [-%d, %d), but you set axis is %d", paddle::platform::errors::InvalidArgument(
dim_size, dim_size, axis); "the axis of topk must be [-%d, %d), but you set axis is %d",
dim_size, dim_size, axis));
if (axis < 0) axis += dim_size; if (axis < 0) axis += dim_size;
...@@ -47,18 +45,22 @@ class TopkV2Op : public framework::OperatorWithKernel { ...@@ -47,18 +45,22 @@ class TopkV2Op : public framework::OperatorWithKernel {
} else { } else {
k = static_cast<int>(ctx->Attrs().Get<int>("k")); k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_EQ(k >= 1, true, PADDLE_ENFORCE_EQ(k >= 1, true,
"the attribute of k in the topk must >= 1 or be a " paddle::platform::errors::InvalidArgument(
"Tensor, but received %d .", "the attribute of k in the topk must >= 1 or be a "
k); "Tensor, but received %d .",
k));
} }
PADDLE_ENFORCE_GE(input_dims.size(), 1, PADDLE_ENFORCE_GE(input_dims.size(), 1,
"input of topk must have >= 1d shape"); paddle::platform::errors::InvalidArgument(
"input of topk must have >= 1d shape"));
if (ctx->IsRuntime()) { if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
input_dims[axis], k, input_dims[axis], k,
"input of topk op must have >= %d columns in axis of %d", k, axis); paddle::platform::errors::InvalidArgument(
"input of topk op must have >= %d columns in axis of %d", k,
axis));
} }
framework::DDim dims = input_dims; framework::DDim dims = input_dims;
......
...@@ -38,8 +38,10 @@ template <typename DeviceContext, typename T> ...@@ -38,8 +38,10 @@ template <typename DeviceContext, typename T>
class TopkV2OpCUDAKernel : public framework::OpKernel<T> { class TopkV2OpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(
"It must use CUDAPlace."); platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
"It must use CUDAPlace, you must check your device set."));
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices"); auto* indices = ctx.Output<Tensor>("Indices");
...@@ -194,7 +196,8 @@ class TopkV2OpGradCUDAKernel : public framework::OpKernel<T> { ...@@ -194,7 +196,8 @@ class TopkV2OpGradCUDAKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true, platform::is_gpu_place(context.GetPlace()), true,
platform::errors::InvalidArgument("It must use CUDAPlace.")); platform::errors::InvalidArgument(
"It must use CUDAPlace, you must check your device set."));
auto* x = context.Input<Tensor>("X"); auto* x = context.Input<Tensor>("X");
auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out")); auto* out_grad = context.Input<Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<Tensor>("Indices"); auto* indices = context.Input<Tensor>("Indices");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册