未验证 提交 acf07c74 编写于 作者: H houj04 提交者: GitHub

[XPU] add top_k op (#44656)

* [XPU] add top_k op. test=kunlun

* [XPU] add top_k op. test=kunlun

* use PADDLE_ENFORCE_XDNN_NOT_NULL to check pointer. test=kunlun
上级 8ee9140b
...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220722") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220727")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
...@@ -19,7 +19,7 @@ endif() ...@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220722") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220727")
else() else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "xpu/refactor/math.h" #include "xpu/refactor/math.h"
namespace paddle { namespace paddle {
...@@ -25,17 +26,26 @@ namespace operators { ...@@ -25,17 +26,26 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
class TopkXPUKernel : public framework::OpKernel<T> { class TopkXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
// Get the top k elements of each row of input tensor // Get the top k elements of each row of input tensor
auto* input = ctx.Input<Tensor>("X"); const auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices"); auto* indices = ctx.Output<Tensor>("Indices");
size_t k = static_cast<int>(ctx.Attr<int>("k")); // get k from attr
int k = static_cast<int>(ctx.Attr<int>("k"));
// get k from input tensor
auto* k_t = ctx.Input<Tensor>("K"); auto* k_t = ctx.Input<Tensor>("K");
if (k_t) { if (k_t) {
k = k_t->data<int>()[0]; memory::Copy(platform::CPUPlace(),
static_cast<void*>(&k),
ctx.GetPlace(),
static_cast<const void*>(k_t->data<int>()),
sizeof(int));
framework::DDim output_dims = output->dims(); framework::DDim output_dims = output->dims();
output_dims[output_dims.size() - 1] = k; output_dims[output_dims.size() - 1] = k;
output->Resize(output_dims); output->Resize(output_dims);
...@@ -44,43 +54,36 @@ class TopkXPUKernel : public framework::OpKernel<T> { ...@@ -44,43 +54,36 @@ class TopkXPUKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
Tensor indices_32_data_tensor;
int32_t* indices_int_data = indices_32_data_tensor.mutable_data<int32_t>( auto& dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
ctx.GetPlace(), indices->numel()); // allocate temp memory for int32 index
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* indices_int_data = RAII_GUARD.alloc_l3_or_gm<int>(indices->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(indices_int_data);
// reshape input to a flattern matrix(like flat_inner_dims) // reshape input to a flattern matrix(like flat_inner_dims)
framework::DDim inputdims = input->dims(); framework::DDim inputdims = input->dims();
const size_t row = const size_t row =
phi::product(phi::slice_ddim(inputdims, 0, inputdims.size() - 1)); phi::product(phi::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t col = inputdims[inputdims.size() - 1]; const size_t col = inputdims[inputdims.size() - 1];
auto& dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
int ret = xpu::sorted_topk<T>(dev_ctx.x_context(), // int sorted_topk(Context* ctx, const T* x, T* y, int* index, int m, int n,
input->data<T>(), // int k, bool largest = true);
output_data, int r = xpu::sorted_topk(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(input->data<T>()),
reinterpret_cast<XPUType*>(output_data),
indices_int_data, indices_int_data,
row, row,
col, col,
k); k);
PADDLE_ENFORCE_EQ(ret, PADDLE_ENFORCE_XDNN_SUCCESS(r, "sorted_topk");
XPU_SUCCESS,
platform::errors::External( // cast to int64 as final result
"XPU API return wrong value[%d] in call kernel name " r = xpu::cast_v2<int32_t, int64_t>(dev_ctx.x_context(),
"[%s], please check "
"where Baidu Kunlun Card is properly installed.",
ret,
"sorted_topk"));
ret = xpu::cast_v2<int32_t, int64_t>(dev_ctx.x_context(),
(const int32_t*)indices_int_data, (const int32_t*)indices_int_data,
indices_data, indices_data,
indices->numel()); indices->numel());
PADDLE_ENFORCE_EQ(ret, PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast_v2");
XPU_SUCCESS,
platform::errors::External(
"XPU API return wrong value[%d] in call kernel name "
"[%s], please check "
"where Baidu Kunlun Card is properly installed.",
ret,
"cast_v2"));
} }
}; };
...@@ -88,5 +91,7 @@ class TopkXPUKernel : public framework::OpKernel<T> { ...@@ -88,5 +91,7 @@ class TopkXPUKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(top_k, ops::TopkXPUKernel<float>); REGISTER_OP_XPU_KERNEL(top_k,
ops::TopkXPUKernel<float>,
ops::TopkXPUKernel<paddle::platform::float16>);
#endif #endif
...@@ -105,13 +105,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -105,13 +105,9 @@ XPUOpMap& get_kl2_ops() {
{"elementwise_add", {"elementwise_add",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_div_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div_grad", {"elementwise_div_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"elementwise_div",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"elementwise_div", {"elementwise_div",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
...@@ -495,6 +491,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -495,6 +491,9 @@ XPUOpMap& get_kl2_ops() {
{"transpose", {"transpose",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"top_k",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"top_k_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"top_k_v2", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"unsqueeze2_grad", {"unsqueeze2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -18,29 +18,44 @@ import numpy as np ...@@ -18,29 +18,44 @@ import numpy as np
import sys import sys
sys.path.append("..") sys.path.append("..")
from paddle.fluid.op import Operator
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle import paddle
from op_test import OpTest from op_test import OpTest
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
paddle.enable_static() paddle.enable_static()
@unittest.skipIf(not paddle.is_compiled_with_xpu(), def random_unique_float(row, k, dtype):
"core is not compiled with XPU") # create a random float array with 10x length
class TestTopkOp(OpTest): arr = np.random.uniform(-10.0, 10.0, int(row * k * 10)).astype(dtype)
arr = np.unique(arr)
assert arr.shape[
0] >= row * k, "failed to create enough unique values: %d vs %d" % (
arr.shape[0], row * k)
arr = arr[:row * k]
np.random.shuffle(arr)
arr = arr.reshape(row, k)
return arr
class XPUTestTopkOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'top_k'
self.use_dynamic_create_class = False
class TestXPUTopkOP(XPUOpTest):
def setUp(self): def setUp(self):
self.variable_k = False self.place = paddle.XPUPlace(0)
self.use_xpu = True
self.set_args()
self.op_type = "top_k"
self.dtype = np.float32
self.init_dtype() self.init_dtype()
self.op_type = 'top_k'
self.set_case()
# generate UNIQUE float values as input, in order to prevent the following potential problem: x[i] and x[j] are IDENTICAL float values, the result of cpu index is [i, j] while the xpu result is [j, i]. Both of them are correct but diff in numpy compare.
k = self.top_k k = self.top_k
input = np.random.random((self.row, k)).astype(self.dtype) input = random_unique_float(self.row, k, self.dtype)
output = np.ndarray((self.row, k)) output = np.ndarray((self.row, k))
indices = np.ndarray((self.row, k)).astype("int64") indices = np.ndarray((self.row, k)).astype("int64")
self.inputs = {'X': input} self.inputs = {'X': input}
...@@ -58,22 +73,51 @@ class TestTopkOp(OpTest): ...@@ -58,22 +73,51 @@ class TestTopkOp(OpTest):
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = self.in_type
def set_args(self): def set_case(self):
self.row = 100 self.variable_k = False
self.top_k = 1 self.row = 16
self.top_k = 8
def test_check_output(self): def test_check_output(self):
if paddle.is_compiled_with_xpu(): self.check_output_with_place(self.place)
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self): def test_check_grad(self):
if paddle.is_compiled_with_xpu(): self.check_grad_with_place(self.place, ['X'], 'Out')
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out') class TestTopk1(TestXPUTopkOP):
def set_case(self):
self.variable_k = True
self.row = 100
self.top_k = 1
class TestTopk2(TestXPUTopkOP):
def set_case(self):
self.variable_k = False
self.row = 16
self.top_k = 256
class TestTopk3(TestXPUTopkOP):
def set_case(self):
self.variable_k = True
self.row = 10
self.top_k = 512
class TestTopk4(TestXPUTopkOP):
def set_case(self):
self.variable_k = False
self.row = 5
self.top_k = 511
support_types = get_xpu_op_support_types('top_k')
for stype in support_types:
create_test_class(globals(), XPUTestTopkOP, stype)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册