From 6c5f9aa813031e395f628fc550f0134abdf037a3 Mon Sep 17 00:00:00 2001 From: ykkk2333 <77383312+ykkk2333@users.noreply.github.com> Date: Thu, 1 Sep 2022 10:47:04 +0800 Subject: [PATCH] migrate xpu activation/activation_grad/transpose/transpose_grad/tril_triu/tril_triu_grad kernel to PHI, test=kunlun (#45554) --- paddle/fluid/operators/activation_op_xpu.cc | 679 ------------------ .../fluid/operators/instance_norm_op_xpu.cc | 99 --- paddle/fluid/operators/transpose_op_xpu.cc | 128 ---- paddle/fluid/operators/tril_triu_op_xpu.cc | 86 --- paddle/phi/kernels/xpu/abs_grad_kernel.cc | 38 + paddle/phi/kernels/xpu/abs_kernel.cc | 30 + .../phi/kernels/xpu/activation_grad_kernel.cc | 576 +++++++++++++++ paddle/phi/kernels/xpu/activation_kernel.cc | 431 +++++++++++ .../phi/kernels/xpu/transpose_grad_kernel.cc | 53 ++ paddle/phi/kernels/xpu/transpose_kernel.cc | 52 ++ .../phi/kernels/xpu/tril_triu_grad_kernel.cc | 52 ++ paddle/phi/kernels/xpu/tril_triu_kernel.cc | 52 ++ 12 files changed, 1284 insertions(+), 992 deletions(-) delete mode 100644 paddle/fluid/operators/activation_op_xpu.cc delete mode 100644 paddle/fluid/operators/instance_norm_op_xpu.cc delete mode 100644 paddle/fluid/operators/transpose_op_xpu.cc delete mode 100644 paddle/fluid/operators/tril_triu_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/abs_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/abs_kernel.cc create mode 100644 paddle/phi/kernels/xpu/activation_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/activation_kernel.cc create mode 100644 paddle/phi/kernels/xpu/transpose_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/transpose_kernel.cc create mode 100644 paddle/phi/kernels/xpu/tril_triu_grad_kernel.cc create mode 100644 paddle/phi/kernels/xpu/tril_triu_kernel.cc diff --git a/paddle/fluid/operators/activation_op_xpu.cc b/paddle/fluid/operators/activation_op_xpu.cc deleted file mode 100644 index 6dfe2945caf..00000000000 --- a/paddle/fluid/operators/activation_op_xpu.cc +++ /dev/null @@ -1,679 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include - -#include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/platform/device/device_wrapper.h" -#include "paddle/fluid/platform/device/xpu/xpu_header.h" - -namespace paddle { -namespace operators { - -using paddle::framework::Tensor; - -template -class XPUActivationKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - Functor functor; - - auto attrs = functor.GetAttrs(); - for (auto &attr : attrs) { - *attr.second = context.Attr(attr.first); - } - functor(context); - } -}; - -template -class XPUActivationGradKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &context) const override { - Functor functor; - - auto attrs = functor.GetAttrs(); - for (auto &attr : attrs) { - *attr.second = context.Attr(attr.first); - } - functor(context); - } -}; - -template -void xpu_activation_forward( - const framework::ExecutionContext &ctx, - std::function func) { - const auto *x = ctx.Input("X"); - auto *y = ctx.Output("Out"); - const XPUT *x_data = reinterpret_cast(x->data()); - XPUT *y_data = reinterpret_cast(y->mutable_data(ctx.GetPlace())); - - auto xpu_context = ctx.device_context().x_context(); - int r = func(xpu_context, x_data, y_data, x->numel()); - PADDLE_ENFORCE_EQ( - r, - xpu::Error_t::SUCCESS, - platform::errors::External("XPU activation op return wrong value[%d %s].", - r, - XPUAPIErrorMsg[r])); -} - -template -void xpu_activation_backward( - const framework::ExecutionContext &ctx, - std::function - func) { - /* TODO: relu tanh sigmoid are inplace */ - const auto *x = ctx.Input("X"); - auto *y = ctx.Input("Out"); - auto *dOut = ctx.Input(framework::GradVarName("Out")); - auto *dX = ctx.Output(framework::GradVarName("X")); - const XPUT *x_data = nullptr; - const XPUT *y_data = nullptr; - const XPUT *y_grad = nullptr; - if (x != nullptr) x_data = reinterpret_cast(x->data()); - if (y != nullptr) y_data = reinterpret_cast(y->data()); - if (dOut != nullptr) y_grad = reinterpret_cast(dOut->data()); - XPUT *x_grad = reinterpret_cast(dX->mutable_data(ctx.GetPlace())); - auto xpu_context = ctx.device_context().x_context(); - - int r = func(xpu_context, x_data, y_data, y_grad, x_grad, dX->numel()); - PADDLE_ENFORCE_EQ(r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "XPU activation grad op return wrong value[%d %s].", - r, - XPUAPIErrorMsg[r])); -} - -template -struct XPUAbsFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_forward( - ctx, xpu::abs); - } -}; - -template -struct XPUAbsGradFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_backward( - ctx, xpu::abs_grad); - } -}; - -template -struct XPUExpFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_forward( - ctx, xpu::exp); - } -}; - -template -struct XPULogFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_forward( - ctx, xpu::log); - } -}; - -template -struct XPUReciprocalFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_forward( - ctx, xpu::reciprocal); - } -}; - -template -struct XPUReciprocalGradFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_backward( - ctx, xpu::reciprocal_grad); - } -}; - -template -struct XPUReluGradFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_backward( - ctx, xpu::relu_grad); - } -}; - -template -struct XPURelu6Functor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_forward( - ctx, xpu::relu6); - } -}; - -template -struct XPURelu6GradFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_backward( - ctx, xpu::relu6_grad); - } -}; - -template -struct XPUSigmoidFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_forward( - ctx, xpu::sigmoid); - } -}; - -template -struct XPUSigmoidGradFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_backward( - ctx, xpu::sigmoid_grad); - } -}; - -template -struct XPUSqrtFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_forward( - ctx, xpu::sqrt); - } -}; - -template -struct XPUSqrtGradFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_backward( - ctx, xpu::sqrt_grad); - } -}; - -template -struct XPUSquareFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_forward( - ctx, xpu::square); - } -}; - -template -struct XPUSquareGradFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_backward( - ctx, xpu::square_grad); - } -}; - -template -struct XPUTanhFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_forward( - ctx, xpu::tanh); - } -}; - -template -struct XPUTanhGradFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - xpu_activation_backward( - ctx, xpu::tanh_grad); - } -}; - -template -struct XPUHardSwishFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - float threshold = ctx.Attr("threshold"); - float scale = ctx.Attr("scale"); - float offset = ctx.Attr("offset"); - PADDLE_ENFORCE_EQ(threshold, - 6.0f, - platform::errors::External( - "Not support threshold [%f] in XPU", threshold)); - PADDLE_ENFORCE_EQ( - scale, - 6.0f, - platform::errors::External("Not support scale [%f] in XPU", scale)); - PADDLE_ENFORCE_EQ( - offset, - 3.0f, - platform::errors::External("Not support offset [%f] in XPU", offset)); - xpu_activation_forward( - ctx, xpu::hard_swish); - } -}; - -template -struct XPUHardSwishGradFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - float threshold = ctx.Attr("threshold"); - float scale = ctx.Attr("scale"); - float offset = ctx.Attr("offset"); - PADDLE_ENFORCE_EQ(threshold, - 6.0f, - platform::errors::External( - "Not support threshold [%f] in XPU", threshold)); - PADDLE_ENFORCE_EQ( - scale, - 6.0f, - platform::errors::External("Not support scale [%f] in XPU", scale)); - PADDLE_ENFORCE_EQ( - offset, - 3.0f, - platform::errors::External("Not support offset [%f] in XPU", offset)); - xpu_activation_backward( - ctx, xpu::hard_swish_grad); - } -}; - -template -struct XPULeakyReluFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *y = ctx.Output("Out"); - float alpha = ctx.Attr("alpha"); - const T *x_data = x->data(); - T *y_data = y->mutable_data(ctx.GetPlace()); - - auto xpu_context = - ctx.device_context().x_context(); - int r = xpu::leaky_relu(xpu_context, x_data, y_data, x->numel(), alpha); - PADDLE_ENFORCE_EQ( - r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "XPU leaky_relu return wrong value[%d %s].", r, XPUAPIErrorMsg[r])); - } -}; - -template -struct XPULeakyReluGradFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *dOut = ctx.Input(framework::GradVarName("Out")); - auto *dX = ctx.Output(framework::GradVarName("X")); - float alpha = ctx.Attr("alpha"); - const T *x_data = nullptr; - const T *y_grad = nullptr; - if (x != nullptr) x_data = x->data(); - if (dOut != nullptr) y_grad = dOut->data(); - T *x_grad = dX->mutable_data(ctx.GetPlace()); - auto xpu_context = - ctx.device_context().x_context(); - - // The signs of x and y are the same, - // y == nullptr here, - // so we give 2 x to the api - int r = xpu::leaky_relu_grad(xpu_context, - reinterpret_cast(x_data), - reinterpret_cast(x_data), - reinterpret_cast(y_grad), - reinterpret_cast(x_grad), - dX->numel(), - alpha); - PADDLE_ENFORCE_EQ(r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "XPU leaky_relu_grad return wrong value[%d %s].", - r, - XPUAPIErrorMsg[r])); - } -}; - -template -struct XPULogGradFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *dOut = ctx.Input(framework::GradVarName("Out")); - auto *dX = ctx.Output(framework::GradVarName("X")); - const T *x_data = nullptr; - const T *y_grad = nullptr; - if (x != nullptr) x_data = x->data(); - if (dOut != nullptr) y_grad = dOut->data(); - T *x_grad = dX->mutable_data(ctx.GetPlace()); - auto dev_ctx = - ctx.device_context().x_context(); - const auto x_dims = x->dims(); - auto xshape = phi::vectorize(x_dims); - int len = x->dims()[x_dims.size() - 1]; - std::vector yshape(1, len); - - xpu::ctx_guard RAII_GUARD(dev_ctx); - T *y_data = RAII_GUARD.alloc_l3_or_gm(len); - PADDLE_ENFORCE_XDNN_NOT_NULL(y_data); - T *tmp_grad = RAII_GUARD.alloc_l3_or_gm(x->numel()); - PADDLE_ENFORCE_XDNN_NOT_NULL(tmp_grad); - int r = xpu::constant(dev_ctx, y_data, len, static_cast(1.0)); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); - - // dx.device(d) = dout * (static_cast(1) / x); - r = xpu::broadcast_div(dev_ctx, - reinterpret_cast(y_data), - reinterpret_cast(x_data), - reinterpret_cast(tmp_grad), - yshape, - xshape); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_div"); - - r = xpu::broadcast_mul(dev_ctx, - reinterpret_cast(y_grad), - reinterpret_cast(tmp_grad), - reinterpret_cast(x_grad), - xshape, - xshape); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_mul"); - } -}; - -template -struct XPUMishFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *y = ctx.Output("Out"); - const T *x_data = x->data(); - T *y_data = y->mutable_data(ctx.GetPlace()); - - float threshold = ctx.Attr("threshold"); - - auto xpu_context = - ctx.device_context().x_context(); - int r = xpu::mish(xpu_context, x_data, y_data, x->numel(), threshold); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "mish"); - } -}; - -template -struct XPUMishGradFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *dOut = ctx.Input(framework::GradVarName("Out")); - auto *dX = ctx.Output(framework::GradVarName("X")); - const T *x_data = x->data(); - const T *y_grad = dOut->data(); - T *x_grad = dX->mutable_data(ctx.GetPlace()); - - float threshold = ctx.Attr("threshold"); - - auto xpu_context = - ctx.device_context().x_context(); - int r = xpu::mish_grad(xpu_context, - reinterpret_cast(x_data), - reinterpret_cast( - x_data), // mish_grad do not need y_data - reinterpret_cast(y_grad), - reinterpret_cast(x_grad), - dX->numel(), - threshold); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "mish_grad"); - } -}; - -template -struct XPUPowFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *y = ctx.Output("Out"); - auto pow_factor = ctx.Attr("factor"); - const T *x_data = x->data(); - T *y_data = y->mutable_data(ctx.GetPlace()); - - // allocate temp memory for factor on xpu - auto xpu_context = - ctx.device_context().x_context(); - xpu::ctx_guard RAII_GUARD(xpu_context); - T *factor_data = RAII_GUARD.alloc_l3_or_gm(1); - PADDLE_ENFORCE_NOT_NULL( - factor_data, - platform::errors::External("XPU alloc_l3_or_gm returns nullptr")); - memory::Copy(ctx.GetPlace(), - static_cast(factor_data), - platform::CPUPlace(), - static_cast(&pow_factor), - sizeof(T)); - - // broadcast_pow(Context* ctx, const T* x, const T* y, T* z, const - // std::vector& xshape, const std::vector& yshape); - auto x_dims = phi::vectorize(x->dims()); - int r = xpu::broadcast_pow( - xpu_context, x_data, factor_data, y_data, x_dims, {1}); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_pow"); - } -}; - -template -struct XPUPowGradFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *dOut = ctx.Input(framework::GradVarName("Out")); - auto *dX = ctx.Output(framework::GradVarName("X")); - - const T *x_data = x->data(); - const T *y_grad = dOut->data(); - T *x_grad = dX->mutable_data(ctx.GetPlace()); - - // check dims: all dims should equal - auto x_dims = phi::vectorize(x->dims()); - auto dy_dims = phi::vectorize(dOut->dims()); - auto dx_dims = phi::vectorize(dX->dims()); - PADDLE_ENFORCE_EQ( - x_dims, - dy_dims, - platform::errors::PreconditionNotMet("x_dims should match dy_dims.")); - PADDLE_ENFORCE_EQ( - x_dims, - dx_dims, - platform::errors::PreconditionNotMet("x_dims should match dx_dims.")); - float pow_factor = ctx.Attr("factor"); - - auto xpu_context = - ctx.device_context().x_context(); - // int pow_grad(Context* ctx, const T* x, const T* dy, T* dx, int len, float - // factor); - int r = xpu::pow_grad( - xpu_context, x_data, y_grad, x_grad, x->numel(), pow_factor); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "pow_grad"); - } -}; - -template -struct XPUReluFunctor : public BaseActivationFunctor { - using XPUType = typename XPUTypeTrait::Type; - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *y = ctx.Output("Out"); - const XPUType *x_data = reinterpret_cast(x->data()); - XPUType *y_data = - reinterpret_cast(y->mutable_data(ctx.GetPlace())); - - auto xpu_context = - ctx.device_context().x_context(); - int r = - xpu::relu(xpu_context, x_data, y_data, x->numel(), nullptr, nullptr); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu"); - } -}; - -template -struct XPUSoftPlusFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *y = ctx.Output("Out"); - const T *x_data = x->data(); - T *y_data = y->mutable_data(ctx.GetPlace()); - - float beta = ctx.Attr("beta"); - float threshold = ctx.Attr("threshold"); - - auto xpu_context = - ctx.device_context().x_context(); - int r = - xpu::softplus(xpu_context, x_data, y_data, x->numel(), beta, threshold); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "softplus"); - } -}; - -template -struct XPUSoftPlusGradFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *dOut = ctx.Input(framework::GradVarName("Out")); - auto *dX = ctx.Output(framework::GradVarName("X")); - const T *x_data = x->data(); - const T *y_grad = dOut->data(); - T *x_grad = dX->mutable_data(ctx.GetPlace()); - - float beta = ctx.Attr("beta"); - float threshold = ctx.Attr("threshold"); - - auto xpu_context = - ctx.device_context().x_context(); - int r = xpu::softplus_grad(xpu_context, - reinterpret_cast(x_data), - reinterpret_cast( - x_data), // softplus_grad do not need y_data - reinterpret_cast(y_grad), - reinterpret_cast(x_grad), - dX->numel(), - beta, - threshold); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "softplus_grad"); - } -}; - -template -struct XPUSwishFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *y = ctx.Output("Out"); - const T *x_data = x->data(); - T *y_data = y->mutable_data(ctx.GetPlace()); - - auto xpu_context = - ctx.device_context().x_context(); - // int swish(Context* ctx, const T* x, T* y, int len); - int r = xpu::swish(xpu_context, x_data, y_data, x->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish"); - } -}; - -template -struct XPUSwishGradFunctor : public BaseActivationFunctor { - void operator()(const framework::ExecutionContext &ctx) const { - const auto *x = ctx.Input("X"); - auto *dOut = ctx.Input(framework::GradVarName("Out")); - auto *dX = ctx.Output(framework::GradVarName("X")); - const T *x_data = x->data(); - const T *y_grad = dOut->data(); - T *x_grad = dX->mutable_data(ctx.GetPlace()); - - auto xpu_context = - ctx.device_context().x_context(); - // int swish_grad(Context* ctx, const T* x, const T* dy, T* dx, int len); - int r = xpu::swish_grad(xpu_context, x_data, y_grad, x_grad, dX->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish_grad"); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -#define REGISTER_ACTIVATION_XPU_KERNEL(act_type, functor, grad_functor) \ - REGISTER_OP_XPU_KERNEL(act_type, \ - ops::XPUActivationKernel>); \ - REGISTER_OP_XPU_KERNEL( \ - act_type##_grad, \ - ops::XPUActivationGradKernel>); - -REGISTER_ACTIVATION_XPU_KERNEL(abs, XPUAbsFunctor, XPUAbsGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(hard_swish, - XPUHardSwishFunctor, - XPUHardSwishGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, - XPULeakyReluFunctor, - XPULeakyReluGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(mish, XPUMishFunctor, XPUMishGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(reciprocal, - XPUReciprocalFunctor, - XPUReciprocalGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, - XPUSigmoidFunctor, - XPUSigmoidGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(square, XPUSquareFunctor, XPUSquareGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(softplus, - XPUSoftPlusFunctor, - XPUSoftPlusGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(swish, XPUSwishFunctor, XPUSwishGradFunctor) -REGISTER_ACTIVATION_XPU_KERNEL(pow, XPUPowFunctor, XPUPowGradFunctor) - -REGISTER_OP_XPU_KERNEL( - relu, - ops::XPUActivationKernel>, - ops::XPUActivationKernel>); -REGISTER_OP_XPU_KERNEL( - relu_grad, - ops::XPUActivationGradKernel>, - ops::XPUActivationGradKernel< - ops::XPUReluGradFunctor>); -REGISTER_OP_XPU_KERNEL(relu6, - ops::XPUActivationKernel>); -REGISTER_OP_XPU_KERNEL( - relu6_grad, ops::XPUActivationKernel>); -REGISTER_OP_XPU_KERNEL( - tanh, - ops::XPUActivationKernel>, - ops::XPUActivationKernel>); -REGISTER_OP_XPU_KERNEL( - tanh_grad, - ops::XPUActivationGradKernel>, - ops::XPUActivationGradKernel< - ops::XPUTanhGradFunctor>); - -REGISTER_OP_XPU_KERNEL(exp, - ops::XPUActivationKernel>); -REGISTER_OP_XPU_KERNEL(log, - ops::XPUActivationKernel>); -REGISTER_OP_XPU_KERNEL( - log_grad, ops::XPUActivationGradKernel>); -#endif // PADDLE_WITH_XPU diff --git a/paddle/fluid/operators/instance_norm_op_xpu.cc b/paddle/fluid/operators/instance_norm_op_xpu.cc deleted file mode 100644 index 429c5c47d68..00000000000 --- a/paddle/fluid/operators/instance_norm_op_xpu.cc +++ /dev/null @@ -1,99 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/device_wrapper.h" -#include "paddle/fluid/platform/device/xpu/xpu_header.h" -#include "paddle/phi/kernels/instance_norm_grad_kernel.h" -#include "paddle/phi/kernels/instance_norm_kernel.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; - -template -class InstanceNormXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto epsilon = ctx.Attr("epsilon"); - const auto* x = ctx.Input("X"); - const auto* scale = ctx.Input("Scale"); - const auto* bias = ctx.Input("Bias"); - auto* y = ctx.Output("Y"); - auto* mean = ctx.Output("SavedMean"); - auto* variance = ctx.Output("SavedVariance"); - auto& dev_ctx = ctx.template device_context(); - - // call phi kernel - phi::InstanceNormKernel( - static_cast::TYPE&>(dev_ctx), - *x, - *scale, - *bias, - epsilon, - y, - mean, - variance); - } -}; -template -class InstanceNormGradXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - const auto epsilon = ctx.Attr("epsilon"); - const auto* x = ctx.Input("X"); - const auto* mean = ctx.Input("SavedMean"); - const auto* variance = ctx.Input("SavedVariance"); - const auto* scale = ctx.Input("Scale"); - const auto* dy = ctx.Input(framework::GradVarName("Y")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dscale = ctx.Output(framework::GradVarName("Scale")); - auto* dbias = ctx.Output(framework::GradVarName("Bias")); - auto& dev_ctx = ctx.template device_context(); - - // call phi kernel - phi::InstanceNormGradKernel( - static_cast::TYPE&>(dev_ctx), - *x, - *dy, - *scale, - *mean, - *variance, - epsilon, - dx, - dbias, - dscale); - } -}; -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_XPU_KERNEL( - instance_norm, - ops::InstanceNormXPUKernel); -REGISTER_OP_XPU_KERNEL( - instance_norm_grad, - ops::InstanceNormGradXPUKernel); - -#endif // PADDLE_WITH_XPU} diff --git a/paddle/fluid/operators/transpose_op_xpu.cc b/paddle/fluid/operators/transpose_op_xpu.cc deleted file mode 100644 index 45f5a1ed005..00000000000 --- a/paddle/fluid/operators/transpose_op_xpu.cc +++ /dev/null @@ -1,128 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include -#include -#include - -#include "paddle/fluid/operators/transpose_op.h" -#include "paddle/fluid/platform/device/xpu/xpu_header.h" - -namespace paddle { -namespace operators { - -using framework::Tensor; - -template -class TransposeXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& context) const override { - auto x = context.Input("X"); - auto out = context.Output("Out"); - - // axis is permute - auto axis = context.Attr>("axis"); - int ndims = axis.size(); - const auto x_dims = x->dims(); - const T* x_data = x->data(); - T* y_data = out->mutable_data(context.GetPlace()); - if (out->numel() == 0) { - return; - } - - std::vector x_shape_host(ndims, 0); - for (int i = 0; i < ndims; ++i) { - x_shape_host[i] = x_dims[i]; - } - auto& dev_ctx = context.template device_context(); - int r = xpu::transpose(dev_ctx.x_context(), - reinterpret_cast(x_data), - reinterpret_cast(y_data), - x_shape_host, - axis); - PADDLE_ENFORCE_EQ( - r, - xpu::Error_t::SUCCESS, - platform::errors::External("XPU kernel error! error code=%d", r)); - } -}; - -template -class TransposeGradXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* out_grad = - context.Input(framework::GradVarName("Out")); - auto* x_grad = - context.Output(framework::GradVarName("X")); - if (!x_grad) return; - - x_grad->mutable_data(context.GetPlace()); - std::vector axis = context.Attr>("axis"); - std::vector reversed_axis(axis); - for (size_t i = 0; i < axis.size(); i++) { - reversed_axis[axis[i]] = i; - } - - int ndims = axis.size(); - std::vector out_shape_host(ndims, 0); - for (int i = 0; i < ndims; ++i) { - out_shape_host[i] = out_grad->dims()[i]; - } - auto& dev_ctx = context.template device_context(); - int r = xpu::transpose( - dev_ctx.x_context(), - reinterpret_cast(out_grad->data()), - reinterpret_cast(x_grad->data()), - out_shape_host, - reversed_axis); - PADDLE_ENFORCE_EQ( - r, - xpu::Error_t::SUCCESS, - platform::errors::External("XPU kernel error! error code=%d", r)); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_XPU_KERNEL( - transpose, - ops::TransposeXPUKernel, - ops::TransposeXPUKernel); -REGISTER_OP_XPU_KERNEL( - transpose_grad, - ops::TransposeGradXPUKernel, - ops::TransposeGradXPUKernel); -REGISTER_OP_XPU_KERNEL( - transpose2, - ops::TransposeXPUKernel, - ops::TransposeXPUKernel); -REGISTER_OP_XPU_KERNEL( - transpose2_grad, - ops::TransposeGradXPUKernel, - ops::TransposeGradXPUKernel); - -#endif // PADDLE_WITH_XPU diff --git a/paddle/fluid/operators/tril_triu_op_xpu.cc b/paddle/fluid/operators/tril_triu_op_xpu.cc deleted file mode 100644 index 1cca6034082..00000000000 --- a/paddle/fluid/operators/tril_triu_op_xpu.cc +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under -the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/platform/device/device_wrapper.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class TrilTriuXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const auto* x = context.Input("X"); - const auto* x_data = x->data(); - auto* out = context.Output("Out"); - auto* out_data = out->mutable_data(context.GetPlace()); - - const int diagonal = context.Attr("diagonal"); - const bool lower = context.Attr("lower"); - auto xshape = phi::vectorize(x->dims()); - auto& dev_ctx = context.template device_context(); - int r = 0; - if (lower) { - r = xpu::tril(dev_ctx.x_context(), x_data, out_data, xshape, diagonal); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "tril_op"); - } else { - r = xpu::triu(dev_ctx.x_context(), x_data, out_data, xshape, diagonal); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "triu_op"); - } - } -}; - -template -class TrilTriuGradXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const auto* d_out = - context.Input(framework::GradVarName("Out")); - const auto* dout_data = d_out->data(); - auto* d_x = context.Output(framework::GradVarName("X")); - auto* dx_data = d_x->mutable_data(context.GetPlace()); - - const int diagonal = context.Attr("diagonal"); - const bool lower = context.Attr("lower"); - - auto dy_shape = phi::vectorize(d_out->dims()); - auto& dev_ctx = context.template device_context(); - int r = 0; - if (lower) { - r = xpu::tril( - dev_ctx.x_context(), dout_data, dx_data, dy_shape, diagonal); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "tril_op"); - } else { - r = xpu::triu( - dev_ctx.x_context(), dout_data, dx_data, dy_shape, diagonal); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "triu_op"); - } - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - tril_triu, - ops::TrilTriuXPUKernel, - ops::TrilTriuXPUKernel); -REGISTER_OP_XPU_KERNEL( - tril_triu_grad, - ops::TrilTriuGradXPUKernel, - ops::TrilTriuGradXPUKernel); -#endif diff --git a/paddle/phi/kernels/xpu/abs_grad_kernel.cc b/paddle/phi/kernels/xpu/abs_grad_kernel.cc new file mode 100644 index 00000000000..e49beee6847 --- /dev/null +++ b/paddle/phi/kernels/xpu/abs_grad_kernel.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/abs_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void AbsGradKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& dout, + DenseTensor* dx) { + ctx.template Alloc(dx); + int r = xpu::abs_grad(ctx.x_context(), + x.data(), + dout.data(), + dout.data(), + dx->data(), + x.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "abs_grad"); +} +} // namespace phi + +PD_REGISTER_KERNEL(abs_grad, XPU, ALL_LAYOUT, phi::AbsGradKernel, float) {} diff --git a/paddle/phi/kernels/xpu/abs_kernel.cc b/paddle/phi/kernels/xpu/abs_kernel.cc new file mode 100644 index 00000000000..4213c92a1eb --- /dev/null +++ b/paddle/phi/kernels/xpu/abs_kernel.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/abs_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) { + ctx.template Alloc(out); + int r = xpu::abs(ctx.x_context(), x.data(), out->data(), x.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "abs"); +} +} // namespace phi + +PD_REGISTER_KERNEL(abs, XPU, ALL_LAYOUT, phi::AbsKernel, float) {} diff --git a/paddle/phi/kernels/xpu/activation_grad_kernel.cc b/paddle/phi/kernels/xpu/activation_grad_kernel.cc new file mode 100644 index 00000000000..875a91d2a73 --- /dev/null +++ b/paddle/phi/kernels/xpu/activation_grad_kernel.cc @@ -0,0 +1,576 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/activation_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" + +namespace phi { + +template +void ActivationGradXPUImpl(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* d_out, + DenseTensor* d_x, + const Functor& functor) { + PADDLE_ENFORCE_NOT_NULL( + d_out, errors::NotFound("The input DenseTensor dOut can not be nullptr")); + PADDLE_ENFORCE_NOT_NULL( + d_x, errors::NotFound("The output DenseTensor dX can not be nullptr")); + if (!out) { + out = d_out; // fake out + } + dev_ctx.template Alloc(d_x); + functor(dev_ctx, x, out, d_out, d_x); +} + +#define DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPX(name, functor_class) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + DenseTensor* dx) { \ + functor_class functor; \ + ActivationGradXPUImpl>( \ + dev_ctx, &x, nullptr, &dout, dx, functor); \ + } + +#define DEFINE_XPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX( \ + name, functor_class, attr) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + float attr, \ + DenseTensor* dx) { \ + functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationGradXPUImpl>( \ + dev_ctx, &x, nullptr, &dout, dx, functor); \ + } + +#define DEFINE_XPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX( \ + name, functor_class, attr1, attr2) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& dout, \ + float attr1, \ + float attr2, \ + DenseTensor* dx) { \ + functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr1; \ + *(attrs[1].second) = attr2; \ + ActivationGradXPUImpl>( \ + dev_ctx, &x, nullptr, &dout, dx, functor); \ + } + +#define DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(name, functor_class) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& out, \ + const DenseTensor& dout, \ + DenseTensor* dx) { \ + functor_class functor; \ + ActivationGradXPUImpl>( \ + dev_ctx, nullptr, &out, &dout, dx, functor); \ + } + +#define DEFINE_XPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT( \ + name, functor_class, attr) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& out, \ + const DenseTensor& dout, \ + float attr, \ + DenseTensor* dx) { \ + functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationGradXPUImpl>( \ + dev_ctx, nullptr, &out, &dout, dx, functor); \ + } + +#define DEFINE_XPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPOUT( \ + name, functor_class, attr1, attr2) \ + template \ + void name##GradKernel(const Context& dev_ctx, \ + const DenseTensor& out, \ + const DenseTensor& dout, \ + float attr1, \ + float attr2, \ + DenseTensor* dx) { \ + functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr1; \ + *(attrs[1].second) = attr2; \ + ActivationGradXPUImpl>( \ + dev_ctx, nullptr, &out, &dout, dx, functor); \ + } + +#define DEFINE_XPU_ACTIVATION_GRAD_KERNEL_NODEP(name, functor_class) \ + template \ + void name##GradKernel( \ + const Context& dev_ctx, const DenseTensor& dout, DenseTensor* dx) { \ + functor_class functor; \ + ActivationGradXPUImpl>( \ + dev_ctx, nullptr, nullptr, &dout, dx, functor); \ + } + +template +int xpu_activation_backward(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx, + std::function func) { + /* TODO: relu tanh sigmoid are inplace */ + const XPUType* x_data = nullptr; + const XPUType* y_data = nullptr; + const XPUType* y_grad = nullptr; + if (x != nullptr) x_data = reinterpret_cast(x->data()); + if (out != nullptr) y_data = reinterpret_cast(out->data()); + if (dout != nullptr) + y_grad = reinterpret_cast(dout->data()); + XPUType* x_grad = reinterpret_cast(dx->data()); + + int r = + func(dev_ctx.x_context(), x_data, y_data, y_grad, x_grad, dx->numel()); + return r; +} + +template +struct XPULogGradFunctor : public funcs::BaseActivationFunctor { + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dOut, + DenseTensor* dX) const { + const T* x_data = nullptr; + const T* y_grad = nullptr; + if (x != nullptr) x_data = x->data(); + if (dOut != nullptr) y_grad = dOut->data(); + T* x_grad = dX->data(); + const auto x_dims = x->dims(); + auto xshape = vectorize(x_dims); + int len = x->dims()[x_dims.size() - 1]; + std::vector yshape(1, len); + + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + T* y_data = RAII_GUARD.alloc_l3_or_gm(len); + PADDLE_ENFORCE_XDNN_NOT_NULL(y_data); + T* tmp_grad = RAII_GUARD.alloc_l3_or_gm(x->numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(tmp_grad); + int r = + xpu::constant(dev_ctx.x_context(), y_data, len, static_cast(1.0)); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); + + // dx.device(d) = dout * (static_cast(1) / x); + r = xpu::broadcast_div(dev_ctx.x_context(), + reinterpret_cast(y_data), + reinterpret_cast(x_data), + reinterpret_cast(tmp_grad), + yshape, + xshape); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_div"); + + r = xpu::broadcast_mul(dev_ctx.x_context(), + reinterpret_cast(y_grad), + reinterpret_cast(tmp_grad), + reinterpret_cast(x_grad), + xshape, + xshape); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_mul"); + } +}; + +template +struct XPULeakyReluGradFunctor : public funcs::BaseActivationFunctor { + float alpha; + typename funcs::BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + const T* x_data = nullptr; + const T* y_grad = nullptr; + if (x != nullptr) x_data = x->data(); + if (dout != nullptr) y_grad = dout->data(); + T* x_grad = dx->data(); + auto xpu_context = dev_ctx.x_context(); + + // The signs of x and y are the same, + // y == nullptr here, + // so we give 2 x to the api + int r = xpu::leaky_relu_grad(xpu_context, + reinterpret_cast(x_data), + reinterpret_cast(x_data), + reinterpret_cast(y_grad), + reinterpret_cast(x_grad), + dx->numel(), + alpha); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "leaky_relu_grad"); + } +}; + +template +struct XPUHardSwishGradFunctor : public funcs::BaseActivationFunctor { + float threshold; + float scale; + float offset; + + typename funcs::BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + using XPUType = typename XPUTypeTrait::Type; + PADDLE_ENFORCE_EQ( + threshold, + 6.0f, + errors::External("Not support threshold [%f] in XPU", threshold)); + PADDLE_ENFORCE_EQ( + scale, 6.0f, errors::External("Not support scale [%f] in XPU", scale)); + PADDLE_ENFORCE_EQ( + offset, + 3.0f, + errors::External("Not support offset [%f] in XPU", offset)); + int r = xpu_activation_backward( + dev_ctx, x, out, dout, dx, xpu::hard_swish_grad); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_swish_grad"); + } +}; + +template +struct XPUReciprocalGradFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + int r = xpu_activation_backward( + dev_ctx, x, out, dout, dx, xpu::reciprocal_grad); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reciprocal_grad"); + } +}; + +template +struct XPUReluGradFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + int r = xpu_activation_backward( + dev_ctx, x, out, dout, dx, xpu::relu_grad); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu_grad"); + } +}; + +template +struct XPURelu6GradFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + float threshold; + typename funcs::BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + int r = xpu_activation_backward( + dev_ctx, x, out, dout, dx, xpu::relu6_grad); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu6_grad"); + } +}; + +template +struct XPUSigmoidGradFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + int r = xpu_activation_backward( + dev_ctx, x, out, dout, dx, xpu::sigmoid_grad); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sigmoid_grad"); + } +}; + +template +struct XPUTanhGradFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + int r = xpu_activation_backward( + dev_ctx, x, out, dout, dx, xpu::tanh_grad); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "tanh_grad"); + } +}; + +template +struct XPUSquareGradFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + int r = xpu_activation_backward( + dev_ctx, x, out, dout, dx, xpu::square_grad); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "square_grad"); + } +}; + +template +struct XPUSqrtGradFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + int r = xpu_activation_backward( + dev_ctx, x, out, dout, dx, xpu::sqrt_grad); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sqrt_grad"); + } +}; + +template +void PowGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + const Scalar& factor, + DenseTensor* dx) { + dev_ctx.template Alloc(dx); + const T* x_data = x.data(); + const T* y_grad = dout.data(); + T* x_grad = dx->data(); + + // check dims: all dims should equal + auto x_dims = vectorize(x.dims()); + auto dy_dims = vectorize(dout.dims()); + auto dx_dims = vectorize(dx->dims()); + PADDLE_ENFORCE_EQ(x_dims, + dy_dims, + errors::PreconditionNotMet("x_dims should match dy_dims.")); + PADDLE_ENFORCE_EQ(x_dims, + dx_dims, + errors::PreconditionNotMet("x_dims should match dx_dims.")); + float pow_factor = factor.to(); + + auto xpu_context = dev_ctx.x_context(); + // int pow_grad(Context* ctx, const T* x, const T* dy, T* dx, int len, float + // factor); + int r = + xpu::pow_grad(xpu_context, x_data, y_grad, x_grad, x.numel(), pow_factor); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "pow_grad"); +} + +template +struct XPUSwishGradFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + float beta; + typename funcs::BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}}; + } + + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + const XPUType* x_data = reinterpret_cast(x->data()); + const XPUType* y_grad = reinterpret_cast(dout->data()); + XPUType* x_grad = reinterpret_cast(dx->data()); + + auto xpu_context = dev_ctx.x_context(); + int r = xpu::swish_grad(xpu_context, x_data, y_grad, x_grad, dx->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish_grad"); + } +}; + +template +struct XPUMishGradFunctor : public funcs::BaseActivationFunctor { + float threshold; + + typename funcs::BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dout, + DenseTensor* dx) const { + const T* x_data = x->data(); + const T* y_grad = dout->data(); + T* x_grad = dx->data(); + + auto xpu_context = dev_ctx.x_context(); + int r = xpu::mish_grad( + xpu_context, + reinterpret_cast(x_data), + reinterpret_cast(x_data), // mish_grad do not need y_data + reinterpret_cast(y_grad), + reinterpret_cast(x_grad), + dx->numel(), + threshold); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "mish_grad"); + } +}; + +template +struct XPUSoftPlusGradFunctor : public funcs::BaseActivationFunctor { + float beta; + float threshold; + typename funcs::BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + template + void operator()(const Context& dev_ctx, + const DenseTensor* x, + const DenseTensor* out, + const DenseTensor* dOut, + DenseTensor* dX) const { + const T* x_data = x->data(); + const T* y_grad = dOut->data(); + T* x_grad = dX->data(); + + auto xpu_context = dev_ctx.x_context(); + int r = xpu::softplus_grad(xpu_context, + reinterpret_cast(x_data), + reinterpret_cast( + x_data), // softplus_grad do not need y_data + reinterpret_cast(y_grad), + reinterpret_cast(x_grad), + dX->numel(), + beta, + threshold); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "softplus_grad"); + } +}; + +DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Reciprocal, XPUReciprocalGradFunctor); +DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sigmoid, XPUSigmoidGradFunctor); +DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Sqrt, XPUSqrtGradFunctor); +DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Tanh, XPUTanhGradFunctor); +DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPOUT(Relu, XPUReluGradFunctor); + +DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPX(Log, XPULogGradFunctor); +DEFINE_XPU_ACTIVATION_GRAD_KERNEL_DEPX(Square, XPUSquareGradFunctor); + +DEFINE_XPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Swish, + XPUSwishGradFunctor, + beta); +DEFINE_XPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(Mish, + XPUMishGradFunctor, + threshold); +DEFINE_XPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPX(LeakyRelu, + XPULeakyReluGradFunctor, + alpha); + +DEFINE_XPU_ACT_GRAD_KERNEL_WITH_ONE_ATTRS_DEPOUT(Relu6, + XPURelu6GradFunctor, + threshold); + +DEFINE_XPU_ACT_GRAD_KERNEL_WITH_TWO_ATTRS_DEPX(Softplus, + XPUSoftPlusGradFunctor, + beta, + threshold) + +template +void HardSwishGradKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& dout, + float threshold, + float scale, + float offset, + DenseTensor* dx) { + XPUHardSwishGradFunctor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = threshold; + *(attrs[1].second) = scale; + *(attrs[2].second) = offset; + ActivationGradXPUImpl>( + dev_ctx, &x, nullptr, &dout, dx, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL(relu_grad, + XPU, + ALL_LAYOUT, + phi::ReluGradKernel, + float, + phi::dtype::float16) {} + +#define PD_REGISTER_ACTIVATION_GRAD_KERNEL(name, func) \ + PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} + +PD_REGISTER_KERNEL(tanh_grad, + XPU, + ALL_LAYOUT, + phi::TanhGradKernel, + float, + phi::dtype::float16) {} +PD_REGISTER_ACTIVATION_GRAD_KERNEL(log_grad, LogGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(leaky_relu_grad, LeakyReluGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(hard_swish_grad, HardSwishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(reciprocal_grad, ReciprocalGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(relu6_grad, Relu6GradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(sigmoid_grad, SigmoidGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(sqrt_grad, SqrtGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(mish_grad, MishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(swish_grad, SwishGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(softplus_grad, SoftplusGradKernel) +PD_REGISTER_ACTIVATION_GRAD_KERNEL(square_grad, SquareGradKernel) +PD_REGISTER_KERNEL(pow_grad, XPU, ALL_LAYOUT, phi::PowGradKernel, float) {} diff --git a/paddle/phi/kernels/xpu/activation_kernel.cc b/paddle/phi/kernels/xpu/activation_kernel.cc new file mode 100644 index 00000000000..514d5e0b281 --- /dev/null +++ b/paddle/phi/kernels/xpu/activation_kernel.cc @@ -0,0 +1,431 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/activation_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/activation_functor.h" + +#include "paddle/fluid/memory/memory.h" + +namespace phi { + +template +void ActivationXPUImpl(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out, + const Functor& functor) { + PADDLE_ENFORCE_NOT_NULL(out, + errors::NotFound("Output Out should not be nullptr")); + dev_ctx.template Alloc(out); + functor(dev_ctx, x, out); +} + +#define DEFINE_XPU_ACTIVATION_KERNEL(name, functor_class) \ + template \ + void name##Kernel( \ + const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \ + functor_class functor; \ + ActivationXPUImpl>(dev_ctx, x, out, functor); \ + } + +#define DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \ + template \ + void name##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + float attr, \ + DenseTensor* out) { \ + functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr; \ + ActivationXPUImpl>(dev_ctx, x, out, functor); \ + } + +#define DEFINE_XPU_ACTIVATION_KERNEL_WITH_TWO_ATTRS( \ + name, functor_class, attr1, attr2) \ + template \ + void name##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + float attr1, \ + float attr2, \ + DenseTensor* out) { \ + functor_class functor; \ + auto attrs = functor.GetAttrs(); \ + *(attrs[0].second) = attr1; \ + *(attrs[1].second) = attr2; \ + ActivationXPUImpl>(dev_ctx, x, out, functor); \ + } + +template +int xpu_activation_func( + const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out, + std::function func) { + int r = func(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + x.numel()); + return r; +} + +template +int xpu_activation_1attr_func( + const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out, + float attr, + std::function + func) { + int r = func(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + x.numel(), + attr); + return r; +} + +template +int xpu_activation_2attr_func( + const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out, + float attr1, + float attr2, + std::function< + int(xpu::Context*, const XPUType*, XPUType*, int, float, float)> func) { + int r = func(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + x.numel(), + attr1, + attr2); + return r; +} + +template +struct XPUExpFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_func( + dev_ctx, x, out, xpu::exp); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp"); + } +}; + +template +struct XPULogFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_func( + dev_ctx, x, out, xpu::log); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "log"); + } +}; + +template +struct XPULeakyReluFunctor : public funcs::BaseActivationFunctor { + float alpha; + typename funcs::BaseActivationFunctor::AttrPair GetAttrs() { + return {{"alpha", &alpha}}; + } + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + using XPUType = typename XPUTypeTrait::Type; + int r = xpu_activation_1attr_func( + dev_ctx, x, out, alpha, xpu::leaky_relu); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "leaky_relu"); + } +}; + +template +void PowKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& factor, + DenseTensor* out) { + dev_ctx.template Alloc(out); + float pow_factor = factor.to(); + const T* x_data = x.data(); + T* y_data = out->data(); + + auto xpu_context = dev_ctx.x_context(); + // allocate temp memory for factor on xpu + xpu::ctx_guard RAII_GUARD(xpu_context); + T* factor_data = RAII_GUARD.alloc_l3_or_gm(1); + PADDLE_ENFORCE_NOT_NULL( + factor_data, errors::External("XPU alloc_l3_or_gm returns nullptr")); + paddle::memory::Copy(dev_ctx.GetPlace(), + static_cast(factor_data), + phi::CPUPlace(), + static_cast(&pow_factor), + sizeof(T)); + + // broadcast_pow(Context* ctx, const T* x, const T* y, T* z, const + // std::vector& xshape, const std::vector& yshape); + auto x_dims = vectorize(x.dims()); + int r = + xpu::broadcast_pow(xpu_context, x_data, factor_data, y_data, x_dims, {1}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_pow"); +} + +template +struct XPUHardSwishFunctor : public funcs::BaseActivationFunctor { + float threshold; + float scale; + float offset; + + typename funcs::BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}}; + } + + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + using XPUType = typename XPUTypeTrait::Type; + PADDLE_ENFORCE_EQ( + threshold, + 6.0f, + errors::External("Not support threshold [%f] in XPU", threshold)); + PADDLE_ENFORCE_EQ( + scale, 6.0f, errors::External("Not support scale [%f] in XPU", scale)); + PADDLE_ENFORCE_EQ( + offset, + 3.0f, + errors::External("Not support offset [%f] in XPU", offset)); + int r = xpu_activation_func( + dev_ctx, x, out, xpu::hard_swish); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_swish"); + } +}; + +template +struct XPUReciprocalFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_func( + dev_ctx, x, out, xpu::reciprocal); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "reciprocal"); + } +}; + +template +struct XPUReluFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + const XPUType* x_data = reinterpret_cast(x.data()); + XPUType* y_data = reinterpret_cast(out->data()); + + auto xpu_context = dev_ctx.x_context(); + int r = xpu::relu(xpu_context, x_data, y_data, x.numel(), nullptr, nullptr); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu"); + } +}; + +template +struct XPURelu6Functor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + float threshold; + typename funcs::BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_func( + dev_ctx, x, out, xpu::relu6); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu6"); + } +}; + +template +struct XPUSigmoidFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_func( + dev_ctx, x, out, xpu::sigmoid); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sigmoid"); + } +}; + +template +struct XPUSquareFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_func( + dev_ctx, x, out, xpu::square); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "square"); + } +}; + +template +struct XPUSqrtFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_func( + dev_ctx, x, out, xpu::sqrt); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sqrt"); + } +}; + +template +struct XPUMishFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + float threshold; + typename funcs::BaseActivationFunctor::AttrPair GetAttrs() { + return {{"threshold", &threshold}}; + } + + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_1attr_func( + dev_ctx, x, out, threshold, xpu::mish); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "mish"); + } +}; + +template +void SwishKernel(const Context& dev_ctx, + const DenseTensor& x, + float beta, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + dev_ctx.template Alloc(out); + int r = xpu::swish(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + x.numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish"); +} + +template +struct XPUSoftplusFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + float beta; + float threshold; + + typename funcs::BaseActivationFunctor::AttrPair GetAttrs() { + return {{"beta", &beta}, {"threshold", &threshold}}; + } + + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_2attr_func( + dev_ctx, x, out, beta, threshold, xpu::softplus); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "softplus"); + } +}; + +template +struct XPUTanhFunctor : public funcs::BaseActivationFunctor { + using XPUType = typename XPUTypeTrait::Type; + template + void operator()(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) const { + int r = xpu_activation_func( + dev_ctx, x, out, xpu::tanh); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "tanh"); + } +}; + +DEFINE_XPU_ACTIVATION_KERNEL(Exp, XPUExpFunctor) +DEFINE_XPU_ACTIVATION_KERNEL(Log, XPULogFunctor) +DEFINE_XPU_ACTIVATION_KERNEL(Reciprocal, XPUReciprocalFunctor) +DEFINE_XPU_ACTIVATION_KERNEL(Relu, XPUReluFunctor) +DEFINE_XPU_ACTIVATION_KERNEL(Sigmoid, XPUSigmoidFunctor) +DEFINE_XPU_ACTIVATION_KERNEL(Square, XPUSquareFunctor) +DEFINE_XPU_ACTIVATION_KERNEL(Sqrt, XPUSqrtFunctor) +DEFINE_XPU_ACTIVATION_KERNEL(Tanh, XPUTanhFunctor) + +DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, XPUMishFunctor, threshold) +DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu, + XPULeakyReluFunctor, + alpha) +DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6, XPURelu6Functor, threshold) + +DEFINE_XPU_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus, + XPUSoftplusFunctor, + beta, + threshold) + +template +void HardSwishKernel(const Context& dev_ctx, + const DenseTensor& x, + float threshold, + float scale, + float offset, + DenseTensor* out) { + XPUHardSwishFunctor functor; + auto attrs = functor.GetAttrs(); + *(attrs[0].second) = threshold; + *(attrs[1].second) = scale; + *(attrs[2].second) = offset; + ActivationXPUImpl>( + dev_ctx, x, out, functor); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + relu, XPU, ALL_LAYOUT, phi::ReluKernel, float, phi::dtype::float16) {} + +#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ + PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} + +PD_REGISTER_KERNEL( + tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {} + +PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad +PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel) +PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) +PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel) +PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) +PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel) +PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) +PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel) +PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) +PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) +PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel) +PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) +PD_REGISTER_ACTIVATION_KERNEL(square, SquareKernel) diff --git a/paddle/phi/kernels/xpu/transpose_grad_kernel.cc b/paddle/phi/kernels/xpu/transpose_grad_kernel.cc new file mode 100644 index 00000000000..9fce92b8262 --- /dev/null +++ b/paddle/phi/kernels/xpu/transpose_grad_kernel.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/transpose_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void TransposeGradKernel(const Context& dev_ctx, + const DenseTensor& out_grad, + const std::vector& axis, + DenseTensor* x_grad) { + using XPUType = typename XPUTypeTrait::Type; + dev_ctx.template Alloc(x_grad); + std::vector reversed_axis(axis); + for (size_t i = 0; i < axis.size(); i++) { + reversed_axis[axis[i]] = i; + } + int ndims = axis.size(); + std::vector out_shape_host(ndims, 0); + for (int i = 0; i < ndims; ++i) { + out_shape_host[i] = out_grad.dims()[i]; + } + int r = xpu::transpose( + dev_ctx.x_context(), + reinterpret_cast(out_grad.data()), + reinterpret_cast(x_grad->data()), + out_shape_host, + reversed_axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose_grad"); +} +} // namespace phi + +PD_REGISTER_KERNEL(transpose_grad, + XPU, + ALL_LAYOUT, + phi::TransposeGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/transpose_kernel.cc b/paddle/phi/kernels/xpu/transpose_kernel.cc new file mode 100644 index 00000000000..18157f18981 --- /dev/null +++ b/paddle/phi/kernels/xpu/transpose_kernel.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/transpose_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void TransposeKernel(const Context& dev_ctx, + const DenseTensor& x, + const std::vector& axis, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + if (out->numel() == 0) { + return; + } + dev_ctx.template Alloc(out); + int ndims = axis.size(); + std::vector x_shape_host(ndims, 0); + for (int i = 0; i < ndims; ++i) { + x_shape_host[i] = x.dims()[i]; + } + int r = xpu::transpose(dev_ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + x_shape_host, + axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(transpose, + XPU, + ALL_LAYOUT, + phi::TransposeKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/tril_triu_grad_kernel.cc b/paddle/phi/kernels/xpu/tril_triu_grad_kernel.cc new file mode 100644 index 00000000000..964e9c61742 --- /dev/null +++ b/paddle/phi/kernels/xpu/tril_triu_grad_kernel.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/tril_triu_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void TrilTriuGradKernel(const Context& ctx, + const DenseTensor& out_grad, + int diagonal, + bool lower, + DenseTensor* x_grad) { + using XPUType = typename XPUTypeTrait::Type; + ctx.template Alloc(x_grad); + auto dy_shape = vectorize(out_grad.dims()); + int r = 0; + if (lower) { + r = xpu::tril(ctx.x_context(), + reinterpret_cast(out_grad.data()), + reinterpret_cast(x_grad->data()), + dy_shape, + diagonal); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "tril_op"); + } else { + r = xpu::triu(ctx.x_context(), + reinterpret_cast(out_grad.data()), + reinterpret_cast(x_grad->data()), + dy_shape, + diagonal); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "triu_op"); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + tril_triu_grad, XPU, ALL_LAYOUT, phi::TrilTriuGradKernel, int, float) {} diff --git a/paddle/phi/kernels/xpu/tril_triu_kernel.cc b/paddle/phi/kernels/xpu/tril_triu_kernel.cc new file mode 100644 index 00000000000..3d9ae98a238 --- /dev/null +++ b/paddle/phi/kernels/xpu/tril_triu_kernel.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/tril_triu_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void TrilTriuKernel(const Context& ctx, + const DenseTensor& x, + int diagonal, + bool lower, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + ctx.template Alloc(out); + auto xshape = vectorize(x.dims()); + int r = 0; + if (lower) { + r = xpu::tril(ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + xshape, + diagonal); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "tril_op"); + } else { + r = xpu::triu(ctx.x_context(), + reinterpret_cast(x.data()), + reinterpret_cast(out->data()), + xshape, + diagonal); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "triu_op"); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL( + tril_triu, XPU, ALL_LAYOUT, phi::TrilTriuKernel, int, float) {} -- GitLab