activation_op_xpu.cc 21.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_XPU

#include <string>
18 19 20

#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
21
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

namespace paddle {
namespace operators {

using paddle::framework::Tensor;

template <typename Functor>
class XPUActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    Functor functor;

    auto attrs = functor.GetAttrs();
    for (auto &attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
    functor(context);
  }
};

template <typename Functor>
class XPUActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    Functor functor;

    auto attrs = functor.GetAttrs();
    for (auto &attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
    functor(context);
  }
};

58
template <typename DeviceContext, typename T, typename XPUT>
T
TTerror 已提交
59 60
void xpu_activation_forward(
    const framework::ExecutionContext &ctx,
61
    std::function<int(xpu::Context *, const XPUT *, XPUT *, int)> func) {
62 63
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Output<Tensor>("Out");
64 65
  const XPUT *x_data = reinterpret_cast<const XPUT *>(x->data<T>());
  XPUT *y_data = reinterpret_cast<XPUT *>(y->mutable_data<T>(ctx.GetPlace()));
P
procr 已提交
66

T
TTerror 已提交
67 68 69
  auto xpu_context = ctx.device_context<DeviceContext>().x_context();
  int r = func(xpu_context, x_data, y_data, x->numel());
  PADDLE_ENFORCE_EQ(
70 71
      r,
      xpu::Error_t::SUCCESS,
T
TTerror 已提交
72
      platform::errors::External("XPU activation op return wrong value[%d %s].",
73 74
                                 r,
                                 XPUAPIErrorMsg[r]));
75 76
}

77 78 79
template <typename DeviceContext, typename T, typename XPUT>
void xpu_activation_backward(
    const framework::ExecutionContext &ctx,
80 81
    std::function<int(
        xpu::Context *, const XPUT *, const XPUT *, const XPUT *, XPUT *, int)>
82
        func) {
83 84 85 86 87
  /* TODO: relu tanh sigmoid are inplace */
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Input<Tensor>("Out");
  auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
  auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
88 89 90 91 92 93 94
  const XPUT *x_data = nullptr;
  const XPUT *y_data = nullptr;
  const XPUT *y_grad = nullptr;
  if (x != nullptr) x_data = reinterpret_cast<const XPUT *>(x->data<T>());
  if (y != nullptr) y_data = reinterpret_cast<const XPUT *>(y->data<T>());
  if (dOut != nullptr) y_grad = reinterpret_cast<const XPUT *>(dOut->data<T>());
  XPUT *x_grad = reinterpret_cast<XPUT *>(dX->mutable_data<T>(ctx.GetPlace()));
95
  auto xpu_context = ctx.device_context<DeviceContext>().x_context();
P
procr 已提交
96

T
TTerror 已提交
97
  int r = func(xpu_context, x_data, y_data, y_grad, x_grad, dX->numel());
98 99
  PADDLE_ENFORCE_EQ(r,
                    xpu::Error_t::SUCCESS,
100
                    platform::errors::External(
101 102
                        "XPU activation grad op return wrong value[%d %s].",
                        r,
T
TTerror 已提交
103
                        XPUAPIErrorMsg[r]));
104 105
}

T
TTerror 已提交
106
template <typename T>
107
struct XPUAbsFunctor : public BaseActivationFunctor<T> {
108
  using XPUType = typename XPUTypeTrait<T>::Type;
109
  void operator()(const framework::ExecutionContext &ctx) const {
110
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
111
        ctx, xpu::abs<XPUType>);
112 113 114
  }
};

T
TTerror 已提交
115
template <typename T>
116
struct XPUAbsGradFunctor : public BaseActivationFunctor<T> {
117
  using XPUType = typename XPUTypeTrait<T>::Type;
118
  void operator()(const framework::ExecutionContext &ctx) const {
119 120
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::abs_grad<XPUType>);
121 122 123 124
  }
};

template <typename T>
125
struct XPUExpFunctor : public BaseActivationFunctor<T> {
126
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
127
  void operator()(const framework::ExecutionContext &ctx) const {
128
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
129
        ctx, xpu::exp<XPUType>);
T
TTerror 已提交
130 131 132
  }
};

