activation_op_xpu.cc 21.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifdef PADDLE_WITH_XPU

#include <string>
18 19 20

#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
21
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57

namespace paddle {
namespace operators {

using paddle::framework::Tensor;

template <typename Functor>
class XPUActivationKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    Functor functor;

    auto attrs = functor.GetAttrs();
    for (auto &attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
    functor(context);
  }
};

template <typename Functor>
class XPUActivationGradKernel
    : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
 public:
  void Compute(const framework::ExecutionContext &context) const override {
    Functor functor;

    auto attrs = functor.GetAttrs();
    for (auto &attr : attrs) {
      *attr.second = context.Attr<float>(attr.first);
    }
    functor(context);
  }
};

58
template <typename DeviceContext, typename T, typename XPUT>
T
TTerror 已提交
59 60
void xpu_activation_forward(
    const framework::ExecutionContext &ctx,
61
    std::function<int(xpu::Context *, const XPUT *, XPUT *, int)> func) {
62 63
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Output<Tensor>("Out");
64 65
  const XPUT *x_data = reinterpret_cast<const XPUT *>(x->data<T>());
  XPUT *y_data = reinterpret_cast<XPUT *>(y->mutable_data<T>(ctx.GetPlace()));
P
procr 已提交
66

T
TTerror 已提交
67 68 69
  auto xpu_context = ctx.device_context<DeviceContext>().x_context();
  int r = func(xpu_context, x_data, y_data, x->numel());
  PADDLE_ENFORCE_EQ(
70 71
      r,
      xpu::Error_t::SUCCESS,
T
TTerror 已提交
72
      platform::errors::External("XPU activation op return wrong value[%d %s].",
73 74
                                 r,
                                 XPUAPIErrorMsg[r]));
75 76
}

77 78 79
template <typename DeviceContext, typename T, typename XPUT>
void xpu_activation_backward(
    const framework::ExecutionContext &ctx,
80 81
    std::function<int(
        xpu::Context *, const XPUT *, const XPUT *, const XPUT *, XPUT *, int)>
82
        func) {
83 84 85 86 87
  /* TODO: relu tanh sigmoid are inplace */
  const auto *x = ctx.Input<Tensor>("X");
  auto *y = ctx.Input<Tensor>("Out");
  auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
  auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
88 89 90 91 92 93 94
  const XPUT *x_data = nullptr;
  const XPUT *y_data = nullptr;
  const XPUT *y_grad = nullptr;
  if (x != nullptr) x_data = reinterpret_cast<const XPUT *>(x->data<T>());
  if (y != nullptr) y_data = reinterpret_cast<const XPUT *>(y->data<T>());
  if (dOut != nullptr) y_grad = reinterpret_cast<const XPUT *>(dOut->data<T>());
  XPUT *x_grad = reinterpret_cast<XPUT *>(dX->mutable_data<T>(ctx.GetPlace()));
95
  auto xpu_context = ctx.device_context<DeviceContext>().x_context();
P
procr 已提交
96

T
TTerror 已提交
97
  int r = func(xpu_context, x_data, y_data, y_grad, x_grad, dX->numel());
98 99
  PADDLE_ENFORCE_EQ(r,
                    xpu::Error_t::SUCCESS,
100
                    platform::errors::External(
101 102
                        "XPU activation grad op return wrong value[%d %s].",
                        r,
T
TTerror 已提交
103
                        XPUAPIErrorMsg[r]));
104 105
}

T
TTerror 已提交
106
template <typename T>
107
struct XPUAbsFunctor : public BaseActivationFunctor<T> {
108
  using XPUType = typename XPUTypeTrait<T>::Type;
109
  void operator()(const framework::ExecutionContext &ctx) const {
110
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
111
        ctx, xpu::abs<XPUType>);
112 113 114
  }
};

T
TTerror 已提交
115
template <typename T>
116
struct XPUAbsGradFunctor : public BaseActivationFunctor<T> {
117
  using XPUType = typename XPUTypeTrait<T>::Type;
118
  void operator()(const framework::ExecutionContext &ctx) const {
119 120
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::abs_grad<XPUType>);
121 122 123 124
  }
};

