未验证 提交 2314f2eb 编写于 作者: W whs 提交者: GitHub

Make topk op support variable k. (#15044)

* Make topk op support variable k.
test=develop

* Fix tensor type.
test=develop
上级 09bd8fa6
...@@ -21,7 +21,7 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -21,7 +21,7 @@ class TopkOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of TopkOp should not be null."); "Input(X) of TopkOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
...@@ -44,12 +44,25 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -44,12 +44,25 @@ class TopkOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", "Out"); ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices"); ctx->ShareLoD("X", "Indices");
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context(), layout_, library_);
}
}; };
class TopkOpMaker : public framework::OpProtoAndCheckerMaker { class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor) The input of Topk op"); AddInput("X", "(Tensor) The input of Topk op");
AddInput("K",
"(Tensor) Number of top elements to look for along "
"the last dimension (along each row for matrices).")
.AsDispensable();
AddOutput("Out", "(Tensor) The output tensor of Topk op"); AddOutput("Out", "(Tensor) The output tensor of Topk op");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
AddComment(R"DOC( AddComment(R"DOC(
......
...@@ -327,6 +327,17 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -327,6 +327,17 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
auto* indices = ctx.Output<Tensor>("Indices"); auto* indices = ctx.Output<Tensor>("Indices");
size_t k = static_cast<int>(ctx.Attr<int>("k")); size_t k = static_cast<int>(ctx.Attr<int>("k"));
auto* k_t = ctx.Input<Tensor>("K");
if (k_t) {
Tensor k_host;
framework::TensorCopySync(*k_t, platform::CPUPlace(), &k_host);
k = k_host.data<int>()[0];
framework::DDim output_dims = output->dims();
output_dims[output_dims.size() - 1] = k;
output->Resize(output_dims);
indices->Resize(output_dims);
}
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T? // FIXME(typhoonzero): data is always converted to type T?
......
...@@ -37,8 +37,16 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -37,8 +37,16 @@ class TopkKernel : public framework::OpKernel<T> {
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices"); auto* indices = ctx.Output<Tensor>("Indices");
// k is determined by Attr
const size_t k = static_cast<int>(ctx.Attr<int>("k")); size_t k = static_cast<int>(ctx.Attr<int>("k"));
auto* k_t = ctx.Input<Tensor>("K");
if (k_t) {
k = k_t->data<int>()[0];
framework::DDim output_dims = output->dims();
output_dims[output_dims.size() - 1] = k;
output->Resize(output_dims);
indices->Resize(output_dims);
}
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
......
...@@ -4530,7 +4530,7 @@ def topk(input, k, name=None): ...@@ -4530,7 +4530,7 @@ def topk(input, k, name=None):
Args: Args:
input(Variable): The input variable which can be a vector or Tensor with input(Variable): The input variable which can be a vector or Tensor with
higher rank. higher rank.
k(int): The number of top elements to look for along the last dimension k(int | Variable): The number of top elements to look for along the last dimension
of input. of input.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
...@@ -4553,12 +4553,18 @@ def topk(input, k, name=None): ...@@ -4553,12 +4553,18 @@ def topk(input, k, name=None):
helper = LayerHelper("top_k", **locals()) helper = LayerHelper("top_k", **locals())
values = helper.create_variable_for_type_inference(dtype=input.dtype) values = helper.create_variable_for_type_inference(dtype=input.dtype)
indices = helper.create_variable_for_type_inference(dtype="int64") indices = helper.create_variable_for_type_inference(dtype="int64")
inputs = {"X": [input]}
attrs = None
if isinstance(k, Variable):
inputs['K'] = k
else:
attrs = {'k': k}
helper.append_op( helper.append_op(
type="top_k", type="top_k",
inputs={"X": [input]}, inputs=inputs,
outputs={"Out": [values], outputs={"Out": [values],
"Indices": [indices]}, "Indices": [indices]},
attrs={"k": k}) attrs=attrs)
values.stop_gradient = True values.stop_gradient = True
indices.stop_gradient = True indices.stop_gradient = True
return values, indices return values, indices
......
...@@ -21,6 +21,7 @@ from op_test import OpTest ...@@ -21,6 +21,7 @@ from op_test import OpTest
class TestTopkOp(OpTest): class TestTopkOp(OpTest):
def setUp(self): def setUp(self):
self.variable_k = False
self.set_args() self.set_args()
self.op_type = "top_k" self.op_type = "top_k"
self.dtype = np.float32 self.dtype = np.float32
...@@ -30,9 +31,12 @@ class TestTopkOp(OpTest): ...@@ -30,9 +31,12 @@ class TestTopkOp(OpTest):
input = np.random.random((self.row, k)).astype(self.dtype) input = np.random.random((self.row, k)).astype(self.dtype)
output = np.ndarray((self.row, k)) output = np.ndarray((self.row, k))
indices = np.ndarray((self.row, k)).astype("int64") indices = np.ndarray((self.row, k)).astype("int64")
self.inputs = {'X': input} self.inputs = {'X': input}
self.attrs = {'k': k}
if self.variable_k:
self.inputs['K'] = np.array([k]).astype("int32")
else:
self.attrs = {'k': k}
for rowid in range(self.row): for rowid in range(self.row):
row = input[rowid] row = input[rowid]
...@@ -118,5 +122,12 @@ class TestTopkOp4(TestTopkOp): ...@@ -118,5 +122,12 @@ class TestTopkOp4(TestTopkOp):
self.top_k = 1 self.top_k = 1
class TestTopkOp5(TestTopkOp):
def set_args(self):
self.row = 40000
self.top_k = 3
self.variable_k = True
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册