From ba9921369193c177b19053615d1dca5084b95f64 Mon Sep 17 00:00:00 2001 From: leolishaohao <138780481+leolishaohao@users.noreply.github.com> Date: Tue, 8 Aug 2023 19:45:29 +0800 Subject: [PATCH] [XPU] register multiclass_nms3 and norm xpu kernel to optimize model (#56064) --- paddle/phi/backends/xpu/xpu2_op_list.cc | 2 + .../phi/kernels/xpu/multiclass_nms3_kernel.cc | 206 +++++++++++ paddle/phi/kernels/xpu/norm_kernel.cc | 74 ++++ test/xpu/test_multiclass_nms3_op_xpu.py | 348 ++++++++++++++++++ test/xpu/test_norm_op_xpu.py | 96 +++++ 5 files changed, 726 insertions(+) create mode 100644 paddle/phi/kernels/xpu/multiclass_nms3_kernel.cc create mode 100644 paddle/phi/kernels/xpu/norm_kernel.cc create mode 100644 test/xpu/test_multiclass_nms3_op_xpu.py create mode 100644 test/xpu/test_norm_op_xpu.py diff --git a/paddle/phi/backends/xpu/xpu2_op_list.cc b/paddle/phi/backends/xpu/xpu2_op_list.cc index 30b1f103ad2..e55c5549f90 100644 --- a/paddle/phi/backends/xpu/xpu2_op_list.cc +++ b/paddle/phi/backends/xpu/xpu2_op_list.cc @@ -542,6 +542,7 @@ XPUOpMap& get_kl2_ops() { phi::DataType::INT64})}, {"multi_encoder_xpu", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, + {"multiclass_nms3", XPUKernelSet({phi::DataType::FLOAT32})}, {"nearest_interp_v2", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16, @@ -549,6 +550,7 @@ XPUOpMap& get_kl2_ops() { {"nearest_interp_v2_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"nll_loss", XPUKernelSet({phi::DataType::FLOAT32})}, {"nll_loss_grad", XPUKernelSet({phi::DataType::FLOAT32})}, + {"norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"not_equal", XPUKernelSet({phi::DataType::INT64, phi::DataType::INT32, diff --git a/paddle/phi/kernels/xpu/multiclass_nms3_kernel.cc b/paddle/phi/kernels/xpu/multiclass_nms3_kernel.cc new file mode 100644 index 00000000000..d33f2e793ed --- /dev/null +++ b/paddle/phi/kernels/xpu/multiclass_nms3_kernel.cc @@ -0,0 +1,206 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/multiclass_nms3_kernel.h" + +#include + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/common/memory_utils.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void MultiClassNMSKernel(const Context& ctx, + const DenseTensor& bboxes, + const DenseTensor& scores, + const paddle::optional& rois_num, + float score_threshold, + int nums_top_k, + int keep_top_k, + float nms_threshold, + bool normalized, + float nms_eta, + int background_label, + DenseTensor* out, + DenseTensor* index, + DenseTensor* nms_rois_num) { + using XPUT = typename XPUTypeTrait::Type; + + const XPUT* bboxes_data = reinterpret_cast(bboxes.data()); + const XPUT* scores_data = reinterpret_cast(scores.data()); + + bool return_index = index != nullptr; + bool has_rois_num = rois_num.get_ptr() != nullptr; + bool return_rois_num = nms_rois_num != nullptr; + auto score_dims = phi::vectorize(scores.dims()); + auto score_size = score_dims.size(); + bool is_lod = score_size == 2 ? true : false; + + int n = 0; + int b = 0; + int class_num = scores.dims()[1]; + int out_dim = bboxes.dims()[2] + 2; + int boxes_count = 0; + std::vector rois_num_vec; + rois_num_vec.clear(); + if (is_lod) { + if (has_rois_num) { + n = rois_num.get_ptr()->numel(); + for (int i = 0; i < n; i++) { + rois_num_vec.push_back(rois_num.get_ptr()->data()[i]); + boxes_count += rois_num.get_ptr()->data()[i]; + } + } else { + auto lod = bboxes.lod().back(); + boxes_count = lod[lod.size() - 1]; + n = lod.size() - 1; + for (int i = 0; i < n; i++) { + rois_num_vec.push_back(lod[i + 1] - lod[i]); + } + } + PADDLE_ENFORCE_EQ(boxes_count == bboxes.dims()[0], + true, + phi::errors::InvalidArgument( + "boxes_count should equal boxes->dims()[0].", + "But received: (%d) and (%d)", + boxes_count, + bboxes.dims()[0])); + PADDLE_ENFORCE_EQ( + boxes_count == score_dims[0], + true, + phi::errors::InvalidArgument("boxes_count shuold equal score_dims[0].", + "But received: (%d) and (%d)", + boxes_count, + score_dims[0])); + } else { + n = bboxes.dims()[0]; + b = bboxes.dims()[1]; + boxes_count = n * b; + } + std::vector outs_vec_; + std::vector out_index_vec_; + + outs_vec_.resize(boxes_count * out_dim); + out_index_vec_.resize(boxes_count); + + std::vector batch_starts; + int r = 0; + r = xpu::multiclass_nms(ctx.x_context(), + bboxes_data, + scores_data, + rois_num_vec, + outs_vec_, + out_index_vec_, + batch_starts, + n, + b, + class_num, + out_dim, + nums_top_k, + score_threshold, + keep_top_k, + nms_threshold, + background_label, + normalized, + nms_eta, + return_index, + is_lod); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "multiclass_nms"); + uint64_t num_kept = batch_starts.back(); + + if (num_kept == 0) { + if (return_index) { + // out_dim may be zero when there is no object in picture, so add some + // zeros to it + // caution: results may differ between cpu and xpu due to this operation + out->Resize({1, out_dim}); + ctx.template Alloc(out); + T* out_ptr = out->template data(); + std::vector temp_value(out_dim, 0.0f); + memory_utils::Copy(ctx.GetPlace(), + out_ptr, + phi::CPUPlace(), + temp_value.data(), + 1 * out_dim * sizeof(T)); + + index->Resize({1, 1}); + ctx.template Alloc(index); + int* out_index_ptr = index->template data(); + std::vector temp_idx(1, 0); + memory_utils::Copy(ctx.GetPlace(), + out_index_ptr, + phi::CPUPlace(), + temp_idx.data(), + 1 * sizeof(int)); + } else { + out->Resize({1, 1}); + T* od = ctx.template Alloc(out); + od[0] = -1; + batch_starts = {0, 1}; + } + } else { + out->Resize({static_cast(num_kept), out_dim}); + ctx.template Alloc(out); + T* out_ptr = out->template data(); + memory_utils::Copy(ctx.GetPlace(), + out_ptr, + phi::CPUPlace(), + outs_vec_.data(), + num_kept * out_dim * sizeof(T)); + if (return_index) { + index->Resize({static_cast(num_kept), 1}); + ctx.template Alloc(index); + int* out_index_ptr = index->template data(); + memory_utils::Copy(ctx.GetPlace(), + out_index_ptr, + phi::CPUPlace(), + out_index_vec_.data(), + num_kept * sizeof(int)); + } + } + + if (return_rois_num) { + nms_rois_num->Resize({n}); + ctx.template Alloc(nms_rois_num); + + DenseTensor nms_rois_num_cpu; + nms_rois_num_cpu.Resize({nms_rois_num->numel()}); + ctx.template HostAlloc(&nms_rois_num_cpu); + int* nms_rois_num_cpu_data = nms_rois_num_cpu.data(); + + for (int i = 1; i <= n; i++) { + nms_rois_num_cpu_data[i - 1] = batch_starts[i] - batch_starts[i - 1]; + } + phi::Copy(ctx, nms_rois_num_cpu, nms_rois_num->place(), true, nms_rois_num); + } + LoD lod; + if (num_kept == 0) { + batch_starts[batch_starts.size() - 1] = 1; + } + lod.emplace_back(batch_starts); + if (return_index) { + index->set_lod(lod); + } + out->set_lod(lod); +} +} // namespace phi + +PD_REGISTER_KERNEL( + multiclass_nms3, XPU, ALL_LAYOUT, phi::MultiClassNMSKernel, float) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT32); + kernel->OutputAt(2).SetDataType(phi::DataType::INT32); +} diff --git a/paddle/phi/kernels/xpu/norm_kernel.cc b/paddle/phi/kernels/xpu/norm_kernel.cc new file mode 100644 index 00000000000..a78fa0c4630 --- /dev/null +++ b/paddle/phi/kernels/xpu/norm_kernel.cc @@ -0,0 +1,74 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/norm_kernel.h" + +#include + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void NormKernel(const Context& ctx, + const DenseTensor& x, + int axis, + float epsilon, + bool is_test, + DenseTensor* out, + DenseTensor* norm) { + ctx.template Alloc(out); + ctx.template Alloc(norm); + + std::vector xshape; + auto x_dims = x.dims(); + auto x_dims_size = x_dims.size(); + xshape.resize(x_dims_size); + + if (axis < 0) { + axis += x_dims_size; + } + + PADDLE_ENFORCE_GE( + axis, + 0, + phi::errors::InvalidArgument("axis must be greater than or equal to 0." + "But received axis: %d.", + axis)); + PADDLE_ENFORCE_LT(axis, + x_dims_size, + phi::errors::InvalidArgument( + "Attr(axis) value must be less than rank of Input(X)" + "But received axis: %d, rank: %d.", + axis, + x_dims_size)); + + for (int i = 0; i < x_dims_size; i++) { + xshape[i] = static_cast(x_dims[i]); + } + + int r = xpu::l2_norm(ctx.x_context(), + x.data(), + out->data(), + norm->data(), + xshape, + axis, + epsilon); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "l2_norm"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(norm, XPU, ALL_LAYOUT, phi::NormKernel, float) {} diff --git a/test/xpu/test_multiclass_nms3_op_xpu.py b/test/xpu/test_multiclass_nms3_op_xpu.py new file mode 100644 index 00000000000..2ffeaf5b9bc --- /dev/null +++ b/test/xpu/test_multiclass_nms3_op_xpu.py @@ -0,0 +1,348 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import unittest + +import numpy as np +from get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) +from op_test_xpu import XPUOpTest + +import paddle + +paddle.enable_static() + + +def softmax(x): + # clip to shiftx, otherwise, when calc loss with + # log(exp(shiftx)), may get log(0)=INF + shiftx = (x - np.max(x)).clip(-64.0) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + +def iou(box_a, box_b, norm): + """Apply intersection-over-union overlap between box_a and box_b.""" + xmin_a = min(box_a[0], box_a[2]) + ymin_a = min(box_a[1], box_a[3]) + xmax_a = max(box_a[0], box_a[2]) + ymax_a = max(box_a[1], box_a[3]) + + xmin_b = min(box_b[0], box_b[2]) + ymin_b = min(box_b[1], box_b[3]) + xmax_b = max(box_b[0], box_b[2]) + ymax_b = max(box_b[1], box_b[3]) + + area_a = (ymax_a - ymin_a + (not norm)) * (xmax_a - xmin_a + (not norm)) + area_b = (ymax_b - ymin_b + (not norm)) * (xmax_b - xmin_b + (not norm)) + if area_a <= 0 and area_b <= 0: + return 0.0 + + xa = max(xmin_a, xmin_b) + ya = max(ymin_a, ymin_b) + xb = min(xmax_a, xmax_b) + yb = min(ymax_a, ymax_b) + + inter_area = max(xb - xa + (not norm), 0.0) * max(yb - ya + (not norm), 0.0) + + iou_ratio = inter_area / (area_a + area_b - inter_area) + + return iou_ratio + + +def nms( + boxes, + scores, + score_threshold, + nms_threshold, + top_k=200, + normalized=True, + eta=1.0, +): + """Apply non-maximum suppression at test time to avoid detecting too many + overlapping bounding boxes for a given object. + Args: + boxes: (tensor) The location preds for the img, Shape: [num_priors,4]. + scores: (tensor) The class predscores for the img, Shape:[num_priors]. + score_threshold: (float) The confidence thresh for filtering low + confidence boxes. + nms_threshold: (float) The overlap thresh for suppressing unnecessary + boxes. + top_k: (int) The maximum number of box preds to consider. + eta: (float) The parameter for adaptive NMS. + Return: + The indices of the kept boxes with respect to num_priors. + """ + all_scores = copy.deepcopy(scores) + all_scores = all_scores.flatten() + selected_indices = np.argwhere(all_scores > score_threshold) + selected_indices = selected_indices.flatten() + all_scores = all_scores[selected_indices] + + sorted_indices = np.argsort(-all_scores, axis=0, kind='mergesort') + sorted_scores = all_scores[sorted_indices] + sorted_indices = selected_indices[sorted_indices] + if top_k > -1 and top_k < sorted_indices.shape[0]: + sorted_indices = sorted_indices[:top_k] + sorted_scores = sorted_scores[:top_k] + + selected_indices = [] + adaptive_threshold = nms_threshold + for i in range(sorted_scores.shape[0]): + idx = sorted_indices[i] + keep = True + for k in range(len(selected_indices)): + if keep: + kept_idx = selected_indices[k] + overlap = iou(boxes[idx], boxes[kept_idx], normalized) + keep = True if overlap <= adaptive_threshold else False + else: + break + if keep: + selected_indices.append(idx) + if keep and eta < 1 and adaptive_threshold > 0.5: + adaptive_threshold *= eta + return selected_indices + + +def multiclass_nms( + boxes, + scores, + background, + score_threshold, + nms_threshold, + nms_top_k, + keep_top_k, + normalized, + shared, +): + if shared: + class_num = scores.shape[0] + priorbox_num = scores.shape[1] + else: + box_num = scores.shape[0] + class_num = scores.shape[1] + + selected_indices = {} + num_det = 0 + for c in range(class_num): + if c == background: + continue + if shared: + indices = nms( + boxes, + scores[c], + score_threshold, + nms_threshold, + nms_top_k, + normalized, + ) + else: + indices = nms( + boxes[:, c, :], + scores[:, c], + score_threshold, + nms_threshold, + nms_top_k, + normalized, + ) + selected_indices[c] = indices + num_det += len(indices) + + if keep_top_k > -1 and num_det > keep_top_k: + score_index = [] + for c, indices in selected_indices.items(): + for idx in indices: + if shared: + score_index.append((scores[c][idx], c, idx)) + else: + score_index.append((scores[idx][c], c, idx)) + + sorted_score_index = sorted( + score_index, key=lambda tup: tup[0], reverse=True + ) + sorted_score_index = sorted_score_index[:keep_top_k] + selected_indices = {} + + for _, c, _ in sorted_score_index: + selected_indices[c] = [] + for s, c, idx in sorted_score_index: + selected_indices[c].append(idx) + if not shared: + for labels in selected_indices: + selected_indices[labels].sort() + num_det = keep_top_k + + return selected_indices, num_det + + +def batched_multiclass_nms( + boxes, + scores, + background, + score_threshold, + nms_threshold, + nms_top_k, + keep_top_k, + normalized=True, + gpu_logic=False, +): + batch_size = scores.shape[0] + num_boxes = scores.shape[2] + det_outs = [] + index_outs = [] + lod = [] + for n in range(batch_size): + nmsed_outs, nmsed_num = multiclass_nms( + boxes[n], + scores[n], + background, + score_threshold, + nms_threshold, + nms_top_k, + keep_top_k, + normalized, + shared=True, + ) + lod.append(nmsed_num) + + if nmsed_num == 0: + continue + tmp_det_out = [] + for c, indices in nmsed_outs.items(): + for idx in indices: + xmin, ymin, xmax, ymax = boxes[n][idx][:] + tmp_det_out.append( + [ + c, + scores[n][c][idx], + xmin, + ymin, + xmax, + ymax, + idx + n * num_boxes, + ] + ) + if gpu_logic: + sorted_det_out = sorted( + tmp_det_out, key=lambda tup: tup[1], reverse=True + ) + else: + sorted_det_out = sorted( + tmp_det_out, key=lambda tup: tup[0], reverse=False + ) + det_outs.extend(sorted_det_out) + return det_outs, lod + + +class TestIOU(unittest.TestCase): + def test_iou(self): + box1 = np.array([4.0, 3.0, 7.0, 5.0]).astype('float32') + box2 = np.array([3.0, 4.0, 6.0, 8.0]).astype('float32') + + expt_output = np.array([2.0 / 16.0]).astype('float32') + calc_output = np.array([iou(box1, box2, True)]).astype('float32') + np.testing.assert_allclose(calc_output, expt_output, rtol=1e-05) + + +class XPUTestMulticlassNMS3Op(XPUOpTestWrapper): + def __init__(self): + self.op_name = 'multiclass_nms3' + self.use_dynamic_create_class = False + + class TestXpuMulticlassNMS3Op(XPUOpTest): + def set_argument(self): + self.score_threshold = 0.01 + + def setUp(self): + self.op_type = "multiclass_nms3" + self.dtype = self.in_type + + self.set_argument() + N = 7 + M = 1200 + C = 21 + BOX_SIZE = 4 + background = 0 + nms_threshold = 0.3 + nms_top_k = 400 + keep_top_k = ( + 200 if not hasattr(self, 'keep_top_k') else self.keep_top_k + ) + score_threshold = self.score_threshold + + scores = np.random.random((N * M, C)).astype(self.dtype) + + scores = np.apply_along_axis(softmax, 1, scores) + scores = np.reshape(scores, (N, M, C)) + scores = np.transpose(scores, (0, 2, 1)) + + boxes = np.random.random((N, M, BOX_SIZE)).astype(self.dtype) + boxes[:, :, 0:2] = boxes[:, :, 0:2] * 0.5 + boxes[:, :, 2:4] = boxes[:, :, 2:4] * 0.5 + 0.5 + + det_outs, lod = batched_multiclass_nms( + boxes, + scores, + background, + score_threshold, + nms_threshold, + nms_top_k, + keep_top_k, + gpu_logic=self.gpu_logic + if hasattr(self, 'gpu_logic') + else None, + ) + det_outs = np.array(det_outs) + nmsed_outs = ( + det_outs[:, :-1].astype(self.dtype) + if len(det_outs) + else np.array([], dtype=np.float32).reshape([0, BOX_SIZE + 2]) + ) + index_outs = ( + det_outs[:, -1:].astype('int') + if len(det_outs) + else np.array([], dtype='int').reshape([0, 1]) + ) + + self.inputs = {'BBoxes': boxes, 'Scores': scores} + self.outputs = { + 'Out': nmsed_outs, + 'Index': index_outs, + 'NmsRoisNum': np.array(lod).astype('int32'), + } + self.attrs = { + 'background_label': 0, + 'nms_threshold': nms_threshold, + 'nms_top_k': nms_top_k, + 'keep_top_k': keep_top_k, + 'score_threshold': score_threshold, + 'nms_eta': 1.0, + 'normalized': True, + } + + def test_check_output(self): + self.check_output_with_place(paddle.XPUPlace(0)) + + +support_types = get_xpu_op_support_types('multiclass_nms3') +for stype in support_types: + create_test_class(globals(), XPUTestMulticlassNMS3Op, stype) + +if __name__ == "__main__": + unittest.main() diff --git a/test/xpu/test_norm_op_xpu.py b/test/xpu/test_norm_op_xpu.py new file mode 100644 index 00000000000..dc225998d9a --- /dev/null +++ b/test/xpu/test_norm_op_xpu.py @@ -0,0 +1,96 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from get_test_cover_info import ( + XPUOpTestWrapper, + create_test_class, + get_xpu_op_support_types, +) +from op_test_xpu import XPUOpTest + +import paddle + +paddle.enable_static() + + +def l2_norm(x, axis, epsilon): + x2 = x**2 + s = np.sum(x2, axis=axis, keepdims=True) + r = np.sqrt(s + epsilon) + y = x / np.broadcast_to(r, x.shape) + return y, r + + +class XPUTestNormOp(XPUOpTestWrapper): + def __init__(self): + self.op_name = "norm" + self.use_dynamic_create_class = False + + class TestXPUNormOp(XPUOpTest): + def setUp(self): + self.op_type = "norm" + self.dtype = self.in_type + self.place = paddle.XPUPlace(0) + self.init_test_case() + x = np.random.random(self.shape).astype(self.dtype) + y, norm = l2_norm(x, self.axis, self.epsilon) + self.inputs = {'X': x} + self.attrs = {'epsilon': self.epsilon, 'axis': self.axis} + self.outputs = {'Out': y, 'Norm': norm} + + def init_test_case(self): + self.shape = [2, 3, 4, 5] + self.axis = 1 + self.epsilon = 1e-8 + + def test_check_output(self): + self.check_output_with_place(self.place) + + def test_check_grad(self): + self.check_grad_with_place(self.place, ['X'], 'Out') + + class TestXPUNormOp2(TestXPUNormOp): + def init_test_case(self): + self.shape = [5, 3, 9, 7] + self.axis = 0 + self.epsilon = 1e-8 + + class TestXPUNormOp3(TestXPUNormOp): + def init_test_case(self): + self.shape = [5, 3, 2, 7] + self.axis = -1 + self.epsilon = 1e-8 + + class TestXPUNormOp4(TestXPUNormOp): + def init_test_case(self): + self.shape = [128, 1024, 14, 14] + self.axis = 2 + self.epsilon = 1e-8 + + class TestXPUNormOp5(TestXPUNormOp): + def init_test_case(self): + self.shape = [2048, 2048] + self.axis = 1 + self.epsilon = 1e-8 + + +support_types = get_xpu_op_support_types('norm') +for stype in support_types: + create_test_class(globals(), XPUTestNormOp, stype) + +if __name__ == "__main__": + unittest.main() -- GitLab