template <typename T>
125
struct XPUExpFunctor : public BaseActivationFunctor<T> {
126
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
127
  void operator()(const framework::ExecutionContext &ctx) const {
128
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
129
        ctx, xpu::exp<XPUType>);
T
TTerror 已提交
130 131 132
  }
};

133
template <typename T>
T
TTerror 已提交
134
struct XPULogFunctor : public BaseActivationFunctor<T> {
135
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
136
  void operator()(const framework::ExecutionContext &ctx) const {
137 138
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::log<XPUType>);
T
TTerror 已提交
139 140 141
  }
};

142
template <typename T>
143
struct XPUReciprocalFunctor : public BaseActivationFunctor<T> {
144
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
145
  void operator()(const framework::ExecutionContext &ctx) const {
146
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
147
        ctx, xpu::reciprocal<XPUType>);
T
TTerror 已提交
148 149 150
  }
};

151
template <typename T>
152
struct XPUReciprocalGradFunctor : public BaseActivationFunctor<T> {
153
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
154
  void operator()(const framework::ExecutionContext &ctx) const {
155 156
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::reciprocal_grad<XPUType>);
T
TTerror 已提交
157 158 159
  }
};

160
template <typename T>
161
struct XPUReluFunctor : public BaseActivationFunctor<T> {
162
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
163
  void operator()(const framework::ExecutionContext &ctx) const {
164
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
165
        ctx, xpu::relu<XPUType>);
T
TTerror 已提交
166 167 168
  }
};

169
template <typename T>
170 171
struct XPUReluGradFunctor : public BaseActivationFunctor<T> {
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
172
  void operator()(const framework::ExecutionContext &ctx) const {
173 174
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::relu_grad<XPUType>);
T
TTerror 已提交
175 176 177
  }
};

P
procr 已提交
178
template <typename T>
179
struct XPUSigmoidFunctor : public BaseActivationFunctor<T> {
180
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
181
  void operator()(const framework::ExecutionContext &ctx) const {
182
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
183
        ctx, xpu::sigmoid<XPUType>);
T
TTerror 已提交
184 185 186
  }
};

187
template <typename T>
188
struct XPUSigmoidGradFunctor : public BaseActivationFunctor<T> {
189
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
190
  void operator()(const framework::ExecutionContext &ctx) const {
191
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
192
        ctx, xpu::sigmoid_grad<XPUType>);
T
TTerror 已提交
193 194 195
  }
};

196
template <typename T>
197
struct XPUSqrtFunctor : public BaseActivationFunctor<T> {
198
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
199
  void operator()(const framework::ExecutionContext &ctx) const {
200 201
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::sqrt<XPUType>);
T
TTerror 已提交
202 203 204
  }
};

205
template <typename T>
206
struct XPUSqrtGradFunctor : public BaseActivationFunctor<T> {
207
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
208
  void operator()(const framework::ExecutionContext &ctx) const {
209
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
210
        ctx, xpu::sqrt_grad<XPUType>);
T
TTerror 已提交
211 212 213
  }
};

214
template <typename T>
215
struct XPUSquareFunctor : public BaseActivationFunctor<T> {
216
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
217
  void operator()(const framework::ExecutionContext &ctx) const {
218 219
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::square<XPUType>);
T
TTerror 已提交
220 221 222
  }
};

223
template <typename T>
T
TTerror 已提交
224
struct XPUSquareGradFunctor : public BaseActivationFunctor<T> {
225
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
226
  void operator()(const framework::ExecutionContext &ctx) const {
227 228
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::square_grad<XPUType>);
T
TTerror 已提交
229 230 231
  }
};

232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
template <typename T>
struct XPUTanhFunctor : public BaseActivationFunctor<T> {
  using XPUType = typename XPUTypeTrait<T>::Type;
  void operator()(const framework::ExecutionContext &ctx) const {
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::tanh<XPUType>);
  }
};

template <typename T>
struct XPUTanhGradFunctor : public BaseActivationFunctor<T> {
  using XPUType = typename XPUTypeTrait<T>::Type;
  void operator()(const framework::ExecutionContext &ctx) const {
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::tanh_grad<XPUType>);
  }
};

