未验证 提交 e61f48c1 编写于 作者: H houj04 提交者: GitHub

[XPU] add sampling_id op, add top_k op, update xdnn api. test=kunlun (#44704)

上级 72b65d6b
...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220727") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220728")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
...@@ -19,7 +19,7 @@ endif() ...@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220727") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220728")
else() else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
...@@ -302,6 +302,11 @@ void TensorFromVector(const std::vector<T>& src, ...@@ -302,6 +302,11 @@ void TensorFromVector(const std::vector<T>& src,
size, size,
reinterpret_cast<const platform::CustomDeviceContext&>(ctx).stream()); reinterpret_cast<const platform::CustomDeviceContext&>(ctx).stream());
} }
#endif
#ifdef PADDLE_WITH_XPU
else if (platform::is_xpu_place(dst_place)) { // NOLINT
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
#endif #endif
else { // NOLINT else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
...@@ -381,6 +386,11 @@ inline void TensorFromVector(const std::vector<bool>& src, ...@@ -381,6 +386,11 @@ inline void TensorFromVector(const std::vector<bool>& src,
reinterpret_cast<const platform::CustomDeviceContext&>(ctx).stream(); reinterpret_cast<const platform::CustomDeviceContext&>(ctx).stream();
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream); memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size, stream);
} }
#endif
#ifdef PADDLE_WITH_XPU
else if (platform::is_xpu_place(dst_place)) { // NOLINT
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
#endif #endif
else { // NOLINT else { // NOLINT
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
......
...@@ -219,20 +219,14 @@ static std::pair<Tensor, Tensor> ProposalForOneImage( ...@@ -219,20 +219,14 @@ static std::pair<Tensor, Tensor> ProposalForOneImage(
// 4. nms // 4. nms
int nms_keep_num = 0; int nms_keep_num = 0;
r = xpu::nms<T>(dev_ctx.x_context(), r = xpu::sorted_nms<T>(dev_ctx.x_context(),
proposals_filter.data<T>(), proposals_filter.data<T>(),
nullptr,
keep_index.data<int>(), keep_index.data<int>(),
1, nms_keep_num,
1,
keep_num, keep_num,
-1,
nms_thresh, nms_thresh,
-1,
0,
&nms_keep_num,
pixel_offset); pixel_offset);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "nms"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "sorted_nms");
if (post_nms_top_n > 0 && post_nms_top_n < nms_keep_num) { if (post_nms_top_n > 0 && post_nms_top_n < nms_keep_num) {
keep_index.Resize({post_nms_top_n}); keep_index.Resize({post_nms_top_n});
} else { } else {
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <vector> #include <vector>
#include "paddle/fluid/operators/one_hot_op.h" #include "paddle/fluid/operators/one_hot_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,9 +29,13 @@ template <typename DeviceContext, typename T> ...@@ -28,9 +29,13 @@ template <typename DeviceContext, typename T>
class OneHotXPUKernel : public framework::OpKernel<T> { class OneHotXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X"); const auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out"); auto* out = context.Output<LoDTensor>("Out");
// get depth from attr
int depth = context.Attr<int>("depth"); int depth = context.Attr<int>("depth");
// get depth from input tensor
if (context.HasInput("depth_tensor")) { if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<Tensor>("depth_tensor"); auto* depth_tensor = context.Input<Tensor>("depth_tensor");
auto* depth_data = depth_tensor->data<int32_t>(); auto* depth_data = depth_tensor->data<int32_t>();
...@@ -50,18 +55,14 @@ class OneHotXPUKernel : public framework::OpKernel<T> { ...@@ -50,18 +55,14 @@ class OneHotXPUKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
int len = in->numel(); int len = in->numel();
// int one_hot(Context* ctx, const T* x, float* y, int len, int depth, float
// on_value = 1.0f, float off_value = 0.0f);
int ret = xpu::one_hot<T>(dev_ctx.x_context(), int ret = xpu::one_hot<T>(dev_ctx.x_context(),
in->data<T>(), in->data<T>(),
out->mutable_data<float>(context.GetPlace()), out->mutable_data<float>(context.GetPlace()),
len, len,
depth); depth);
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "one_hot");
PADDLE_ENFORCE_EQ(ret,
XPU_SUCCESS,
platform::errors::External(
"XPU one_hot kernel return wrong value[%d %s]",
ret,
XPUAPIErrorMsg[ret]));
} }
}; };
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/sampling_id_op.h"
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(sampling_id,
paddle::operators::SamplingIdKernel<float>,
paddle::operators::SamplingIdKernel<double>);
...@@ -322,6 +322,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -322,6 +322,9 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"one_hot",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})},
{"one_hot_v2", {"one_hot_v2",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})}, pOpKernelType(vartype::INT64, XPUPlace())})},
...@@ -393,6 +396,9 @@ XPUOpMap& get_kl2_ops() { ...@@ -393,6 +396,9 @@ XPUOpMap& get_kl2_ops() {
{"scatter", {"scatter",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"sampling_id",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace())})},
{"sgd", {"sgd",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -13,172 +13,117 @@ ...@@ -13,172 +13,117 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid.core as core
import sys import sys
sys.path.append("..") sys.path.append("..")
import paddle
import paddle.fluid.core as core
from op_test import OpTest
from op_test_xpu import XPUOpTest from op_test_xpu import XPUOpTest
import paddle.fluid as fluid from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
from paddle.fluid import Program, program_guard
import time
paddle.enable_static() paddle.enable_static()
"""
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
'core is not compiled with XPU')
class TestOneHotOp(XPUOpTest):
def setUp(self):
self.use_xpu = True
self.op_type = 'one_hot'
depth = 10
depth_np = np.array(10).astype('int32')
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]),
depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np} class XPUTestOneHotOP(XPUOpTestWrapper):
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self): def __init__(self):
place = paddle.XPUPlace(0) self.op_name = 'one_hot'
self.check_output_with_place(place, check_dygraph=False) self.use_dynamic_create_class = False
class TestXPUOneHotOP(XPUOpTest):
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
'core is not compiled with XPU')
class TestOneHotOp_attr(XPUOpTest):
def setUp(self): def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_dtype()
self.op_type = 'one_hot' self.op_type = 'one_hot'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]),
depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)} self.set_data()
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32), 'depth': depth} self.set_input()
self.outputs = {'Out': (out, x_lod)}
def set_data(self):
self.depth = 10
self.depth_np = np.array(10).astype('int32')
self.x_lod = [[4, 1, 3, 3]]
self.x = [
np.random.randint(0, self.depth - 1)
for i in range(sum(self.x_lod[0]))
]
self.x = np.array(self.x).astype(self.dtype).reshape(
[sum(self.x_lod[0]), 1])
self.out = np.zeros(shape=(np.product(self.x.shape[:-1]),
self.depth)).astype('float32')
for i in range(np.product(self.x.shape)):
self.out[i, self.x[i]] = 1.0
self.outputs = {'Out': (self.out, self.x_lod)}
def set_input(self):
self.inputs = {
'X': (self.x, self.x_lod),
'depth_tensor': self.depth_np
}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
def test_check_output(self): def test_check_output(self):
place = paddle.XPUPlace(0) self.check_output(check_dygraph=False)
self.check_output_with_place(place, check_dygraph=False)
def init_dtype(self):
self.dtype = self.in_type
@unittest.skipIf(not paddle.is_compiled_with_xpu(), class TestXPUOneHotOP_attr(TestXPUOneHotOP):
'core is not compiled with XPU')
class TestOneHotOp_default_dtype(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
depth_np = np.array(10).astype('int32')
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]), def set_input(self):
depth)).astype('float32') self.inputs = {'X': (self.x, self.x_lod)}
self.attrs = {
'dtype': int(core.VarDesc.VarType.FP32),
'depth': self.depth
}
for i in range(np.product(x.shape)): class TestXPUOneHotOP_default_dtype(TestXPUOneHotOP):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np} def set_input(self):
self.inputs = {
'X': (self.x, self.x_lod),
'depth_tensor': self.depth_np
}
self.attrs = {} self.attrs = {}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_xpu(), class TestXPUOneHotOP_default_dtype_attr(TestXPUOneHotOP):
'core is not compiled with XPU')
class TestOneHotOp_default_dtype_attr(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]), def set_input(self):
depth)).astype('float32') self.inputs = {'X': (self.x, self.x_lod)}
self.attrs = {'depth': self.depth}
for i in range(np.product(x.shape)): class TestXPUOneHotOP_out_of_range(TestXPUOneHotOP):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)} def set_data(self):
self.attrs = {'depth': depth} self.depth = 10
self.outputs = {'Out': (out, x_lod)} self.x_lod = [[4, 1, 3, 3]]
self.x = [
np.random.choice([-1, self.depth])
for i in range(sum(self.x_lod[0]))
]
self.x = np.array(self.x).astype(self.dtype).reshape(
[sum(self.x_lod[0]), 1])
def test_check_output(self): self.out = np.zeros(shape=(np.product(self.x.shape[:-1]),
place = paddle.XPUPlace(0) self.depth)).astype('float32')
self.check_output_with_place(place, check_dygraph=False)
self.outputs = {'Out': (self.out, self.x_lod)}
@unittest.skipIf(not paddle.is_compiled_with_xpu(), def set_input(self):
'core is not compiled with XPU') self.inputs = {'X': (self.x, self.x_lod)}
class TestOneHotOp_out_of_range(XPUOpTest): self.attrs = {'depth': self.depth, 'allow_out_of_range': True}
def setUp(self):
self.op_type = 'one_hot'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.choice([-1, depth]) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]),
depth)).astype('float32')
self.inputs = {'X': (x, x_lod)} support_types = get_xpu_op_support_types('one_hot')
self.attrs = {'depth': depth, 'allow_out_of_range': True} print("support_types: %s" % str(support_types))
self.outputs = {'Out': (out, x_lod)} for stype in support_types:
create_test_class(globals(), XPUTestOneHotOP, stype)
def test_check_output(self): if __name__ == "__main__":
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
'core is not compiled with XPU')
class TestOneHotOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
# the input must be Variable
in_w = np.random.random((4, 1)).astype('int32')
self.assertRaises(TypeError, fluid.layers.one_hot, in_w)
# the input must be int32 or int 64
in_w2 = fluid.layers.data(
name='in_w2',
shape=[4, 1],
append_batch_size=False,
dtype='float32')
self.assertRaises(TypeError, fluid.layers.one_hot, in_w2)
# the depth must be int, long or Variable
in_r = fluid.layers.data(
name='in_r',
shape=[4, 1],
append_batch_size=False,
dtype='int32')
depth_w = np.array([4])
self.assertRaises(TypeError, fluid.layers.one_hot, in_r, 4.1)
self.assertRaises(TypeError, fluid.layers.one_hot, in_r, depth_w)
"""
if __name__ == '__main__':
paddle.enable_static()
unittest.main() unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.op import Operator
import paddle
class TestSamplingIdShape(unittest.TestCase):
def test_shape(self):
paddle.enable_static()
x = fluid.layers.data(name='x', shape=[3], dtype='float32')
output = fluid.layers.sampling_id(x)
place = fluid.XPUPlace(0)
exe = fluid.Executor(place=place)
exe.run(fluid.default_startup_program())
feed = {
'x': np.array([[0.2, 0.3, 0.5], [0.2, 0.3, 0.4]], dtype='float32')
}
output_np = exe.run(feed=feed, fetch_list=[output])[0]
self.assertEqual(output.shape[0], -1)
self.assertEqual(len(output.shape), 1)
self.assertEqual(output_np.shape[0], 2)
self.assertEqual(len(output_np.shape), 1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册