elementwise_compute.cc 20.8 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/arm/elementwise_compute.h"
#include <string>
#include <vector>
18
#include "lite/backends/arm/math/funcs.h"
Y
Yan Chunwei 已提交
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37

namespace paddle {
namespace lite {
namespace kernels {
namespace arm {

inline DDim trim_trailing_singular_dims(const DDim& dims) {
  // Remove trailing dimensions of size 1 for y
  auto actual_dims_size = dims.size();
  for (; actual_dims_size != 0; --actual_dims_size) {
    if (dims[actual_dims_size - 1] != 1) break;
  }

  std::vector<int64_t> trim_dims;
  trim_dims.resize(actual_dims_size);
  for (int i = 0; i < actual_dims_size; ++i) {
    trim_dims[i] = dims[i];
  }
  if (trim_dims.size() == 0) {
38
    return DDim();
Y
Yan Chunwei 已提交
39 40 41 42 43 44 45 46 47 48 49 50 51 52
  }
  return DDim(trim_dims);
}

inline bool is_broadcast(const DDim& x_dims,
                         const DDim& y_dims,
                         int axis,
                         int* pre,
                         int* n,
                         int* post) {
  if (axis < 0) {
    axis = x_dims.size() - y_dims.size();
  }
  DDim y_dim_trim = trim_trailing_singular_dims(y_dims);
53
  axis = (y_dim_trim.size() == 0) ? x_dims.size() : axis;
Y
Yan Chunwei 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
  if (x_dims.size() == y_dim_trim.size()) {
    return false;
  }
  *pre = 1;
  *n = 1;
  *post = 1;
  for (int i = 0; i < axis; ++i) {
    (*pre) *= x_dims[i];
  }
  for (int i = 0; i < y_dim_trim.size(); ++i) {
    CHECK_EQ(x_dims[i + axis], y_dim_trim[i])
        << "Broadcast dimension mismatch.";
    (*n) *= y_dim_trim[i];
  }
  for (int i = axis + y_dim_trim.size(); i < x_dims.size(); ++i) {
    (*post) *= x_dims[i];
  }
  return true;
}

74 75 76 77 78 79
template <typename T, PrecisionType PType>
void ElementwiseAddCompute<T, PType>::Run() {
  auto& param = this->template Param<operators::ElementwiseParam>();
  const T* x_data = param.X->template data<T>();
  const T* y_data = param.Y->template data<T>();
  T* out_data = param.Out->template mutable_data<T>();
Y
Yan Chunwei 已提交
80 81 82 83
  int axis = param.axis;
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
  int pre, n, post;
X
xiaogang 已提交
84 85
  if (x_dims.size() < y_dims.size() &&
      is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
86
    lite::arm::math::elementwise_add_broadcast<T>(
X
xiaogang 已提交
87 88
        y_data, x_data, out_data, pre, n, post);
  } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
89
    lite::arm::math::elementwise_add_broadcast<T>(
Y
Yan Chunwei 已提交
90 91
        x_data, y_data, out_data, pre, n, post);
  } else {
92
    lite::arm::math::elementwise_add<T>(
Y
Yan Chunwei 已提交
93 94 95 96 97 98 99 100 101 102 103 104 105 106
        x_data, y_data, out_data, x_dims.production());
  }
}

void ElementwiseAddActivationCompute::Run() {
  auto& param = Param<operators::FusionElementwiseActivationParam>();
  const float* x_data = param.X->data<float>();
  const float* y_data = param.Y->data<float>();
  float* out_data = param.Out->mutable_data<float>();
  int axis = param.axis;
  std::string act_type = param.act_type;
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
  int pre, n, post;
X
xiaogang 已提交
107 108 109 110 111 112 113 114 115
  if (x_dims.size() < y_dims.size() &&
      is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
    if (act_type == "relu") {
      lite::arm::math::elementwise_add_relu_broadcast(
          y_data, x_data, out_data, pre, n, post);
    } else {
      LOG(FATAL) << "unsupported Activation type: " << act_type;
    }
  } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
Y
Yan Chunwei 已提交
116 117 118 119 120 121 122 123 124 125
    if (act_type == "relu") {
      lite::arm::math::elementwise_add_relu_broadcast(
          x_data, y_data, out_data, pre, n, post);
    } else {
      LOG(FATAL) << "unsupported Activation type: " << act_type;
    }
  } else {
    if (act_type == "relu") {
      lite::arm::math::elementwise_add_relu(
          x_data, y_data, out_data, x_dims.production());
126 127 128
    } else if (act_type == "tanh") {
      lite::arm::math::elementwise_add_tanh(
          x_data, y_data, out_data, x_dims.production());
Y
Yan Chunwei 已提交
129 130 131 132 133 134
    } else {
      LOG(FATAL) << "unsupported Activation type: " << act_type;
    }
  }
}

135 136 137 138 139 140 141 142 143
void ElementwiseSubCompute::Run() {
  auto& param = Param<operators::ElementwiseParam>();
  const float* x_data = param.X->data<float>();
  const float* y_data = param.Y->data<float>();
  float* out_data = param.Out->mutable_data<float>();
  int axis = param.axis;
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
  int pre, n, post;
144 145 146 147 148
  if (x_dims.size() < y_dims.size() &&
      is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
    lite::arm::math::elementwise_sub_broadcast(
        y_data, x_data, out_data, pre, n, post);
  } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
    lite::arm::math::elementwise_sub_broadcast(
        x_data, y_data, out_data, pre, n, post);
  } else {
    lite::arm::math::elementwise_sub(
        x_data, y_data, out_data, x_dims.production());
  }
}

void ElementwiseSubActivationCompute::Run() {
  auto& param = Param<operators::FusionElementwiseActivationParam>();
  const float* x_data = param.X->data<float>();
  const float* y_data = param.Y->data<float>();
  float* out_data = param.Out->mutable_data<float>();
  int axis = param.axis;
  std::string act_type = param.act_type;
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
  int pre, n, post;
167 168 169 170 171 172 173 174 175 176 177

  if (act_type != "relu") {
    LOG(FATAL) << "unsupported Activation type: " << act_type;
  }
  if (x_dims.size() < y_dims.size() &&
      is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
    lite::arm::math::elementwise_sub_relu_broadcast(
        y_data, x_data, out_data, pre, n, post);
  } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
    lite::arm::math::elementwise_sub_relu_broadcast(
        x_data, y_data, out_data, pre, n, post);
178
  } else {
179 180
    lite::arm::math::elementwise_sub_relu(
        x_data, y_data, out_data, x_dims.production());
181 182 183
  }
}