template <typename T>
struct XPUHardSwishFunctor : public BaseActivationFunctor<T> {
  using XPUType = typename XPUTypeTrait<T>::Type;
  void operator()(const framework::ExecutionContext &ctx) const {
    float threshold = ctx.Attr<float>("threshold");
    float scale = ctx.Attr<float>("scale");
    float offset = ctx.Attr<float>("offset");
257 258
    PADDLE_ENFORCE_EQ(threshold,
                      6.0f,
259 260
                      platform::errors::External(
                          "Not support threshold [%f] in XPU", threshold));
261
    PADDLE_ENFORCE_EQ(
262 263
        scale,
        6.0f,
264
        platform::errors::External("Not support scale [%f] in XPU", scale));
265
    PADDLE_ENFORCE_EQ(
266 267
        offset,
        3.0f,
268 269 270 271 272 273
        platform::errors::External("Not support offset [%f] in XPU", offset));
    xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::hard_swish<XPUType>);
  }
};

274
template <typename T>
T
TTerror 已提交
275
struct XPUHardSwishGradFunctor : public BaseActivationFunctor<T> {
276
  using XPUType = typename XPUTypeTrait<T>::Type;
T
TTerror 已提交
277 278 279 280
  void operator()(const framework::ExecutionContext &ctx) const {
    float threshold = ctx.Attr<float>("threshold");
    float scale = ctx.Attr<float>("scale");
    float offset = ctx.Attr<float>("offset");
281 282
    PADDLE_ENFORCE_EQ(threshold,
                      6.0f,
T
TTerror 已提交
283 284
                      platform::errors::External(
                          "Not support threshold [%f] in XPU", threshold));
285
    PADDLE_ENFORCE_EQ(
286 287
        scale,
        6.0f,
288
        platform::errors::External("Not support scale [%f] in XPU", scale));
T
TTerror 已提交
289
    PADDLE_ENFORCE_EQ(
290 291
        offset,
        3.0f,
T
TTerror 已提交
292
        platform::errors::External("Not support offset [%f] in XPU", offset));
293 294
    xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
        ctx, xpu::hard_swish_grad<XPUType>);
T
TTerror 已提交
295 296 297
  }
};

P
procr 已提交
298
template <typename T>
T
TTerror 已提交
299 300 301 302 303 304 305 306 307 308 309 310
struct XPULeakyReluFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Output<Tensor>("Out");
    float alpha = ctx.Attr<float>("alpha");
    const T *x_data = x->data<T>();
    T *y_data = y->mutable_data<T>(ctx.GetPlace());

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
    int r = xpu::leaky_relu(xpu_context, x_data, y_data, x->numel(), alpha);
    PADDLE_ENFORCE_EQ(
311 312 313 314
        r,
        xpu::Error_t::SUCCESS,
        platform::errors::External(
            "XPU leaky_relu return wrong value[%d %s].", r, XPUAPIErrorMsg[r]));
T
TTerror 已提交
315 316 317
  }
};

318
template <typename T>
T
TTerror 已提交
319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335
struct XPULeakyReluGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    float alpha = ctx.Attr<float>("alpha");
    const T *x_data = nullptr;
    const T *y_grad = nullptr;
    if (x != nullptr) x_data = x->data<T>();
    if (dOut != nullptr) y_grad = dOut->data<T>();
    T *x_grad = dX->mutable_data<T>(ctx.GetPlace());
    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();

    // The signs of x and y are the same,
    // y == nullptr here,
    // so we give 2 x to the api
336 337 338 339 340 341 342 343 344
    int r = xpu::leaky_relu_grad(xpu_context,
                                 reinterpret_cast<const float *>(x_data),
                                 reinterpret_cast<const float *>(x_data),
                                 reinterpret_cast<const float *>(y_grad),
                                 reinterpret_cast<float *>(x_grad),
                                 dX->numel(),
                                 alpha);
    PADDLE_ENFORCE_EQ(r,
                      xpu::Error_t::SUCCESS,
T
TTerror 已提交
345
                      platform::errors::External(
346 347
                          "XPU leaky_relu_grad return wrong value[%d %s].",
                          r,
T
TTerror 已提交
348 349 350 351
                          XPUAPIErrorMsg[r]));
  }
};

