From 0ff72e5d83a4d1861be9a72b9d5f84d098878ad6 Mon Sep 17 00:00:00 2001 From: Liu-xiandong <85323580+Liu-xiandong@users.noreply.github.com> Date: Mon, 28 Feb 2022 10:05:33 +0800 Subject: [PATCH] [KP] Unify .cu and .xpu files with .kps files (#39917) * [KP] Unify .cu and .xpu files with .kps files * fix CI bug in GPU and modify the list * fix conflict * modify the date --- cmake/operators.cmake | 12 ++ .../elementwise/elementwise_add_op.cu | 29 --- .../elementwise/elementwise_add_op.h | 23 ++- .../elementwise/elementwise_add_op.kps | 171 +++--------------- .../platform/device/xpu/xpu_op_kpfirst_list.h | 5 +- 5 files changed, 59 insertions(+), 181 deletions(-) delete mode 100644 paddle/fluid/operators/elementwise/elementwise_add_op.cu diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 8843dd2628..7affd59de1 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -73,6 +73,12 @@ function(op_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) list(APPEND cu_srcs ${TARGET}.cu) endif() + # rename in KP: .kps -> .cu + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.kps) + file(COPY ${TARGET}.kps DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + file(RENAME ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.kps ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.cu) + list(APPEND cu_srcs ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.cu) + endif() if (WITH_NV_JETSON) list(REMOVE_ITEM cu_srcs "decode_jpeg_op.cu") endif() @@ -96,6 +102,12 @@ function(op_library TARGET) if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.cu) list(APPEND hip_srcs ${TARGET}.cu) endif() + # rename in KP: .kps -> .cu + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.kps) + file(COPY ${TARGET}.kps DESTINATION ${CMAKE_CURRENT_BINARY_DIR}) + file(RENAME ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.kps ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.cu) + list(APPEND hip_srcs ${CMAKE_CURRENT_BINARY_DIR}/${TARGET}.cu) + endif() if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu) set(PART_CUDA_KERNEL_FILES ${CMAKE_CURRENT_SOURCE_DIR}/${TARGET}.part.cu ${PART_CUDA_KERNEL_FILES} PARENT_SCOPE) diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.cu b/paddle/fluid/operators/elementwise/elementwise_add_op.cu deleted file mode 100644 index 52bf9b0e03..0000000000 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.cu +++ /dev/null @@ -1,29 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" -#include "paddle/phi/kernels/gpu/elementwise.h" - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_CUDA_KERNEL( - grad_add, ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel, - ops::ElementwiseAddKernel>, - ops::ElementwiseAddKernel>); diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.h b/paddle/fluid/operators/elementwise/elementwise_add_op.h index 1a256f7567..ae2e5b33b5 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.h @@ -13,7 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once - +#ifdef __xpu__ +#include +#include +#include "paddle/fluid/operators/elementwise/elementwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" +#include "paddle/fluid/operators/elementwise/elementwise_xpu.h" +#include "paddle/fluid/platform/device/device_wrapper.h" +#else #include #include #include "paddle/fluid/operators/elementwise/elementwise_op.h" @@ -21,6 +28,7 @@ limitations under the License. */ // only can include the headers in paddle/phi/include dirs #include "paddle/phi/kernels/elementwise_grad_kernel.h" #include "paddle/phi/kernels/math_kernel.h" +#endif namespace paddle { namespace operators { @@ -28,7 +36,17 @@ namespace operators { template class ElementwiseAddKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { + void Compute(const framework::ExecutionContext& ctx) const override { +#ifdef __xpu__ + std::vector ins; + std::vector outs; + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + const auto& xpu_ctx = + ctx.template device_context(); + paddle::operators::LaunchElementwiseCudaKernel, 1>( + xpu_ctx, ins, &outs, axis, kps::AddFunctor()); +#else auto *x = ctx.Input("X"); auto *y = ctx.Input("Y"); auto *z = ctx.Output("Out"); @@ -40,6 +58,7 @@ class ElementwiseAddKernel : public framework::OpKernel { static_cast::TYPE &>(dev_ctx), *x, *y, axis, z); +#endif } }; diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op.kps b/paddle/fluid/operators/elementwise/elementwise_add_op.kps index a3fea0d7b3..d6e0749318 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op.kps +++ b/paddle/fluid/operators/elementwise/elementwise_add_op.kps @@ -1,14 +1,19 @@ /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#ifdef PADDLE_WITH_XPU_KP + // Please do not modify the following code #if defined(__CUDA_ARCH__) #undef __CUDA_ARCH__ @@ -26,163 +31,31 @@ limitations under the License. */ #undef __NVCC__ #endif -#ifdef PADDLE_WITH_XPU_KP #include // NOLINT #include "xpu/kernel/cluster_header.h" // NOLINT #include "xpu/kernel/debug.h" // NOLINT #include "xpu/kernel/math.h" // NOLINT -#include -#include #include "paddle/fluid/operators/elementwise/elementwise_add_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/operators/elementwise/elementwise_xpu.h" -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -template -class ElementwiseAddXPUKPKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - std::vector ins; - std::vector outs; - int axis = PackTensorsIntoVector(ctx, &ins, &outs); - const auto& xpu_ctx = - ctx.template device_context(); - paddle::operators::LaunchElementwiseCudaKernel, 1>( - xpu_ctx, ins, &outs, axis, kps::AddFunctor()); - } -}; - -static std::vector get_rdims(const std::vector& xdims, - const std::vector& ydims) { - std::vector rdims; - for (size_t i = 0; i < xdims.size(); i++) { - if (xdims[i] != ydims[i]) { - rdims.push_back(i); - } - } - return rdims; -} - -template -class ElementwiseAddGradXPUKPKernel : public ElemwiseGradKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - ElemwiseGradKernel::Compute(ctx); - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dz = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); - const framework::DDim& x_dims = x->dims(); - const framework::DDim& y_dims = y->dims(); - const framework::DDim& dz_dims = dz->dims(); - int axis = ctx.Attr("axis"); - axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); - int max_dim = std::max(x_dims.size(), y_dims.size()); - PADDLE_ENFORCE_GE( - axis, 0, - platform::errors::InvalidArgument( - "Axis should be great than or equal to 0, but received axis is %d.", - axis)); - PADDLE_ENFORCE_LT( - axis, max_dim, - platform::errors::InvalidArgument( - "Axis should be less than %d, but received axis is %d.", max_dim, - axis)); - - std::vector x_dims_vec(max_dim, 1); - std::vector y_dims_vec(max_dim, 1); - std::vector z_dims_vec(max_dim, 1); - if (x_dims.size() == max_dim) { - for (int i = 0; i < max_dim; i++) { - x_dims_vec[i] = x_dims[i]; - } - } else { - for (int i = 0; i < x_dims.size(); i++) { - x_dims_vec[i + axis] = x_dims[i]; - } - } - - if (y_dims.size() == max_dim) { - for (int i = 0; i < max_dim; i++) { - y_dims_vec[i] = y_dims[i]; - } - } else { - for (int i = 0; i < y_dims.size(); i++) { - y_dims_vec[i + axis] = y_dims[i]; - } - } - - for (int i = 0; i < max_dim; i++) { - z_dims_vec[i] = dz_dims[i]; - } - std::vector rdims_for_x; - std::vector rdims_for_y; - rdims_for_x = get_rdims(x_dims_vec, z_dims_vec); - rdims_for_y = get_rdims(y_dims_vec, z_dims_vec); - const T* dz_data = dz->data(); - auto& dev_ctx = - ctx.template device_context(); - - if (dx != nullptr) { - T* dx_data = dx->mutable_data(ctx.GetPlace()); - if (rdims_for_x.size() == 0) { - if (dx_data != dz_data) { - framework::TensorCopy( - *dz, ctx.GetPlace(), - ctx.template device_context(), dx); - } - } else { - // For inplace strategy, dx will be stored in addr of dz, which makes - // the result of dy wrong. - if (dx->IsSharedBufferWith(*dz)) { - dx->clear(); - dx->mutable_data(x->dims(), ctx.GetPlace()); - } - - int ret = xpu::reduce_sum( - dev_ctx.x_context(), reinterpret_cast(dz_data), - reinterpret_cast(dx_data), z_dims_vec, rdims_for_x); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum "); - } - } - - if (dy != nullptr) { - T* dy_data = dy->mutable_data(ctx.GetPlace()); - if (rdims_for_y.size() == 0) { - if (dy_data != dz_data) { - framework::TensorCopy( - *dz, ctx.GetPlace(), - ctx.template device_context(), dy); - } - } else { - int ret = xpu::reduce_sum( - dev_ctx.x_context(), reinterpret_cast(dz_data), - reinterpret_cast(dy_data), z_dims_vec, rdims_for_y); - PADDLE_ENFORCE_XDNN_SUCCESS(ret, "reduce_sum "); - } - } - } -}; - -} // namespace operators -} // namespace paddle +#else +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include "paddle/phi/kernels/gpu/elementwise.h" +#endif namespace ops = paddle::operators; namespace plat = paddle::platform; +#ifdef PADDLE_WITH_XPU_KP REGISTER_OP_KERNEL(elementwise_add, KP, plat::XPUPlace, - ops::ElementwiseAddXPUKPKernel); - -REGISTER_OP_KERNEL(elementwise_add_grad, KP, plat::XPUPlace, - ops::ElementwiseAddGradXPUKPKernel); - -#endif // PADDLE_WITH_XPU_KP + ops::ElementwiseAddKernel); +#else +REGISTER_OP_CUDA_KERNEL( + grad_add, ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel, + ops::ElementwiseAddKernel>, + ops::ElementwiseAddKernel>); +#endif \ No newline at end of file diff --git a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h index aa02059345..f79ef8505d 100644 --- a/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h +++ b/paddle/fluid/platform/device/xpu/xpu_op_kpfirst_list.h @@ -27,7 +27,10 @@ using XPUKernelSet = using XPUOpMap = std::unordered_map; XPUOpMap& get_kp_ops() { - static XPUOpMap s_xpu_kp_kernels{}; + static XPUOpMap s_xpu_kp_kernels{ + {"elementwise_add", + XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, + }; return s_xpu_kp_kernels; } -- GitLab