From 3b9b4c341bcd9d1f5cf22a5062580125770ce12c Mon Sep 17 00:00:00 2001 From: ykkk2333 <77383312+ykkk2333@users.noreply.github.com> Date: Fri, 2 Sep 2022 16:25:20 +0800 Subject: [PATCH] migrate shaple sgd, split,sign xpu kernels to phi, test=kunlun (#45607) --- .../fluid/operators/optimizers/sgd_op_xpu.cc | 102 ------------------ paddle/fluid/operators/shape_op_xpu.cc | 54 ---------- paddle/fluid/operators/sign_op_xpu.cc | 43 -------- paddle/fluid/operators/split_op_xpu.cc | 73 ------------- paddle/phi/kernels/shape_kernel.cc | 14 +++ paddle/phi/kernels/xpu/sgd_kernel.cc | 86 +++++++++++++++ paddle/phi/kernels/xpu/sign_kernel.cc | 33 ++++++ paddle/phi/kernels/xpu/split_kernel.cc | 67 ++++++++++++ 8 files changed, 200 insertions(+), 272 deletions(-) delete mode 100644 paddle/fluid/operators/optimizers/sgd_op_xpu.cc delete mode 100644 paddle/fluid/operators/shape_op_xpu.cc delete mode 100644 paddle/fluid/operators/sign_op_xpu.cc delete mode 100644 paddle/fluid/operators/split_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/sgd_kernel.cc create mode 100644 paddle/phi/kernels/xpu/sign_kernel.cc create mode 100644 paddle/phi/kernels/xpu/split_kernel.cc diff --git a/paddle/fluid/operators/optimizers/sgd_op_xpu.cc b/paddle/fluid/operators/optimizers/sgd_op_xpu.cc deleted file mode 100644 index 268da2e4e3e..00000000000 --- a/paddle/fluid/operators/optimizers/sgd_op_xpu.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#ifdef PADDLE_WITH_XPU -#include - -#include "paddle/fluid/operators/optimizers/sgd_op.h" -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -template -class SGDOpXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext &ctx) const override { - const auto *learning_rate = ctx.Input("LearningRate"); - - const auto *param_var = ctx.InputVar("Param"); - const auto *grad_var = ctx.InputVar("Grad"); - - if (param_var->IsType() && - grad_var->IsType()) { - const auto *param = ctx.Input("Param"); - auto *param_out = ctx.Output("ParamOut"); - // Actually, all tensors are LoDTensor except SelectedRows. - const auto *grad = ctx.Input("Grad"); - auto sz = param_out->numel(); - PADDLE_ENFORCE_EQ(param->numel(), - sz, - platform::errors::InvalidArgument( - "The input tensor Param's numel of SgdOp " - "should be equal with ParamOut's numel. " - "But received Param's " - "numel = [%s], ParamOut's numel = [%s]", - param->numel(), - sz)); - PADDLE_ENFORCE_EQ(grad->numel(), - sz, - platform::errors::InvalidArgument( - "The input tensor Grad's numel of SgdOp " - "should be equal with ParamOut's numel. " - "But received Grad's " - "numel = [%s], ParamOut's numel = [%s]", - grad->numel(), - sz)); - - const T *lr_t = learning_rate->data(); - auto &dev_ctx = ctx.template device_context(); - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - const float *lr = nullptr; - if (std::is_same::value) { - float *lr_float = - RAII_GUARD.alloc_l3_or_gm(learning_rate->numel()); - int r = xpu::cast_v2( - dev_ctx.x_context(), - reinterpret_cast(lr_t), - lr_float, - learning_rate->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2"); - lr = lr_float; - } else { - lr = reinterpret_cast(lr_t); - } - - const T *param_data = param->data(); - const T *grad_data = grad->data(); - T *out_data = param_out->mutable_data(ctx.GetPlace()); - - int r = xpu::sgd(dev_ctx.x_context(), - reinterpret_cast(grad_data), - reinterpret_cast(param_data), - lr, - reinterpret_cast(out_data), - sz); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "sgd"); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_XPU_KERNEL( - sgd, - ops::SGDOpXPUKernel, - ops::SGDOpXPUKernel); -#endif diff --git a/paddle/fluid/operators/shape_op_xpu.cc b/paddle/fluid/operators/shape_op_xpu.cc deleted file mode 100644 index 98ed9fb0322..00000000000 --- a/paddle/fluid/operators/shape_op_xpu.cc +++ /dev/null @@ -1,54 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * http://www.apache.org/licenses/LICENSE-2.0 - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -using LoDTensor = framework::LoDTensor; -using SelectedRows = phi::SelectedRows; - -template -class ShapeXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in_var = ctx.InputVar("Input"); - framework::DDim in_dims; - if (in_var->IsType()) { - in_dims = in_var->Get().value().dims(); - } else { - in_dims = in_var->Get().dims(); - } - auto* out_t = ctx.Output("Out"); - out_t->Resize({in_dims.size()}); - auto out_data = out_t->mutable_data(platform::CPUPlace()); - for (int i = 0; i < in_dims.size(); ++i) { - out_data[i] = in_dims[i]; - } - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL(shape, - ops::ShapeXPUKernel, - ops::ShapeXPUKernel, - ops::ShapeXPUKernel, - ops::ShapeXPUKernel, - ops::ShapeXPUKernel); - -#endif diff --git a/paddle/fluid/operators/sign_op_xpu.cc b/paddle/fluid/operators/sign_op_xpu.cc deleted file mode 100644 index a00aa4bb7ce..00000000000 --- a/paddle/fluid/operators/sign_op_xpu.cc +++ /dev/null @@ -1,43 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/device_wrapper.h" -#include "paddle/fluid/platform/device/xpu/xpu_header.h" -namespace paddle { -namespace operators { - -template -class SignXPUKernel : public framework::OpKernel { - public: - virtual void Compute(const framework::ExecutionContext& context) const { - auto* out = context.Output("Out"); - auto* in = context.Input("X"); - out->mutable_data(in->place()); - auto xpu_context = context.device_context().x_context(); - // int sign(Context* ctx, const T* x , T* y, int len); - int r = xpu::sign(xpu_context, in->data(), out->data(), in->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "sign"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - sign, ops::SignXPUKernel); - -#endif diff --git a/paddle/fluid/operators/split_op_xpu.cc b/paddle/fluid/operators/split_op_xpu.cc deleted file mode 100644 index d051978c0cc..00000000000 --- a/paddle/fluid/operators/split_op_xpu.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include -#include - -#include "paddle/fluid/operators/split_op.h" -#include "paddle/fluid/platform/device/xpu/xpu_header.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class SplitXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* input = ctx.Input("X"); - auto output = ctx.MultiOutput("Out"); - int num = ctx.Attr("num"); - std::vector sections = ctx.Attr>("sections"); - int axis = ctx.Attr("axis"); - auto& dev_ctx = ctx.template device_context(); - auto in_dims = input->dims(); - - auto input_shape = phi::vectorize(in_dims); - std::vector split_lists; - std::vector out_ptrs; - auto outs_number = output.size(); - std::vector outs_dims = - UpdateOutsDims(true, true, in_dims, num, sections, axis, outs_number); - for (size_t i = 0; i < output.size(); ++i) { - output[i]->Resize(outs_dims[i]); - out_ptrs.push_back(output[i]->mutable_data(ctx.GetPlace())); - split_lists.push_back(output[i]->dims()[axis]); - } - - int r = xpu::split(dev_ctx.x_context(), - input->data(), - out_ptrs, - input_shape, - split_lists, - axis); - PADDLE_ENFORCE_EQ( - r, - XPU_SUCCESS, - platform::errors::External("XPU split kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - split, - ops::SplitXPUKernel, - ops::SplitXPUKernel); -#endif diff --git a/paddle/phi/kernels/shape_kernel.cc b/paddle/phi/kernels/shape_kernel.cc index d4dbdbaf178..2c2b41e3c66 100644 --- a/paddle/phi/kernels/shape_kernel.cc +++ b/paddle/phi/kernels/shape_kernel.cc @@ -67,3 +67,17 @@ PD_REGISTER_KERNEL(shape, kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); } #endif + +#if defined(PADDLE_WITH_XPU) +PD_REGISTER_KERNEL(shape, + XPU, + ALL_LAYOUT, + phi::ShapeKernel, + bool, + int, + int64_t, + float, + double) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); +} +#endif diff --git a/paddle/phi/kernels/xpu/sgd_kernel.cc b/paddle/phi/kernels/xpu/sgd_kernel.cc new file mode 100644 index 00000000000..1bfd790893a --- /dev/null +++ b/paddle/phi/kernels/xpu/sgd_kernel.cc @@ -0,0 +1,86 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/sgd_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SGDDenseKernel(const Context &dev_ctx, + const DenseTensor ¶m, + const DenseTensor &learning_rate, + const DenseTensor &grad, + const paddle::optional &master_param, + bool multi_precision, + DenseTensor *param_out, + DenseTensor *master_param_out) { + using XPUType = typename XPUTypeTrait::Type; + auto sz = param_out->numel(); + PADDLE_ENFORCE_EQ( + param.numel(), + sz, + errors::InvalidArgument("The input tensor Param's numel of SgdOp " + "should be equal with ParamOut's numel. " + "But received Param's " + "numel = [%s], ParamOut's numel = [%s]", + param.numel(), + sz)); + PADDLE_ENFORCE_EQ( + grad.numel(), + sz, + errors::InvalidArgument("The input tensor Grad's numel of SgdOp " + "should be equal with ParamOut's numel. " + "But received Grad's " + "numel = [%s], ParamOut's numel = [%s]", + grad.numel(), + sz)); + + const T *lr_t = learning_rate.data(); + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + const float *lr = nullptr; + if (std::is_same::value) { + float *lr_float = RAII_GUARD.alloc_l3_or_gm(learning_rate.numel()); + int r = + xpu::cast_v2(dev_ctx.x_context(), + reinterpret_cast(lr_t), + lr_float, + learning_rate.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2"); + lr = lr_float; + } else { + lr = reinterpret_cast(lr_t); + } + + const T *param_data = param.data(); + const T *grad_data = grad.data(); + + dev_ctx.template Alloc(param_out); + T *out_data = param_out->data(); + + int r = xpu::sgd(dev_ctx.x_context(), + reinterpret_cast(grad_data), + reinterpret_cast(param_data), + lr, + reinterpret_cast(out_data), + sz); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sgd"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + sgd, XPU, ALL_LAYOUT, phi::SGDDenseKernel, phi::dtype::float16, float) {} diff --git a/paddle/phi/kernels/xpu/sign_kernel.cc b/paddle/phi/kernels/xpu/sign_kernel.cc new file mode 100644 index 00000000000..28223948dcc --- /dev/null +++ b/paddle/phi/kernels/xpu/sign_kernel.cc @@ -0,0 +1,33 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/sign_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SignKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + dev_ctx.template Alloc(out); + auto xpu_context = dev_ctx.x_context(); + int r = xpu::sign(xpu_context, x.data(), out->data(), x.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sign"); +} +} // namespace phi + +PD_REGISTER_KERNEL(sign, XPU, ALL_LAYOUT, phi::SignKernel, float) {} diff --git a/paddle/phi/kernels/xpu/split_kernel.cc b/paddle/phi/kernels/xpu/split_kernel.cc new file mode 100644 index 00000000000..352d6f857c0 --- /dev/null +++ b/paddle/phi/kernels/xpu/split_kernel.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/split_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void SplitKernel(const Context& dev_ctx, + const DenseTensor& x, + const IntArray& sections, + const Scalar& axis_scalar, + std::vector outs) { + int axis = axis_scalar.to(); + auto in_dims = x.dims(); + auto input_shape = vectorize(in_dims); + std::vector out_ptrs; + std::vector split_lists; + for (size_t j = 0; j < outs.size(); ++j) { + dev_ctx.template Alloc(outs[j]); + out_ptrs.push_back(outs[j]->data()); + split_lists.push_back(outs[j]->dims()[axis]); + } + int r = xpu::split(dev_ctx.x_context(), + x.data(), + out_ptrs, + input_shape, + split_lists, + axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "split"); +} + +template +void SplitWithNumKernel(const Context& dev_ctx, + const DenseTensor& x, + int num, + const Scalar& axis_scalar, + std::vector outs) { + int axis_value = axis_scalar.to(); + auto input_axis_dim = x.dims().at(axis_value); + std::vector sections_vec; + for (int i = 0; i < num; ++i) { + sections_vec.push_back(input_axis_dim / num); + } + IntArray sections(sections_vec); + SplitKernel(dev_ctx, x, sections, axis_scalar, outs); +} + +} // namespace phi + +PD_REGISTER_KERNEL(split, XPU, ALL_LAYOUT, phi::SplitKernel, float, int) {} +PD_REGISTER_KERNEL( + split_with_num, XPU, ALL_LAYOUT, phi::SplitWithNumKernel, float, int) {} -- GitLab