J
juncaipeng 已提交
184 185 186
template <typename T, PrecisionType PType>
void ElementwiseMulCompute<T, PType>::Run() {
  auto& param = this->template Param<operators::ElementwiseParam>();
187 188 189 190 191 192 193 194 195 196 197 198 199 200
  auto* x_data = param.X->template data<T>();
  auto* y_data = param.Y->template data<T>();
  auto* out_data = param.Out->template mutable_data<T>();
  int axis = param.axis;
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
  int pre, n, post;
  if (x_dims.size() < y_dims.size() &&
      is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
    lite::arm::math::elementwise_mul_broadcast<T>(
        y_data, x_data, out_data, pre, n, post);
  } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
    lite::arm::math::elementwise_mul_broadcast<T>(
        x_data, y_data, out_data, pre, n, post);
Y
Yan Chunwei 已提交
201
  } else {
202 203
    lite::arm::math::elementwise_mul<T>(
        x_data, y_data, out_data, x_dims.production());
Y
Yan Chunwei 已提交
204 205 206
  }
}

207 208 209 210 211 212 213
template <typename T, PrecisionType PType>
void ElementwiseMulActivationCompute<T, PType>::Run() {
  auto& param =
      this->template Param<operators::FusionElementwiseActivationParam>();
  auto* x_data = param.X->template data<T>();
  auto* y_data = param.Y->template data<T>();
  auto* out_data = param.Out->template mutable_data<T>();
Y
Yan Chunwei 已提交
214 215 216 217 218
  int axis = param.axis;
  std::string act_type = param.act_type;
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
  int pre, n, post;
X
xiaogang 已提交
219 220 221
  if (x_dims.size() < y_dims.size() &&
      is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
    if (act_type == "relu") {
222
      lite::arm::math::elementwise_mul_relu_broadcast<T>(
X
xiaogang 已提交
223 224 225 226 227
          y_data, x_data, out_data, pre, n, post);
    } else {
      LOG(FATAL) << "unsupported Activation type: " << act_type;
    }
  } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
Y
Yan Chunwei 已提交
228
    if (act_type == "relu") {
229
      lite::arm::math::elementwise_mul_relu_broadcast<T>(
Y
Yan Chunwei 已提交
230 231 232 233 234 235
          x_data, y_data, out_data, pre, n, post);
    } else {
      LOG(FATAL) << "unsupported Activation type: " << act_type;
    }
  } else {
    if (act_type == "relu") {
236
      lite::arm::math::elementwise_mul_relu<T>(
Y
Yan Chunwei 已提交
237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
          x_data, y_data, out_data, x_dims.production());
    } else {
      LOG(FATAL) << "unsupported Activation type: " << act_type;
    }
  }
}