133
template <typename T>
T
TTerror 已提交
134
struct XPULogFunctor : public BaseActivationFunctor<T> {
135
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
136
  void operator()(const framework::ExecutionContext &ctx) const {
137 138
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::log<XPUType>);
T
TTerror 已提交
139 140 141
  }
};

142
template <typename T>
143
struct XPUReciprocalFunctor : public BaseActivationFunctor<T> {
144
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
145
  void operator()(const framework::ExecutionContext &ctx) const {
146
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
147
        ctx, xpu::reciprocal<XPUType>);
T
TTerror 已提交
148 149 150
  }
};

151
template <typename T>
152
struct XPUReciprocalGradFunctor : public BaseActivationFunctor<T> {
153
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
154
  void operator()(const framework::ExecutionContext &ctx) const {
155 156
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::reciprocal_grad<XPUType>);
T
TTerror 已提交
157 158 159
  }
};

160
template <typename T>
161 162
struct XPUReluGradFunctor : public BaseActivationFunctor<T> {
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
163
  void operator()(const framework::ExecutionContext &ctx) const {
164 165
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::relu_grad<XPUType>);
T
TTerror 已提交
166 167 168
  }
};

P
procr 已提交
169
template <typename T>
170
struct XPUSigmoidFunctor : public BaseActivationFunctor<T> {
171
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
172
  void operator()(const framework::ExecutionContext &ctx) const {
173
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
174
        ctx, xpu::sigmoid<XPUType>);
T
TTerror 已提交
175 176 177
  }
};

178
template <typename T>
179
struct XPUSigmoidGradFunctor : public BaseActivationFunctor<T> {
180
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
181
  void operator()(const framework::ExecutionContext &ctx) const {
182
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
183
        ctx, xpu::sigmoid_grad<XPUType>);
T
TTerror 已提交
184 185 186
  }
};

187
template <typename T>
188
struct XPUSqrtFunctor : public BaseActivationFunctor<T> {
189
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
190
  void operator()(const framework::ExecutionContext &ctx) const {
191 192
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::sqrt<XPUType>);
T
TTerror 已提交
193 194 195
  }
};

196
template <typename T>
197
struct XPUSqrtGradFunctor : public BaseActivationFunctor<T> {
198
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
199
  void operator()(const framework::ExecutionContext &ctx) const {
200
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
201
        ctx, xpu::sqrt_grad<XPUType>);
T
TTerror 已提交
202 203 204
  }
};

205
template <typename T>
206
struct XPUSquareFunctor : public BaseActivationFunctor<T> {
207
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
208
  void operator()(const framework::ExecutionContext &ctx) const {
209 210
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::square<XPUType>);
T
TTerror 已提交
211 212 213
  }
};

214
template <typename T>
T
TTerror 已提交
215
struct XPUSquareGradFunctor : public BaseActivationFunctor<T> {
216
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
217
  void operator()(const framework::ExecutionContext &ctx) const {
218 219
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::square_grad<XPUType>);
T
TTerror 已提交
220 221 222
  }
};

223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
template <typename T>
struct XPUTanhFunctor : public BaseActivationFunctor<T> {
  using XPUType = typename XPUTypeTrait<T>::Type;
  void operator()(const framework::ExecutionContext &ctx) const {
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::tanh<XPUType>);
  }
};

template <typename T>
struct XPUTanhGradFunctor : public BaseActivationFunctor<T> {
  using XPUType = typename XPUTypeTrait<T>::Type;
  void operator()(const framework::ExecutionContext &ctx) const {
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::tanh_grad<XPUType>);
  }
};