352 353 354 355 356 357 358 359 360
template <typename T>
struct XPUPowFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Output<Tensor>("Out");
    auto pow_factor = ctx.Attr<float>("factor");
    const T *x_data = x->data<T>();
    T *y_data = y->mutable_data<T>(ctx.GetPlace());

361
    // allocate temp memory for factor on xpu
362 363
    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
364 365 366 367 368
    xpu::ctx_guard RAII_GUARD(xpu_context);
    T *factor_data = RAII_GUARD.alloc_l3_or_gm<T>(1);
    PADDLE_ENFORCE_NOT_NULL(
        factor_data,
        platform::errors::External("XPU alloc_l3_or_gm returns nullptr"));
369 370 371 372
    memory::Copy(ctx.GetPlace(),
                 static_cast<void *>(factor_data),
                 platform::CPUPlace(),
                 static_cast<void *>(&pow_factor),
373 374 375 376 377
                 sizeof(T));

    // broadcast_pow(Context* ctx, const T* x, const T* y, T* z, const
    // std::vector<int>& xshape, const std::vector<int>& yshape);
    auto x_dims = phi::vectorize<int>(x->dims());
378 379
    int r = xpu::broadcast_pow(
        xpu_context, x_data, factor_data, y_data, x_dims, {1});
380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_pow");
  }
};

template <typename T>
struct XPUPowGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));

    const T *x_data = x->data<T>();
    const T *y_grad = dOut->data<T>();
    T *x_grad = dX->mutable_data<T>(ctx.GetPlace());

    // check dims: all dims should equal
    auto x_dims = phi::vectorize<int>(x->dims());
    auto dy_dims = phi::vectorize<int>(dOut->dims());
    auto dx_dims = phi::vectorize<int>(dX->dims());
399
    PADDLE_ENFORCE_EQ(
400 401
        x_dims,
        dy_dims,
402 403
        platform::errors::PreconditionNotMet("x_dims should match dy_dims."));
    PADDLE_ENFORCE_EQ(
404 405
        x_dims,
        dx_dims,
406
        platform::errors::PreconditionNotMet("x_dims should match dx_dims."));
407 408 409 410 411 412
    float pow_factor = ctx.Attr<float>("factor");

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
    // int pow_grad(Context* ctx, const T* x, const T* dy, T* dx, int len, float
    // factor);
413 414
    int r = xpu::pow_grad(
        xpu_context, x_data, y_grad, x_grad, x->numel(), pow_factor);
415
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "pow_grad");
416 417 418
  }
};

419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452
template <typename T>
struct XPUSoftPlusFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Output<Tensor>("Out");
    const T *x_data = x->data<T>();
    T *y_data = y->mutable_data<T>(ctx.GetPlace());

    float beta = ctx.Attr<float>("beta");
    float threshold = ctx.Attr<float>("threshold");

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
    int r =
        xpu::softplus(xpu_context, x_data, y_data, x->numel(), beta, threshold);
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "softplus");
  }
};

template <typename T>
struct XPUSoftPlusGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    const T *x_data = x->data<T>();
    const T *y_grad = dOut->data<T>();
    T *x_grad = dX->mutable_data<T>(ctx.GetPlace());

    float beta = ctx.Attr<float>("beta");
    float threshold = ctx.Attr<float>("threshold");

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
453 454 455 456 457 458 459 460 461
    int r = xpu::softplus_grad(xpu_context,
                               reinterpret_cast<const float *>(x_data),
                               reinterpret_cast<const float *>(
                                   x_data),  // softplus_grad do not need y_data
                               reinterpret_cast<const float *>(y_grad),
                               reinterpret_cast<float *>(x_grad),
                               dX->numel(),
                               beta,
                               threshold);
462 463 464 465
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "softplus_grad");
  }
};

466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
template <typename T>
struct XPUSwishFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *y = ctx.Output<Tensor>("Out");
    const T *x_data = x->data<T>();
    T *y_data = y->mutable_data<T>(ctx.GetPlace());

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
    // int swish(Context* ctx, const T* x, T* y, int len);
    int r = xpu::swish(xpu_context, x_data, y_data, x->numel());
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish");
  }
};