void ElementwiseMaxCompute::Run() {
  auto& param = Param<operators::ElementwiseParam>();
  const float* x_data = param.X->data<float>();
  const float* y_data = param.Y->data<float>();
  float* out_data = param.Out->mutable_data<float>();
  int axis = param.axis;
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
  int pre, n, post;
X
xiaogang 已提交
253 254 255 256 257
  if (x_dims.size() < y_dims.size() &&
      is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
    lite::arm::math::elementwise_max_broadcast(
        y_data, x_data, out_data, pre, n, post);
  } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
Y
Yan Chunwei 已提交
258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275
    lite::arm::math::elementwise_max_broadcast(
        x_data, y_data, out_data, pre, n, post);
  } else {
    lite::arm::math::elementwise_max(
        x_data, y_data, out_data, x_dims.production());
  }
}

void ElementwiseMaxActivationCompute::Run() {
  auto& param = Param<operators::FusionElementwiseActivationParam>();
  const float* x_data = param.X->data<float>();
  const float* y_data = param.Y->data<float>();
  float* out_data = param.Out->mutable_data<float>();
  int axis = param.axis;
  std::string act_type = param.act_type;
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
  int pre, n, post;
X
xiaogang 已提交
276 277 278 279 280 281 282 283 284
  if (x_dims.size() < y_dims.size() &&
      is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
    if (act_type == "relu") {
      lite::arm::math::elementwise_max_relu_broadcast<float>(
          y_data, x_data, out_data, pre, n, post);
    } else {
      LOG(FATAL) << "unsupported Activation type: " << act_type;
    }
  } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
Y
Yan Chunwei 已提交
285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
    if (act_type == "relu") {
      lite::arm::math::elementwise_max_relu_broadcast(
          x_data, y_data, out_data, pre, n, post);
    } else {
      LOG(FATAL) << "unsupported Activation type: " << act_type;
    }
  } else {
    if (act_type == "relu") {
      lite::arm::math::elementwise_max_relu(
          x_data, y_data, out_data, x_dims.production());
    } else {
      LOG(FATAL) << "unsupported Activation type: " << act_type;
    }
  }
}

