From 2314f2ebb3489d891b895a22a1495d5ba2a08381 Mon Sep 17 00:00:00 2001
From: whs <wanghaoshuang@baidu.com>
Date: Wed, 26 Dec 2018 12:00:23 +0800
Subject: [PATCH] Make topk op support variable k. (#15044)

* Make topk op support variable k.
test=develop

* Fix tensor type.
test=develop
---
 paddle/fluid/operators/top_k_op.cc                | 15 ++++++++++++++-
 paddle/fluid/operators/top_k_op.cu                | 11 +++++++++++
 paddle/fluid/operators/top_k_op.h                 | 12 ++++++++++--
 python/paddle/fluid/layers/nn.py                  | 12 +++++++++---
 .../paddle/fluid/tests/unittests/test_top_k_op.py | 15 +++++++++++++--
 5 files changed, 57 insertions(+), 8 deletions(-)

diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc
index c17d1afc30..9e77f7252d 100644
--- a/paddle/fluid/operators/top_k_op.cc
+++ b/paddle/fluid/operators/top_k_op.cc
@@ -21,7 +21,7 @@ class TopkOp : public framework::OperatorWithKernel {
  public:
   using framework::OperatorWithKernel::OperatorWithKernel;
 
-  void InferShape(framework::InferShapeContext *ctx) const override {
+  void InferShape(framework::InferShapeContext* ctx) const override {
     PADDLE_ENFORCE(ctx->HasInput("X"),
                    "Input(X) of TopkOp should not be null.");
     PADDLE_ENFORCE(ctx->HasOutput("Out"),
@@ -44,12 +44,25 @@ class TopkOp : public framework::OperatorWithKernel {
     ctx->ShareLoD("X", "Out");
     ctx->ShareLoD("X", "Indices");
   }
+
+ protected:
+  framework::OpKernelType GetExpectedKernelType(
+      const framework::ExecutionContext& ctx) const override {
+    framework::LibraryType library_{framework::LibraryType::kPlain};
+    framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
+    return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
+                                   ctx.device_context(), layout_, library_);
+  }
 };
 
 class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
  public:
   void Make() override {
     AddInput("X", "(Tensor) The input of Topk op");
+    AddInput("K",
+             "(Tensor)  Number of top elements to look for along "
+             "the last dimension (along each row for matrices).")
+        .AsDispensable();
     AddOutput("Out", "(Tensor) The output tensor of Topk op");
     AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
     AddComment(R"DOC(
diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu
index 99a4b1b7b0..c27039dd0a 100644
--- a/paddle/fluid/operators/top_k_op.cu
+++ b/paddle/fluid/operators/top_k_op.cu
@@ -327,6 +327,17 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
     auto* indices = ctx.Output<Tensor>("Indices");
     size_t k = static_cast<int>(ctx.Attr<int>("k"));
 
+    auto* k_t = ctx.Input<Tensor>("K");
+    if (k_t) {
+      Tensor k_host;
+      framework::TensorCopySync(*k_t, platform::CPUPlace(), &k_host);
+      k = k_host.data<int>()[0];
+      framework::DDim output_dims = output->dims();
+      output_dims[output_dims.size() - 1] = k;
+      output->Resize(output_dims);
+      indices->Resize(output_dims);
+    }
+
     const T* input_data = input->data<T>();
     T* output_data = output->mutable_data<T>(ctx.GetPlace());
     // FIXME(typhoonzero): data is always converted to type T?
diff --git a/paddle/fluid/operators/top_k_op.h b/paddle/fluid/operators/top_k_op.h
index 76ece57b39..f7bac67300 100644
--- a/paddle/fluid/operators/top_k_op.h
+++ b/paddle/fluid/operators/top_k_op.h
@@ -37,8 +37,16 @@ class TopkKernel : public framework::OpKernel<T> {
     auto* input = ctx.Input<Tensor>("X");
     auto* output = ctx.Output<Tensor>("Out");
     auto* indices = ctx.Output<Tensor>("Indices");
-    // k is determined by Attr
-    const size_t k = static_cast<int>(ctx.Attr<int>("k"));
+
+    size_t k = static_cast<int>(ctx.Attr<int>("k"));
+    auto* k_t = ctx.Input<Tensor>("K");
+    if (k_t) {
+      k = k_t->data<int>()[0];
+      framework::DDim output_dims = output->dims();
+      output_dims[output_dims.size() - 1] = k;
+      output->Resize(output_dims);
+      indices->Resize(output_dims);
+    }
 
     T* output_data = output->mutable_data<T>(ctx.GetPlace());
     int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py
index 8ac7efee50..cc1fdbd285 100644
--- a/python/paddle/fluid/layers/nn.py
+++ b/python/paddle/fluid/layers/nn.py
@@ -4530,7 +4530,7 @@ def topk(input, k, name=None):
     Args:
         input(Variable): The input variable which can be a vector or Tensor with
             higher rank.
-        k(int):  The number of top elements to look for along the last dimension
+        k(int | Variable):  The number of top elements to look for along the last dimension
                  of input.
         name(str|None): A name for this layer(optional). If set None, the layer
                        will be named automatically.
@@ -4553,12 +4553,18 @@ def topk(input, k, name=None):
     helper = LayerHelper("top_k", **locals())
     values = helper.create_variable_for_type_inference(dtype=input.dtype)
     indices = helper.create_variable_for_type_inference(dtype="int64")
+    inputs = {"X": [input]}
+    attrs = None
+    if isinstance(k, Variable):
+        inputs['K'] = k
+    else:
+        attrs = {'k': k}
     helper.append_op(
         type="top_k",
-        inputs={"X": [input]},
+        inputs=inputs,
         outputs={"Out": [values],
                  "Indices": [indices]},
-        attrs={"k": k})
+        attrs=attrs)
     values.stop_gradient = True
     indices.stop_gradient = True
     return values, indices
diff --git a/python/paddle/fluid/tests/unittests/test_top_k_op.py b/python/paddle/fluid/tests/unittests/test_top_k_op.py
index 21b5a62baf..9fbf59ed66 100644
--- a/python/paddle/fluid/tests/unittests/test_top_k_op.py
+++ b/python/paddle/fluid/tests/unittests/test_top_k_op.py
@@ -21,6 +21,7 @@ from op_test import OpTest
 
 class TestTopkOp(OpTest):
     def setUp(self):
+        self.variable_k = False
         self.set_args()
         self.op_type = "top_k"
         self.dtype = np.float32
@@ -30,9 +31,12 @@ class TestTopkOp(OpTest):
         input = np.random.random((self.row, k)).astype(self.dtype)
         output = np.ndarray((self.row, k))
         indices = np.ndarray((self.row, k)).astype("int64")
-
         self.inputs = {'X': input}
-        self.attrs = {'k': k}
+
+        if self.variable_k:
+            self.inputs['K'] = np.array([k]).astype("int32")
+        else:
+            self.attrs = {'k': k}
 
         for rowid in range(self.row):
             row = input[rowid]
@@ -118,5 +122,12 @@ class TestTopkOp4(TestTopkOp):
         self.top_k = 1
 
 
+class TestTopkOp5(TestTopkOp):
+    def set_args(self):
+        self.row = 40000
+        self.top_k = 3
+        self.variable_k = True
+
+
 if __name__ == "__main__":
     unittest.main()
-- 
GitLab