From 3055d71abecbfdce5d9486b26c05e07a1ade480a Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Fri, 26 Aug 2022 13:45:47 +0800 Subject: [PATCH] [XPU] add load_combine_op_xpu. test=kunlun (#45436) --- paddle/fluid/operators/load_combine_op_xpu.cc | 25 +++++++++++++++++++ .../fluid/platform/device/xpu/xpu2_op_list.h | 6 +++++ 2 files changed, 31 insertions(+) create mode 100644 paddle/fluid/operators/load_combine_op_xpu.cc diff --git a/paddle/fluid/operators/load_combine_op_xpu.cc b/paddle/fluid/operators/load_combine_op_xpu.cc new file mode 100644 index 00000000000..9fa7ba3f752 --- /dev/null +++ b/paddle/fluid/operators/load_combine_op_xpu.cc @@ -0,0 +1,25 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/load_combine_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL( + load_combine, + ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel); diff --git a/paddle/fluid/platform/device/xpu/xpu2_op_list.h b/paddle/fluid/platform/device/xpu/xpu2_op_list.h index 84e075f37c5..7eb90515785 100644 --- a/paddle/fluid/platform/device/xpu/xpu2_op_list.h +++ b/paddle/fluid/platform/device/xpu/xpu2_op_list.h @@ -326,6 +326,12 @@ XPUOpMap& get_kl2_ops() { XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace())})}, + {"load_combine", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), + pOpKernelType(vartype::FP64, XPUPlace()), + pOpKernelType(vartype::INT8, XPUPlace()), + pOpKernelType(vartype::INT32, XPUPlace()), + pOpKernelType(vartype::INT64, XPUPlace())})}, {"log", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"log_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"log_softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, -- GitLab