template <typename T>
struct XPUHardSwishFunctor : public BaseActivationFunctor<T> {
  using XPUType = typename XPUTypeTrait<T>::Type;
  void operator()(const framework::ExecutionContext &ctx) const {
    float threshold = ctx.Attr<float>("threshold");
    float scale = ctx.Attr<float>("scale");
    float offset = ctx.Attr<float>("offset");
248 249
    PADDLE_ENFORCE_EQ(threshold,
                      6.0f,
250 251
                      platform::errors::External(
                          "Not support threshold [%f] in XPU", threshold));
252
    PADDLE_ENFORCE_EQ(
253 254
        scale,
        6.0f,
255
        platform::errors::External("Not support scale [%f] in XPU", scale));
256
    PADDLE_ENFORCE_EQ(
257 258
        offset,
        3.0f,
259 260 261 262 263 264
        platform::errors::External("Not support offset [%f] in XPU", offset));
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::hard_swish<XPUType>);
  }
};

265
template <typename T>
T
TTerror 已提交
266
struct XPUHardSwishGradFunctor : public BaseActivationFunctor<T> {
267
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
268 269 270 271
  void operator()(const framework::ExecutionContext &ctx) const {
    float threshold = ctx.Attr<float>("threshold");
    float scale = ctx.Attr<float>("scale");
    float offset = ctx.Attr<float>("offset");
272 273
    PADDLE_ENFORCE_EQ(threshold,
                      6.0f,
T
TTerror 已提交
274 275
                      platform::errors::External(
                          "Not support threshold [%f] in XPU", threshold));
276
    PADDLE_ENFORCE_EQ(
277 278
        scale,
        6.0f,
279
        platform::errors::External("Not support scale [%f] in XPU", scale));
T
TTerror 已提交
280
    PADDLE_ENFORCE_EQ(
281 282
        offset,
        3.0f,
T
TTerror 已提交
283
        platform::errors::External("Not support offset [%f] in XPU", offset));
284 285
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::hard_swish_grad<XPUType>);
T
TTerror 已提交
286 287 288
  }
};

P
procr 已提交
289
template <typename T>
T
TTerror 已提交
290 291 292 293 294 295 296 297 298 299 300 301
struct XPULeakyReluFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Output<Tensor>("Out");
    float alpha = ctx.Attr<float>("alpha");
    const T *x_data = x->data<T>();
    T *y_data = y->mutable_data<T>(ctx.GetPlace());

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
    int r = xpu::leaky_relu(xpu_context, x_data, y_data, x->numel(), alpha);
    PADDLE_ENFORCE_EQ(
302 303 304 305
        r,
        xpu::Error_t::SUCCESS,
        platform::errors::External(
            "XPU leaky_relu return wrong value[%d %s].", r, XPUAPIErrorMsg[r]));
T
TTerror 已提交
306 307 308
  }
};

309
template <typename T>
T
TTerror 已提交
310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326
struct XPULeakyReluGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    float alpha = ctx.Attr<float>("alpha");
    const T *x_data = nullptr;
    const T *y_grad = nullptr;
    if (x != nullptr) x_data = x->data<T>();
    if (dOut != nullptr) y_grad = dOut->data<T>();
    T *x_grad = dX->mutable_data<T>(ctx.GetPlace());
    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();

    // The signs of x and y are the same,
    // y == nullptr here,
    // so we give 2 x to the api
327 328 329 330 331 332 333 334 335
    int r = xpu::leaky_relu_grad(xpu_context,
                                 reinterpret_cast<const float *>(x_data),
                                 reinterpret_cast<const float *>(x_data),
                                 reinterpret_cast<const float *>(y_grad),
                                 reinterpret_cast<float *>(x_grad),
                                 dX->numel(),
                                 alpha);
    PADDLE_ENFORCE_EQ(r,
                      xpu::Error_t::SUCCESS,
T
TTerror 已提交
336
                      platform::errors::External(
337 338
                          "XPU leaky_relu_grad return wrong value[%d %s].",
                          r,
T
TTerror 已提交
339 340 341 342
                          XPUAPIErrorMsg[r]));
  }
};

343 344 345 346 347 348 349 350 351
template <typename T>
struct XPUPowFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Output<Tensor>("Out");
    auto pow_factor = ctx.Attr<float>("factor");
    const T *x_data = x->data<T>();
    T *y_data = y->mutable_data<T>(ctx.GetPlace());

352
    // allocate temp memory for factor on xpu