301 302 303 304 305 306
template <typename T, PrecisionType PType>
void ElementwiseDivCompute<T, PType>::Run() {
  auto& param = this->template Param<operators::ElementwiseParam>();
  auto* x_data = param.X->template data<T>();
  auto* y_data = param.Y->template data<T>();
  auto* out_data = param.Out->template mutable_data<T>();
307 308 309 310
  int axis = param.axis;
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
  int pre, n, post;
X
xiaogang 已提交
311 312 313
  if (x_dims.size() < y_dims.size()) {
    LOG(FATAL) << "elewise div don't support x_dims size < y_dims size";
  }
314
  if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
315
    lite::arm::math::elementwise_div_broadcast<T>(
316 317
        x_data, y_data, out_data, pre, n, post);
  } else {
318
    lite::arm::math::elementwise_div<T>(
319 320 321 322 323 324 325 326 327 328 329 330 331
        x_data, y_data, out_data, x_dims.production());
  }
}

void ElementwiseDivActivationCompute::Run() {
  auto& param = Param<operators::FusionElementwiseActivationParam>();
  const float* x_data = param.X->data<float>();
  const float* y_data = param.Y->data<float>();
  float* out_data = param.Out->mutable_data<float>();
  int axis = param.axis;
  std::string act_type = param.act_type;
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
X
xiaogang 已提交
332 333 334
  if (x_dims.size() < y_dims.size()) {
    LOG(FATAL) << "elewise div don't support x_dims size < y_dims size";
  }
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352
  int pre, n, post;
  if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
    if (act_type == "relu") {
      lite::arm::math::elementwise_div_relu_broadcast(
          x_data, y_data, out_data, pre, n, post);
    } else {
      LOG(FATAL) << "unsupported Activation type: " << act_type;
    }
  } else {
    if (act_type == "relu") {
      lite::arm::math::elementwise_div_relu(
          x_data, y_data, out_data, x_dims.production());
    } else {
      LOG(FATAL) << "unsupported Activation type: " << act_type;
    }
  }
}

353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375
template <typename T, PrecisionType PType>
void ElementwiseModCompute<T, PType>::Run() {
  auto& param = this->template Param<operators::ElementwiseParam>();
  auto* x_data = param.X->template data<T>();
  auto* y_data = param.Y->template data<T>();
  auto* out_data = param.Out->template mutable_data<T>();
  int axis = param.axis;
  auto x_dims = param.X->dims();
  auto y_dims = param.Y->dims();
  int pre, n, post;
  if (x_dims.size() < y_dims.size() &&
      is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
    lite::arm::math::elementwise_mod_broadcast<T>(
        y_data, x_data, out_data, pre, n, post);
  } else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
    lite::arm::math::elementwise_mod_broadcast<T>(
        x_data, y_data, out_data, pre, n, post);
  } else {
    lite::arm::math::elementwise_mod<T>(
        x_data, y_data, out_data, x_dims.production());
  }
}

Y
Yan Chunwei 已提交
376 377 378 379 380
}  // namespace arm
}  // namespace kernels
}  // namespace lite
}  // namespace paddle

381 382 383 384
using elementwise_add_float_t =
    paddle::lite::kernels::arm::ElementwiseAddCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(
    elementwise_add, kARM, kFloat, kNCHW, elementwise_add_float_t, def)
Y
Yan Chunwei 已提交
385 386 387 388 389
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();