template <typename T>
struct XPUSwishGradFunctor : public BaseActivationFunctor<T> {
  void operator()(const framework::ExecutionContext &ctx) const {
    const auto *x = ctx.Input<Tensor>("X");
    auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
    auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
    const T *x_data = x->data<T>();
    const T *y_grad = dOut->data<T>();
    T *x_grad = dX->mutable_data<T>(ctx.GetPlace());

    auto xpu_context =
        ctx.device_context<paddle::platform::XPUDeviceContext>().x_context();
    // int swish_grad(Context* ctx, const T* x, const T* dy, T* dx, int len);
    int r = xpu::swish_grad(xpu_context, x_data, y_grad, x_grad, dX->numel());
    PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish_grad");
  }
};

500 501 502 503 504 505 506 507 508 509 510 511
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

#define REGISTER_ACTIVATION_XPU_KERNEL(act_type, functor, grad_functor)  \
  REGISTER_OP_XPU_KERNEL(act_type,                                       \
                         ops::XPUActivationKernel<ops::functor<float>>); \
  REGISTER_OP_XPU_KERNEL(                                                \
      act_type##_grad,                                                   \
      ops::XPUActivationGradKernel<ops::grad_functor<float>>);

512
REGISTER_ACTIVATION_XPU_KERNEL(abs, XPUAbsFunctor, XPUAbsGradFunctor)
513 514
REGISTER_ACTIVATION_XPU_KERNEL(hard_swish,
                               XPUHardSwishFunctor,
515
                               XPUHardSwishGradFunctor)
516 517
REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu,
                               XPULeakyReluFunctor,
518
                               XPULeakyReluGradFunctor)
519 520
REGISTER_ACTIVATION_XPU_KERNEL(reciprocal,
                               XPUReciprocalFunctor,
521
                               XPUReciprocalGradFunctor)
522 523
REGISTER_ACTIVATION_XPU_KERNEL(sigmoid,
                               XPUSigmoidFunctor,
524 525
                               XPUSigmoidGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor)
T
TTerror 已提交
526
REGISTER_ACTIVATION_XPU_KERNEL(square, XPUSquareFunctor, XPUSquareGradFunctor)
527 528
REGISTER_ACTIVATION_XPU_KERNEL(softplus,
                               XPUSoftPlusFunctor,
529
                               XPUSoftPlusGradFunctor)
530 531
REGISTER_ACTIVATION_XPU_KERNEL(swish, XPUSwishFunctor, XPUSwishGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(pow, XPUPowFunctor, XPUPowGradFunctor)
532

533
REGISTER_OP_XPU_KERNEL(
534 535
    relu,
    ops::XPUActivationKernel<ops::XPUReluFunctor<float>>,
536 537
    ops::XPUActivationKernel<ops::XPUReluFunctor<paddle::platform::float16>>);
REGISTER_OP_XPU_KERNEL(
538 539
    relu_grad,
    ops::XPUActivationGradKernel<ops::XPUReluGradFunctor<float>>,
540 541
    ops::XPUActivationGradKernel<
        ops::XPUReluGradFunctor<paddle::platform::float16>>);
542
REGISTER_OP_XPU_KERNEL(
543 544
    tanh,
    ops::XPUActivationKernel<ops::XPUTanhFunctor<float>>,
545 546
    ops::XPUActivationKernel<ops::XPUTanhFunctor<paddle::platform::float16>>);
REGISTER_OP_XPU_KERNEL(
547 548
    tanh_grad,
    ops::XPUActivationGradKernel<ops::XPUTanhGradFunctor<float>>,
549 550 551
    ops::XPUActivationGradKernel<
        ops::XPUTanhGradFunctor<paddle::platform::float16>>);

552 553
REGISTER_OP_XPU_KERNEL(exp,
                       ops::XPUActivationKernel<ops::XPUExpFunctor<float>>);
554 555 556 557
REGISTER_OP_XPU_KERNEL(log,
                       ops::XPUActivationKernel<ops::XPULogFunctor<float>>);

#endif  // PADDLE_WITH_XPU