353 354
    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
355 356 357 358 359
    xpu::ctx_guard RAII_GUARD(xpu_context);
    T *factor_data = RAII_GUARD.alloc_l3_or_gm<T>(1);
    PADDLE_ENFORCE_NOT_NULL(
        factor_data,
        platform::errors::External("XPU alloc_l3_or_gm returns nullptr"));
360 361 362 363
    memory::Copy(ctx.GetPlace(),
                 static_cast<void *>(factor_data),
                 platform::CPUPlace(),
                 static_cast<void *>(&pow_factor),
364 365 366 367 368
                 sizeof(T));

    // broadcast_pow(Context* ctx, const T* x, const T* y, T* z, const
    // std::vector<int>& xshape, const std::vector<int>& yshape);
    auto x_dims = phi::vectorize<int>(x->dims());
369 370
    int r = xpu::broadcast_pow(
        xpu_context, x_data, factor_data, y_data, x_dims, {1});
371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_pow");
  }
};

template <typename T>
struct XPUPowGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));

    const T *x_data = x->data<T>();
    const T *y_grad = dOut->data<T>();
    T *x_grad = dX->mutable_data<T>(ctx.GetPlace());

    // check dims: all dims should equal
    auto x_dims = phi::vectorize<int>(x->dims());
    auto dy_dims = phi::vectorize<int>(dOut->dims());
    auto dx_dims = phi::vectorize<int>(dX->dims());
390
    PADDLE_ENFORCE_EQ(
391 392
        x_dims,
        dy_dims,
393 394
        platform::errors::PreconditionNotMet("x_dims should match dy_dims."));
    PADDLE_ENFORCE_EQ(
395 396
        x_dims,
        dx_dims,
397
        platform::errors::PreconditionNotMet("x_dims should match dx_dims."));
398 399 400 401 402 403
    float pow_factor = ctx.Attr<float>("factor");

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
    // int pow_grad(Context* ctx, const T* x, const T* dy, T* dx, int len, float
    // factor);
404 405
    int r = xpu::pow_grad(
        xpu_context, x_data, y_grad, x_grad, x->numel(), pow_factor);
406
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "pow_grad");
407 408 409
  }
};

410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427
template <typename T>
struct XPUReluFunctor : public BaseActivationFunctor<T> {
  using XPUType = typename XPUTypeTrait<T>::Type;
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Output<Tensor>("Out");
    const XPUType *x_data = reinterpret_cast<const XPUType *>(x->data<T>());
    XPUType *y_data =
        reinterpret_cast<XPUType *>(y->mutable_data<T>(ctx.GetPlace()));

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
    int r =
        xpu::relu(xpu_context, x_data, y_data, x->numel(), nullptr, nullptr);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu");
  }
};

428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461
template <typename T>
struct XPUSoftPlusFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Output<Tensor>("Out");
    const T *x_data = x->data<T>();
    T *y_data = y->mutable_data<T>(ctx.GetPlace());

    float beta = ctx.Attr<float>("beta");
    float threshold = ctx.Attr<float>("threshold");

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
    int r =
        xpu::softplus(xpu_context, x_data, y_data, x->numel(), beta, threshold);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "softplus");
  }
};

template <typename T>
struct XPUSoftPlusGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    const T *x_data = x->data<T>();
    const T *y_grad = dOut->data<T>();
    T *x_grad = dX->mutable_data<T>(ctx.GetPlace());

    float beta = ctx.Attr<float>("beta");
    float threshold = ctx.Attr<float>("threshold");

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
462 463 464 465 466 467 468 469 470
    int r = xpu::softplus_grad(xpu_context,
                               reinterpret_cast<const float *>(x_data),
                               reinterpret_cast<const float *>(
                                   x_data),  // softplus_grad do not need y_data
                               reinterpret_cast<const float *>(y_grad),
                               reinterpret_cast<float *>(x_grad),
                               dX->numel(),
                               beta,
                               threshold);
471 472 473 474
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "softplus_grad");
  }
};

