未验证 提交 4f0c3278 编写于 作者: R ronnywang 提交者: GitHub

[NPU] add randperm_op_npu (#35763)

* add randperm_op_npu

* fix test_set_value_op_npu
上级 0c6ee945
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/randperm_op.h"
#include "paddle/fluid/framework/op_registry.h"
template <typename T>
using kernel =
paddle::operators::RandpermKernel<paddle::platform::NPUDeviceContext, T>;
REGISTER_OP_NPU_KERNEL(randperm, kernel<int64_t>, kernel<int>, kernel<float>,
kernel<double>);
...@@ -114,6 +114,11 @@ inline T GetValue(const framework::Tensor* x) { ...@@ -114,6 +114,11 @@ inline T GetValue(const framework::Tensor* x) {
if (!platform::is_cpu_place(x->place())) { if (!platform::is_cpu_place(x->place())) {
framework::Tensor cpu_x; framework::Tensor cpu_x;
framework::TensorCopy(*x, platform::CPUPlace(), &cpu_x); framework::TensorCopy(*x, platform::CPUPlace(), &cpu_x);
#ifdef PADDLE_WITH_ASCEND_CL
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
const platform::DeviceContext* dev_ctx = pool.Get(x->place());
dev_ctx->Wait();
#endif
value = cpu_x.data<T>()[0]; value = cpu_x.data<T>()[0];
} else { } else {
value = x->data<T>()[0]; value = x->data<T>()[0];
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid.core as core
from paddle.static import program_guard, Program
from test_randperm_op import check_randperm_out, error_msg, convert_dtype
paddle.enable_static()
class TestRandpermOp(OpTest):
""" Test randperm op."""
def setUp(self):
self.set_npu()
self.op_type = "randperm"
self.n = 200
self.dtype = "int64"
self.inputs = {}
self.outputs = {"Out": np.zeros((self.n)).astype(self.dtype)}
self.init_attrs()
self.attrs = {
"n": self.n,
"dtype": convert_dtype(self.dtype),
}
def set_npu(self):
self.__class__.use_npu = True
def _get_places(self):
return [paddle.NPUPlace(0)]
def init_attrs(self):
pass
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
out_np = np.array(outs[0])
self.assertTrue(
check_randperm_out(self.n, out_np), msg=error_msg(out_np))
class TestRandpermOpN(TestRandpermOp):
def init_attrs(self):
self.n = 10000
class TestRandpermOpInt32(TestRandpermOp):
def init_attrs(self):
self.dtype = "int32"
class TestRandpermOpFloat32(TestRandpermOp):
def init_attrs(self):
self.dtype = "float32"
class TestRandpermOpFloat64(TestRandpermOp):
def init_attrs(self):
self.dtype = "float64"
class TestRandpermOpError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
self.assertRaises(ValueError, paddle.randperm, -3)
self.assertRaises(TypeError, paddle.randperm, 10, 'int8')
class TestRandpermAPI(unittest.TestCase):
def test_out(self):
n = 10
place = paddle.NPUPlace(0)
with program_guard(Program(), Program()):
x1 = paddle.randperm(n)
x2 = paddle.randperm(n, 'float32')
exe = paddle.static.Executor(place)
res = exe.run(fetch_list=[x1, x2])
self.assertEqual(res[0].dtype, np.int64)
self.assertEqual(res[1].dtype, np.float32)
self.assertTrue(check_randperm_out(n, res[0]))
self.assertTrue(check_randperm_out(n, res[1]))
class TestRandpermImperative(unittest.TestCase):
def test_out(self):
paddle.disable_static(paddle.NPUPlace(0))
n = 10
for dtype in ['int32', np.int64, 'float32', 'float64']:
data_p = paddle.randperm(n, dtype)
data_np = data_p.numpy()
self.assertTrue(
check_randperm_out(n, data_np), msg=error_msg(data_np))
paddle.enable_static()
if __name__ == "__main__":
unittest.main()
...@@ -68,7 +68,7 @@ class TestSetValueApi(TestSetValueBase): ...@@ -68,7 +68,7 @@ class TestSetValueApi(TestSetValueBase):
return out return out
def _run_dynamic(self): def _run_dynamic(self):
paddle.disable_static() paddle.disable_static(paddle.NPUPlace(0))
x = paddle.ones(shape=self.shape, dtype=self.dtype) x = paddle.ones(shape=self.shape, dtype=self.dtype)
self._call_setitem(x) self._call_setitem(x)
out = x.numpy() out = x.numpy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册