未验证 提交 4812d8e4 编写于 作者: H houj04 提交者: GitHub

[XPU] add numel op (#53041)

上级 abc44b40
......@@ -508,6 +508,13 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT32,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"numel",
XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32,
phi::DataType::INT16,
phi::DataType::BOOL,
phi::DataType::FLOAT16,
phi::DataType::FLOAT32})},
{"one_hot", XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
{"one_hot_v2",
XPUKernelSet({phi::DataType::INT32, phi::DataType::INT64})},
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/numel_kernel.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/numel_kernel_impl.h"
PD_REGISTER_KERNEL(numel,
XPU,
ALL_LAYOUT,
phi::NumelKernel,
int16_t,
int,
int64_t,
phi::dtype::float16,
float,
bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::INT64);
}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
from op_test_xpu import XPUOpTest
import paddle
paddle.enable_static()
class XPUTestNumelOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'size'
self.use_dynamic_create_class = False
class TestXPUNumelOp(XPUOpTest):
def setUp(self):
self.place = paddle.XPUPlace(0)
self.init_dtype()
self.op_type = 'size'
self.initTestCase()
x = np.random.random(self.shape).astype(self.dtype)
self.inputs = {
'Input': x,
}
self.outputs = {'Out': np.array([np.size(x)])}
def initTestCase(self):
self.shape = (6, 56, 8, 55)
def init_dtype(self):
self.dtype = self.in_type
def test_check_output(self):
self.check_output_with_place(self.place)
class TestNumel1(TestXPUNumelOp):
def initTestCase(self):
self.shape = (11, 66)
class TestNumel2(TestXPUNumelOp):
def initTestCase(self):
self.shape = (0,)
class TestNumel3(TestXPUNumelOp):
def initTestCase(self):
self.shape = (2, 3, 4, 5, 6)
class TestNumel4(TestXPUNumelOp):
def initTestCase(self):
self.shape = (12, 24)
class TestNumel5(TestXPUNumelOp):
def initTestCase(self):
self.shape = (1, 64, 16)
support_types = get_xpu_op_support_types('numel')
for stype in support_types:
create_test_class(globals(), XPUTestNumelOP, stype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册