475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508
template <typename T>
struct XPUSwishFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Output<Tensor>("Out");
    const T *x_data = x->data<T>();
    T *y_data = y->mutable_data<T>(ctx.GetPlace());

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
    // int swish(Context* ctx, const T* x, T* y, int len);
    int r = xpu::swish(xpu_context, x_data, y_data, x->numel());
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish");
  }
};

template <typename T>
struct XPUSwishGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    const T *x_data = x->data<T>();
    const T *y_grad = dOut->data<T>();
    T *x_grad = dX->mutable_data<T>(ctx.GetPlace());

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
    // int swish_grad(Context* ctx, const T* x, const T* dy, T* dx, int len);
    int r = xpu::swish_grad(xpu_context, x_data, y_grad, x_grad, dX->numel());
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish_grad");
  }
};

509 510 511 512 513 514 515 516 517 518 519 520
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_XPU_KERNEL(act_type, functor, grad_functor)  \
  REGISTER_OP_XPU_KERNEL(act_type,                                       \
                         ops::XPUActivationKernel<ops::functor<float>>); \
  REGISTER_OP_XPU_KERNEL(                                                \
      act_type##_grad,                                                   \
      ops::XPUActivationGradKernel<ops::grad_functor<float>>);

521
REGISTER_ACTIVATION_XPU_KERNEL(abs, XPUAbsFunctor, XPUAbsGradFunctor)
522 523
REGISTER_ACTIVATION_XPU_KERNEL(hard_swish,
                               XPUHardSwishFunctor,
524
                               XPUHardSwishGradFunctor)
525 526
REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu,
                               XPULeakyReluFunctor,
527
                               XPULeakyReluGradFunctor)
528 529
REGISTER_ACTIVATION_XPU_KERNEL(reciprocal,
                               XPUReciprocalFunctor,
530
                               XPUReciprocalGradFunctor)
531 532
REGISTER_ACTIVATION_XPU_KERNEL(sigmoid,
                               XPUSigmoidFunctor,
533 534
                               XPUSigmoidGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor)
T
TTerror 已提交
535
REGISTER_ACTIVATION_XPU_KERNEL(square, XPUSquareFunctor, XPUSquareGradFunctor)
536 537
REGISTER_ACTIVATION_XPU_KERNEL(softplus,
                               XPUSoftPlusFunctor,
538
                               XPUSoftPlusGradFunctor)
539 540
REGISTER_ACTIVATION_XPU_KERNEL(swish, XPUSwishFunctor, XPUSwishGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(pow, XPUPowFunctor, XPUPowGradFunctor)
541

542
REGISTER_OP_XPU_KERNEL(
543 544
    relu,
    ops::XPUActivationKernel<ops::XPUReluFunctor<float>>,
545 546
    ops::XPUActivationKernel<ops::XPUReluFunctor<paddle::platform::float16>>);
REGISTER_OP_XPU_KERNEL(
547 548
    relu_grad,
    ops::XPUActivationGradKernel<ops::XPUReluGradFunctor<float>>,
549 550
    ops::XPUActivationGradKernel<
        ops::XPUReluGradFunctor<paddle::platform::float16>>);
551
REGISTER_OP_XPU_KERNEL(
552 553
    tanh,
    ops::XPUActivationKernel<ops::XPUTanhFunctor<float>>,
554 555
    ops::XPUActivationKernel<ops::XPUTanhFunctor<paddle::platform::float16>>);
REGISTER_OP_XPU_KERNEL(
556 557
    tanh_grad,
    ops::XPUActivationGradKernel<ops::XPUTanhGradFunctor<float>>,
558 559 560
    ops::XPUActivationGradKernel<
        ops::XPUTanhGradFunctor<paddle::platform::float16>>);

561 562
REGISTER_OP_XPU_KERNEL(exp,
                       ops::XPUActivationKernel<ops::XPUExpFunctor<float>>);
563 564 565 566
REGISTER_OP_XPU_KERNEL(log,
                       ops::XPUActivationKernel<ops::XPULogFunctor<float>>);

#endif  // PADDLE_WITH_XPU