390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409
using elementwise_add_int32_t =
    paddle::lite::kernels::arm::ElementwiseAddCompute<int32_t,
                                                      PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(
    elementwise_add, kARM, kInt32, kNCHW, elementwise_add_int32_t, def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
    .Finalize();

using elementwise_add_int64_t =
    paddle::lite::kernels::arm::ElementwiseAddCompute<int64_t,
                                                      PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(
    elementwise_add, kARM, kInt64, kNCHW, elementwise_add_int64_t, def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .Finalize();

Y
Yan Chunwei 已提交
410 411 412 413 414 415 416 417 418 419 420 421
REGISTER_LITE_KERNEL(
    fusion_elementwise_add_activation,
    kARM,
    kFloat,
    kNCHW,
    paddle::lite::kernels::arm::ElementwiseAddActivationCompute,
    def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();

422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444
REGISTER_LITE_KERNEL(elementwise_sub,
                     kARM,
                     kFloat,
                     kNCHW,
                     paddle::lite::kernels::arm::ElementwiseSubCompute,
                     def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();

REGISTER_LITE_KERNEL(
    fusion_elementwise_sub_activation,
    kARM,
    kFloat,
    kNCHW,
    paddle::lite::kernels::arm::ElementwiseSubActivationCompute,
    def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();

445
using elementwise_mul_float_t =
J
juncaipeng 已提交
446 447
    paddle::lite::kernels::arm::ElementwiseMulCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(
448
    elementwise_mul, kARM, kFloat, kNCHW, elementwise_mul_float_t, def)
Y
Yan Chunwei 已提交
449 450 451 452 453
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();

454
using elementwise_mul_int32_t =
J
juncaipeng 已提交
455 456
    paddle::lite::kernels::arm::ElementwiseMulCompute<int, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(
457
    elementwise_mul, kARM, kInt32, kNCHW, elementwise_mul_int32_t, def)
J
juncaipeng 已提交
458 459 460 461 462
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
    .Finalize();

463
using elementwise_mul_int64_t =
464 465 466
    paddle::lite::kernels::arm::ElementwiseMulCompute<int64_t,
                                                      PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(
467
    elementwise_mul, kARM, kInt64, kNCHW, elementwise_mul_int64_t, def)
468 469 470 471 472
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .Finalize();

473 474 475 476 477 478 479 480
using fusion_elementwise_mul_activation_float_t = paddle::lite::kernels::arm::
    ElementwiseMulActivationCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
                     kARM,
                     kFloat,
                     kNCHW,
                     fusion_elementwise_mul_activation_float_t,
                     def)
Y
Yan Chunwei 已提交
481 482 483 484 485
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();

486 487 488 489 490 491 492 493 494 495 496 497 498
using fusion_elementwise_mul_activation_int64_t = paddle::lite::kernels::arm::
    ElementwiseMulActivationCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(fusion_elementwise_mul_activation,
                     kARM,
                     kInt64,
                     kNCHW,
                     fusion_elementwise_mul_activation_int64_t,
                     def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .Finalize();

Y
Yan Chunwei 已提交
499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520
REGISTER_LITE_KERNEL(elementwise_max,
                     kARM,
                     kFloat,
                     kNCHW,
                     paddle::lite::kernels::arm::ElementwiseMaxCompute,
                     def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();

REGISTER_LITE_KERNEL(
    fusion_elementwise_max_activation,
    kARM,
    kFloat,
    kNCHW,
    paddle::lite::kernels::arm::ElementwiseMaxActivationCompute,
    def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();
521

522
using elementwise_div_fp32_t =
523 524 525
    paddle::lite::kernels::arm::ElementwiseDivCompute<float, PRECISION(kFloat)>;

REGISTER_LITE_KERNEL(
526
    elementwise_div, kARM, kFloat, kNCHW, elementwise_div_fp32_t, def)
527 528 529 530 531
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();

532
using elementwise_div_int64_t =
533 534 535 536
    paddle::lite::kernels::arm::ElementwiseDivCompute<int64_t,
                                                      PRECISION(kInt64)>;

REGISTER_LITE_KERNEL(
537
    elementwise_div, kARM, kInt64, kNCHW, elementwise_div_int64_t, def)
538 539 540 541 542
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .Finalize();

543 544 545 546 547 548 549 550 551 552 553
REGISTER_LITE_KERNEL(
    fusion_elementwise_div_activation,
    kARM,
    kFloat,
    kNCHW,
    paddle::lite::kernels::arm::ElementwiseDivActivationCompute,
    def)
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
    .Finalize();
554

555
using elementwise_mod_int64_t =
556 557 558
    paddle::lite::kernels::arm::ElementwiseModCompute<int64_t,
                                                      PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(
559
    elementwise_mod, kARM, kInt64, kNCHW, elementwise_mod_int64_t, def)
560 561 562 563
    .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
    .Finalize();