From eacbd488c07cff044a9897db9f81872ad23ff40c Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Mon, 11 Jan 2021 10:22:08 +0800 Subject: [PATCH] add aarch64 and sunway kunlun lib (#30027) (#30237) * add aarch64 and sunway kunlun lib * minor * optimize elementwise_add for kunlun * update kunlun dependence * minor * minor --- cmake/external/xpu.cmake | 10 +- .../elementwise/elementwise_add_op_xpu.cc | 131 +++++++++++++++++- 2 files changed, 136 insertions(+), 5 deletions(-) diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index bbd065c0a5e..6516b861a9c 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -4,7 +4,15 @@ endif() INCLUDE(ExternalProject) SET(XPU_PROJECT "extern_xpu") -SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_15.tar.gz" CACHE STRING "" FORCE) + +if (WITH_AARCH64) + SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/aarch64/xpu_2020_1229.tar.gz" CACHE STRING "" FORCE) +elseif(WITH_SUNWAY) + SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2020_1227.tar.gz" CACHE STRING "" FORCE) +else() + SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_0105.tar.gz" CACHE STRING "" FORCE) +endif() + SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc index 625e66d5f39..8d99aa27985 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc @@ -26,17 +26,140 @@ namespace operators { template class ElementwiseAddXPUKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { - XPUElementwise(ctx, xpu::add); + void Compute(const framework::ExecutionContext& ctx) const override { + // XPUElementwise(ctx, xpu::add); + // ToDo(QingshuChen): update this optimization to elementwise_xpu.h + auto x_var = ctx.InputVar("X"); + PADDLE_ENFORCE_NE(x_var, nullptr, platform::errors::InvalidArgument( + "Cannot get input Variable X")); + PADDLE_ENFORCE_EQ( + x_var->IsType(), true, + platform::errors::InvalidArgument( + "XPU only support LoDTensor, Input(X) is not LoDTensor")); + + auto x = x_var->Get(); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + auto x_dims = x.dims(); + auto y_dims = y->dims(); + int max_dim = std::max(x_dims.size(), y_dims.size()); + int axis = ctx.Attr("axis"); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + + PADDLE_ENFORCE_GE( + axis, 0, + platform::errors::InvalidArgument( + "Axis should be great than or equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT( + axis, max_dim, + platform::errors::InvalidArgument( + "Axis should be less than %d, but received axis is %d.", max_dim, + axis)); + std::vector x_dims_vec(max_dim, 1); + std::vector y_dims_vec(max_dim, 1); + if (x_dims.size() == max_dim) { + for (int i = 0; i < max_dim; i++) { + x_dims_vec[i] = x_dims[i]; + } + } else { + for (int i = 0; i < x_dims.size(); i++) { + x_dims_vec[i + axis] = x_dims[i]; + } + } + if (y_dims.size() == max_dim) { + for (int i = 0; i < max_dim; i++) { + y_dims_vec[i] = y_dims[i]; + } + } else { + for (int i = 0; i < y_dims.size(); i++) { + y_dims_vec[i + axis] = y_dims[i]; + } + } + const T* x_data = x.data(); + const T* y_data = y->data(); + T* z_data = z->data(); + + auto& dev_ctx = + ctx.template device_context(); + int ret = xpu::SUCCESS; + ret = xpu::broadcast_add(dev_ctx.x_context(), x_data, y_data, z_data, + x_dims_vec, y_dims_vec); + PADDLE_ENFORCE_EQ( + ret, xpu::SUCCESS, + platform::errors::External( + "XPU kernel Elementwise occur error in XPUElementwise error code ", + ret, XPUAPIErrorMsg[ret])); } }; template class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel { public: - void Compute(const framework::ExecutionContext &ctx) const override { + void Compute(const framework::ExecutionContext& ctx) const override { ElemwiseGradKernel::Compute(ctx); - XPUElementwiseGrad(ctx, xpu::add_grad, false); + // XPUElementwiseGrad(ctx, xpu::add_grad, false); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* dz = ctx.Input(framework::GradVarName("Out")); + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + int axis = ctx.Attr("axis"); + const framework::DDim& x_dims = x->dims(); + const framework::DDim& y_dims = y->dims(); + int max_dim = std::max(x_dims.size(), y_dims.size()); + axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis); + PADDLE_ENFORCE_GE( + axis, 0, + platform::errors::InvalidArgument( + "Axis should be great than or equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT( + axis, max_dim, + platform::errors::InvalidArgument( + "Axis should be less than %d, but received axis is %d.", max_dim, + axis)); + std::vector x_dims_vec(max_dim, 1); + std::vector y_dims_vec(max_dim, 1); + if (x_dims.size() == max_dim) { + for (int i = 0; i < max_dim; i++) { + x_dims_vec[i] = x_dims[i]; + } + } else { + for (int i = 0; i < x_dims.size(); i++) { + x_dims_vec[i + axis] = x_dims[i]; + } + } + if (y_dims.size() == max_dim) { + for (int i = 0; i < max_dim; i++) { + y_dims_vec[i] = y_dims[i]; + } + } else { + for (int i = 0; i < y_dims.size(); i++) { + y_dims_vec[i + axis] = y_dims[i]; + } + } + + T* dx_data = nullptr; + T* dy_data = nullptr; + if (dx) { + dx_data = dx->mutable_data(ctx.GetPlace()); + } + if (dy) { + dy_data = dy->mutable_data(ctx.GetPlace()); + } + + auto& dev_ctx = + ctx.template device_context(); + int ret = xpu::broadcast_add_grad(dev_ctx.x_context(), dx_data, dx_data, + dx_data, dz->data(), dy_data, + dx_data, x_dims_vec, y_dims_vec); + PADDLE_ENFORCE_EQ( + ret, xpu::SUCCESS, + platform::errors::External( + "XPU kernel Elementwise occur error in XPUElementwise error code ", + ret, XPUAPIErrorMsg[ret])); } }; -- GitLab