From 2314f2ebb3489d891b895a22a1495d5ba2a08381 Mon Sep 17 00:00:00 2001 From: whs <wanghaoshuang@baidu.com> Date: Wed, 26 Dec 2018 12:00:23 +0800 Subject: [PATCH] Make topk op support variable k. (#15044) * Make topk op support variable k. test=develop * Fix tensor type. test=develop --- paddle/fluid/operators/top_k_op.cc | 15 ++++++++++++++- paddle/fluid/operators/top_k_op.cu | 11 +++++++++++ paddle/fluid/operators/top_k_op.h | 12 ++++++++++-- python/paddle/fluid/layers/nn.py | 12 +++++++++--- .../paddle/fluid/tests/unittests/test_top_k_op.py | 15 +++++++++++++-- 5 files changed, 57 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index c17d1afc30..9e77f7252d 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -21,7 +21,7 @@ class TopkOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of TopkOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -44,12 +44,25 @@ class TopkOp : public framework::OperatorWithKernel { ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Indices"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; + return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), + ctx.device_context(), layout_, library_); + } }; class TopkOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor) The input of Topk op"); + AddInput("K", + "(Tensor) Number of top elements to look for along " + "the last dimension (along each row for matrices).") + .AsDispensable(); AddOutput("Out", "(Tensor) The output tensor of Topk op"); AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); AddComment(R"DOC( diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 99a4b1b7b0..c27039dd0a 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -327,6 +327,17 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { auto* indices = ctx.Output<Tensor>("Indices"); size_t k = static_cast<int>(ctx.Attr<int>("k")); + auto* k_t = ctx.Input<Tensor>("K"); + if (k_t) { + Tensor k_host; + framework::TensorCopySync(*k_t, platform::CPUPlace(), &k_host); + k = k_host.data<int>()[0]; + framework::DDim output_dims = output->dims(); + output_dims[output_dims.size() - 1] = k; + output->Resize(output_dims); + indices->Resize(output_dims); + } + const T* input_data = input->data<T>(); T* output_data = output->mutable_data<T>(ctx.GetPlace()); // FIXME(typhoonzero): data is always converted to type T? diff --git a/paddle/fluid/operators/top_k_op.h b/paddle/fluid/operators/top_k_op.h index 76ece57b39..f7bac67300 100644 --- a/paddle/fluid/operators/top_k_op.h +++ b/paddle/fluid/operators/top_k_op.h @@ -37,8 +37,16 @@ class TopkKernel : public framework::OpKernel<T> { auto* input = ctx.Input<Tensor>("X"); auto* output = ctx.Output<Tensor>("Out"); auto* indices = ctx.Output<Tensor>("Indices"); - // k is determined by Attr - const size_t k = static_cast<int>(ctx.Attr<int>("k")); + + size_t k = static_cast<int>(ctx.Attr<int>("k")); + auto* k_t = ctx.Input<Tensor>("K"); + if (k_t) { + k = k_t->data<int>()[0]; + framework::DDim output_dims = output->dims(); + output_dims[output_dims.size() - 1] = k; + output->Resize(output_dims); + indices->Resize(output_dims); + } T* output_data = output->mutable_data<T>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 8ac7efee50..cc1fdbd285 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -4530,7 +4530,7 @@ def topk(input, k, name=None): Args: input(Variable): The input variable which can be a vector or Tensor with higher rank. - k(int): The number of top elements to look for along the last dimension + k(int | Variable): The number of top elements to look for along the last dimension of input. name(str|None): A name for this layer(optional). If set None, the layer will be named automatically. @@ -4553,12 +4553,18 @@ def topk(input, k, name=None): helper = LayerHelper("top_k", **locals()) values = helper.create_variable_for_type_inference(dtype=input.dtype) indices = helper.create_variable_for_type_inference(dtype="int64") + inputs = {"X": [input]} + attrs = None + if isinstance(k, Variable): + inputs['K'] = k + else: + attrs = {'k': k} helper.append_op( type="top_k", - inputs={"X": [input]}, + inputs=inputs, outputs={"Out": [values], "Indices": [indices]}, - attrs={"k": k}) + attrs=attrs) values.stop_gradient = True indices.stop_gradient = True return values, indices diff --git a/python/paddle/fluid/tests/unittests/test_top_k_op.py b/python/paddle/fluid/tests/unittests/test_top_k_op.py index 21b5a62baf..9fbf59ed66 100644 --- a/python/paddle/fluid/tests/unittests/test_top_k_op.py +++ b/python/paddle/fluid/tests/unittests/test_top_k_op.py @@ -21,6 +21,7 @@ from op_test import OpTest class TestTopkOp(OpTest): def setUp(self): + self.variable_k = False self.set_args() self.op_type = "top_k" self.dtype = np.float32 @@ -30,9 +31,12 @@ class TestTopkOp(OpTest): input = np.random.random((self.row, k)).astype(self.dtype) output = np.ndarray((self.row, k)) indices = np.ndarray((self.row, k)).astype("int64") - self.inputs = {'X': input} - self.attrs = {'k': k} + + if self.variable_k: + self.inputs['K'] = np.array([k]).astype("int32") + else: + self.attrs = {'k': k} for rowid in range(self.row): row = input[rowid] @@ -118,5 +122,12 @@ class TestTopkOp4(TestTopkOp): self.top_k = 1 +class TestTopkOp5(TestTopkOp): + def set_args(self): + self.row = 40000 + self.top_k = 3 + self.variable_k = True + + if __name__ == "__main__": unittest.main() -- GitLab