broadcast_function.h 22.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

17
#include "paddle/phi/kernels/funcs/elementwise_base.h"
18

19
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
20

21
namespace kps = phi::kps;
22 23 24

#endif

25
namespace phi {
26 27
namespace funcs {

28 29
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)

30 31 32 33
struct DimensionsTransform {
  using DimVector = std::vector<int64_t>;
  typedef void (*MergeFunctor)(
      bool &, std::vector<DimVector> &, DimVector &, int, int);
34
  int64_t N;
35 36 37 38 39
  int64_t dim_size;
  DimVector out_dims;
  std::vector<DimVector> in_dims;

 private:
40 41
  // To compensate the lackage of input_tensors` dimension with input
  // variable 'axis'.
42 43 44 45 46 47 48 49 50 51 52
  void InputDimensionsExtend(int N, int axis) {
    for (auto &in_dim : in_dims) {
      int64_t in_idx = 0;
      if (in_dim.size() < dim_size) {
        DimVector tmp_dim(dim_size, 1);
        do {
          if (in_dim[in_idx] == out_dims[axis] || in_dim[in_idx] == 1) {
            tmp_dim[axis] = in_dim[in_idx];
            in_idx++;
            axis++;
          } else {
53
            PADDLE_THROW(phi::errors::InvalidArgument(
54 55
                "The %d-th dimension of input tensor is expected to be equal "
                "with the %d-th dimension of output tensor %d or 1, but "
56
                "received %d.",
57 58 59 60 61 62 63 64 65 66 67 68 69
                in_idx + 1,
                axis + 1,
                out_dims[axis],
                in_dim[in_idx]));
          }
        } while (in_idx < in_dim.size());
        in_dim.resize(dim_size);
        std::copy(tmp_dim.begin(), tmp_dim.end(), in_dim.begin());
      } else {
        do {
          if (in_dim[in_idx] == out_dims[in_idx] || in_dim[in_idx] == 1) {
            in_idx++;
          } else {
70
            PADDLE_THROW(phi::errors::InvalidArgument(
71 72
                "The %d-th dimension of input tensor is expected to be equal "
                "with the %d-th dimension of output tensor %d or 1, but "
73
                "received %d.",
74 75 76 77 78 79 80 81 82 83 84 85
                in_idx + 1,
                in_idx + 1,
                out_dims[in_idx],
                in_dim[in_idx]));
          }
        } while (in_idx < dim_size);
      }
      std::reverse(in_dim.begin(), in_dim.end());
    }
    std::reverse(out_dims.begin(), out_dims.end());
  }

86 87
  // Merge sequential dimension to shrink calculation cost for
  // offset computation in CUDA Kernel.
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
  template <typename MergeFunctor>
  __inline__ void MergeDimensions(MergeFunctor merge_func, int N) {
    auto VectorReorganise = [](DimVector *vec, int l_idx, int m_idx) {
      (*vec)[m_idx - 1] = std::accumulate(vec->begin() + l_idx,
                                          vec->begin() + m_idx,
                                          1,
                                          std::multiplies<int64_t>());
      vec->erase(vec->begin() + l_idx, vec->begin() + m_idx - 1);
    };

    int64_t i = 0;
    while (i < dim_size) {
      int cnt = 0;
      int low_idx = i;
      bool equal = true;
      do {
        merge_func(equal, in_dims, out_dims, i, N);
        if (equal) {
          i++;
          cnt++;
        } else {
          break;
        }
      } while (i < dim_size);

      if (cnt > 1) {
        for (auto &in_dim : in_dims) {
          VectorReorganise(&in_dim, low_idx, i);
        }
        VectorReorganise(&out_dims, low_idx, i);
        dim_size -= --cnt;
        i -= cnt;
      } else if (cnt < 1) {
        i++;
      }
    }
  }

126 127
  // To judge whether shape of any input tensors is sequential
  // 1-value-dimensions, and metric the length of it.
128
  bool FindSequentialOneDim(int *swap_index) {
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
    int index = 0;
    int max_one_length = 0;
    for (int j = 0; j < N; ++j) {
      int seq_one_length = 0;
      bool active_seq = false;

      for (int i = 0; i < dim_size; ++i) {
        if (!active_seq && in_dims[j][i] == 1) {
          seq_one_length = 1;
          active_seq = true;
        } else if (active_seq) {
          if (in_dims[j][i] == 1) {
            seq_one_length++;
          } else {
            active_seq = false;
          }
        }
      }
      index = seq_one_length > max_one_length ? j : index;
148
      max_one_length = std::max(seq_one_length, max_one_length);
149 150
    }

151 152
    bool has_seq_one = max_one_length > 1;
    if (has_seq_one) {
153 154 155
      std::swap(in_dims[0], in_dims[index]);
      *swap_index = index;
    }
156
    return has_seq_one;
157 158
  }

159 160
 public:
  explicit DimensionsTransform(const std::vector<const DenseTensor *> &ins,
161
                               const phi::DDim &dims,
162
                               int axis) {
163
    N = std::max(static_cast<int>(ins.size()), 2);
164
    dim_size = dims.size();
165
    out_dims = phi::vectorize<int64_t>(dims);
166 167 168
    in_dims.resize(N);
    if (ins.size() == 1) {
      // when ins.size() = 1, broadcast input to output
169
      in_dims[0] = phi::vectorize<int64_t>(ins[0]->dims());
170 171 172 173
      in_dims[1] = out_dims;
      // Add out_dims to in_dims to avoid errors in dims merging
    } else {
      for (int j = 0; j < N; ++j) {
174
        in_dims[j] = phi::vectorize<int64_t>(ins[j]->dims());
175 176 177 178
      }
    }
    InputDimensionsExtend(N, axis);

179 180 181 182 183
    // To Merge the dimensions of input_tensors while the consequtive
    // equal-dimensions appears. Example below :
    //   in_1.shape = [2, 3, 4, 5]    in_1.shape = [2, 12, 5]
    //   in_2.shape = [1, 3, 4, 5] -> in_2.shape = [1, 12, 5]
    //   in_3.shape = [2, 3, 4, 1]    in_3.shape = [2, 12, 1]
184 185 186 187 188 189 190 191 192
    auto merge_sequential_dims = [](bool &equal,
                                    std::vector<DimVector> &in_dims,
                                    DimVector &out,
                                    int i,
                                    int num) {
      for (int j = 1; j < num; ++j) {
        equal &= (in_dims[0][i] == in_dims[j][i]) ? true : false;
      }
    };
193 194 195 196 197 198 199 200 201 202 203
    MergeFunctor merge_ptr = merge_sequential_dims;
    MergeDimensions<MergeFunctor>(merge_ptr, N);

    // To Merge the dimension of input_tensors while the sequential
    // 1-value-dimensions appears. Example below :
    //   in_1.shape = [2, 1, 1, 5]    in_1.shape = [2,  1, 5]
    //   in_2.shape = [2, 3, 4, 5] -> in_2.shape = [1, 12, 5]
    //   in_3.shape = [2, 3, 4, 1]    in_3.shape = [2, 12, 1]
    // Caution: Once 1-value-dimensions appears, the corresponding
    // shape position of other input tensors must be same with the
    // output tensor`s shape, or incorrect merge may occur.
204 205 206 207 208 209 210 211 212 213 214 215
    auto merge_sequential_one_dims = [](bool &equal,
                                        std::vector<DimVector> &in_dims,
                                        DimVector &out,
                                        int i,
                                        int num) {
      equal = in_dims[0][i] == 1;
      if (equal) {
        for (int j = 1; j < num; ++j) {
          equal &= in_dims[j][i] == out[i];
        }
      }
    };
216 217 218 219
    for (auto i = 0; i < dim_size; ++i) {
      int swap_idx = 0;
      bool has_seq_one = FindSequentialOneDim(&swap_idx);
      if (!has_seq_one) break;
220 221 222
      merge_ptr = merge_sequential_one_dims;
      MergeDimensions<MergeFunctor>(merge_ptr, N);
      std::swap(in_dims[swap_idx], in_dims[0]);
223 224 225 226
    }
  }
};

227
template <typename InT, typename OutT>
228 229 230 231
int GetVecsize(const std::vector<const DenseTensor *> &ins,
               std::vector<DenseTensor *> *outs) {
  int in_vec_size = 4;
  int out_vec_size = 4;
232 233
  if (outs->size() > 1) {
    for (auto i = 1; i < outs->size(); ++i) {
234 235 236 237 238 239 240 241 242 243
      PADDLE_ENFORCE_EQ(
          (*outs)[i]->dims(),
          (*outs)[0]->dims(),
          phi::errors::InvalidArgument(
              "The shape of each output tensor shall be identical yet, but "
              "%d-th output tensor`s shape is not.",
              i));
      out_vec_size = std::min(
          phi::GetVectorizedSize<OutT>((*outs)[i]->data<OutT>()), out_vec_size);
    }
244
  } else {
245 246 247 248 249 250 251 252
    out_vec_size = phi::GetVectorizedSize<OutT>((*outs)[0]->data<OutT>());
  }

  for (auto *in : ins) {
    auto temp_size = phi::GetVectorizedSize<InT>(in->data<InT>());
    in_vec_size = in->dims() == (*outs)[0]->dims()
                      ? std::min(temp_size, in_vec_size)
                      : in_vec_size;
253
  }
254
  return std::min(out_vec_size, in_vec_size);
255 256
}

257
template <typename T, int VecSize, bool IsBoundary = false>
258 259 260 261
__device__ __forceinline__ void LoadData(
    T *dst,
    const _ptr_ T *src,
    uint32_t block_offset,
262
    const kps::details::BroadcastConfig &config,
263 264 265 266 267 268 269
    int numel,
    int num,
    int need_broadcast,
    int read_lens) {
  // numel : whole num of output
  // num: how many data will be deal with in this time
  if (need_broadcast) {
270
    kps::ReadDataBc<T, VecSize, 1, IsBoundary>(
271 272
        dst, src, block_offset, config, numel, read_lens);
  } else {
273
    kps::ReadData<T, VecSize, 1, IsBoundary>(
274 275 276 277
        dst, src + block_offset, num, read_lens);
  }
}

278 279 280 281 282 283 284 285
template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
          int VecSize,
          bool IsBoundary = false>
__device__ void VectorizedBroadcastKernelImpl(
286 287 288
    const phi::Array<const _ptr_ InT *__restrict__, Arity> &ins,
    phi::Array<_ptr_ OutT *, NumOuts> outs,
    const phi::Array<int, Arity> &use_broadcast,
289
    uint32_t numel,
290
    const phi::Array<kps::details::BroadcastConfig, Arity> &configs,
291 292
    int num,
    int block_offset,
293
    int read_lens,
294
    Functor func) {
295 296
  __simd__ InT args[Arity][VecSize];
  __simd__ ConditionalT<OutT, NumOuts> result[VecSize];
297 298

#pragma unroll
299
  for (int i = 0; i < Arity; ++i) {
300
    kps::Init<InT, VecSize>(args[i], static_cast<InT>(1.0f), read_lens);
301 302 303 304 305 306 307 308
    LoadData<InT, VecSize, IsBoundary>(args[i],
                                       ins[i],
                                       block_offset,
                                       configs[i],
                                       numel,
                                       num,
                                       use_broadcast[i],
                                       read_lens);
309 310 311
  }
  constexpr bool kCallElementwiseAny =
      paddle::platform::FunctionTraits<Functor>::has_pointer_args;
312 313 314 315 316 317
  phi::funcs::ElementwisePrimitiveCaller<InT,
                                         ConditionalT<OutT, NumOuts>,
                                         VecSize,
                                         Functor,
                                         Arity,
                                         kCallElementwiseAny>()(
318 319 320 321
      func, args, result, read_lens);
  phi::funcs::
      ElementwiseWriteDataCallerBc<OutT, VecSize, IsBoundary, NumOuts>()(
          outs, result, block_offset, num, read_lens);
322 323 324 325 326 327 328
}

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
329
          int VecSize>
330
__global__ void VectorizedBroadcastKernel(
331 332 333
    phi::Array<const _ptr_ InT *__restrict__, Arity> ins,
    phi::Array<_ptr_ OutT *, NumOuts> outs,
    phi::Array<int, Arity> use_broadcast,
334
    uint32_t numel,
335
    phi::Array<kps::details::BroadcastConfig, Arity> configs,
336 337
    int main_offset,
    int tail_tid,
338
    int read_lens,
339
    Functor func) {
340 341
  int block_offset = BLOCK_ID_X * BLOCK_NUM_X * read_lens;
  int stride = BLOCK_NUM_X * GRID_NUM_X * read_lens;
342

343
#ifdef PADDLE_WITH_XPU_KP
344 345 346 347 348 349 350 351 352 353 354 355
  for (; block_offset < main_offset; block_offset += stride) {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
                                  false>(ins,
                                         outs,
                                         use_broadcast,
                                         numel,
                                         configs,
356
                                         BLOCK_NUM_X * read_lens,
357
                                         block_offset,
358
                                         read_lens,
359 360 361 362 363 364 365 366 367 368
                                         func);
  }
  int num = numel - block_offset;
  if (num > 0) {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
369 370 371 372 373 374 375 376 377
                                  true>(ins,
                                        outs,
                                        use_broadcast,
                                        numel,
                                        configs,
                                        num,
                                        block_offset,
                                        read_lens,
                                        func);
378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393
  }
#else
  if (block_offset < main_offset) {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
                                  false>(ins,
                                         outs,
                                         use_broadcast,
                                         numel,
                                         configs,
                                         BLOCK_NUM_X * VecSize,
                                         block_offset,
394
                                         read_lens,
395 396 397 398 399 400 401 402
                                         func);
  } else {
    VectorizedBroadcastKernelImpl<InT,
                                  OutT,
                                  Functor,
                                  Arity,
                                  NumOuts,
                                  VecSize,
403 404 405 406 407 408 409 410 411
                                  true>(ins,
                                        outs,
                                        use_broadcast,
                                        numel,
                                        configs,
                                        tail_tid,
                                        block_offset,
                                        read_lens,
                                        func);
412 413 414 415 416 417 418 419 420
  }
#endif
}

template <typename InT,
          typename OutT,
          typename Functor,
          int Arity,
          int NumOuts,
421 422 423 424 425 426 427
          int VecSize>
void LaunchBroadcastKernel(
    const KPDevice &ctx,
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
    Functor func,
    const phi::Array<kps::details::BroadcastConfig, Arity> &configs) {
428
  int numel = (*outs)[0]->numel();
429 430 431
  phi::Array<int, Arity> use_broadcast;
  phi::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
  phi::Array<_ptr_ OutT *, NumOuts> outs_data;
432 433

  for (int i = 0; i < NumOuts; ++i) {
434
    outs_data[i] = (_ptr_ OutT *)(ctx.Alloc<OutT>((*outs)[i]));
435 436
  }

437
  for (int i = 0; i < Arity; ++i) {
438
    use_broadcast[i] = (ins[i]->numel() != numel);
439
    ins_data[i] = (const _ptr_ InT *)(ins[i]->data<InT>());
440 441
  }

442
#ifdef PADDLE_WITH_XPU_KP
443 444
  const int threads = 64;
  const int blocks = 8;
445
  int read_lens = configs[0].buf_len;
446
  auto stream = ctx.x_context()->xpu_stream;
447 448
  int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
  int tail_tid = numel % (read_lens * threads);
449
#else
450 451 452
  auto gpu_config =
      phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
  int read_lens = VecSize;
453
  auto stream = ctx.stream();
454 455 456 457 458 459
  auto threads = gpu_config.thread_per_block;
  auto blocks = gpu_config.block_per_grid;
  int main_offset = (numel / (read_lens * gpu_config.GetBlockSize())) *
                    read_lens * gpu_config.GetBlockSize();
  int tail_tid = numel % (read_lens * gpu_config.GetBlockSize());
#endif
460 461 462 463 464 465 466 467 468 469
  VectorizedBroadcastKernel<InT, OutT, Functor, Arity, NumOuts, VecSize>
      <<<blocks, threads, 0, stream>>>(ins_data,
                                       outs_data,
                                       use_broadcast,
                                       numel,
                                       configs,
                                       main_offset,
                                       tail_tid,
                                       read_lens,
                                       func);
470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485
}

template <ElementwiseType ET,
          typename InT,
          typename OutT,
          typename Functor,
          int NumOuts = 1>
void BroadcastKernelForDifferentVecSize(
    const KPDevice &ctx,
    const std::vector<const DenseTensor *> &ins,
    std::vector<DenseTensor *> *outs,
    int axis,
    Functor func) {
  using Traits = paddle::platform::FunctionTraits<Functor>;
  const int kArity =
      Traits::has_pointer_args ? static_cast<int>(ET) : Traits::arity;
486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514
  PADDLE_ENFORCE_EQ(
      ins.size(),
      kArity,
      phi::errors::InvalidArgument("The number of inputs is expected to be "
                                   "equal to the "
                                   "arity of functor. But recieved: the "
                                   "number of inputs "
                                   "is %d, the arity of functor is %d.",
                                   ins.size(),
                                   kArity));
  PADDLE_ENFORCE_LE(
      kArity,
      3,
      phi::errors::InvalidArgument("Currently only broadcast of ternary is "
                                   "supported "
                                   "and verified, but received %d.",
                                   kArity));
  PADDLE_ENFORCE_EQ(
      outs->size(),
      NumOuts,
      phi::errors::InvalidArgument("Number of outputs shall equal to number "
                                   "of functions, "
                                   "but number of outputs is %d, of "
                                   "functions is %d.",
                                   outs->size(),
                                   NumOuts));
  // mergedim and get vec_size
  const auto merge_dims = DimensionsTransform(ins, (*outs)[0]->dims(), axis);
  phi::Array<kps::details::BroadcastConfig, kArity> configs;
515

516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534
// get vec_size
#ifdef PADDLE_WITH_XPU_KP
  PADDLE_ENFORCE_EQ(
      ins.size(),
      2,
      phi::errors::InvalidArgument(
          "XPU only support inputs is 2, but received %d", ins.size()));
  configs[0] = kps::details::BroadcastConfig(merge_dims.out_dims,
                                             merge_dims.in_dims[0],
                                             merge_dims.in_dims[1],
                                             merge_dims.dim_size);
  configs[1] = kps::details::BroadcastConfig(merge_dims.out_dims,
                                             merge_dims.in_dims[1],
                                             merge_dims.in_dims[0],
                                             merge_dims.dim_size);
  auto type = kps::details::OptType::CanNotOptimize;
  bool is_optimize = configs[0].cmp_type != type;
  int vec_size = is_optimize ? VecSizeL : VecSizeM;
#else
535
  for (int i = 0; i < kArity; ++i) {
536 537 538 539 540 541 542
    // get the broadcast config,
    // if data shape is[m, n], then you should set data_dim = {n, m}
    // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
    if (ins[i]->numel()) {
      configs[i] = kps::details::BroadcastConfig(
          merge_dims.out_dims, merge_dims.in_dims[i], merge_dims.dim_size);
    }
543
  }
544
  int vec_size = GetVecsize<InT, OutT>(ins, outs);
545
#endif
546 547

  switch (vec_size) {
548 549 550
    case VecSizeL: {
      LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeL>(
          ctx, ins, outs, func, configs);
551 552
      break;
    }
553 554 555
    case VecSizeM: {
      LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeM>(
          ctx, ins, outs, func, configs);
556 557
      break;
    }
558 559 560
    case VecSizeS: {
      LaunchBroadcastKernel<InT, OutT, Functor, kArity, NumOuts, VecSizeS>(
          ctx, ins, outs, func, configs);
561 562 563
      break;
    }
    default: {
564
      PADDLE_THROW(phi::errors::Unimplemented(
565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581
          "Unsupported vectorized size: %d!", vec_size));
      break;
    }
  }
}

template <ElementwiseType ET,
          typename InT,
          typename OutT,
          typename Functor,
          int NumOuts = 1>
void BroadcastKernel(const KPDevice &ctx,
                     const std::vector<const DenseTensor *> &ins,
                     std::vector<DenseTensor *> *outs,
                     int axis,
                     Functor func) {
  std::vector<int> dims_size;
582
  dims_size.reserve(ins.size());
583 584 585 586
  for (auto *in : ins) {
    dims_size.emplace_back(in->dims().size());
  }

587 588 589
  axis = axis == -1 ? *std::max_element(dims_size.begin(), dims_size.end()) -
                          *std::min_element(dims_size.begin(), dims_size.end())
                    : axis;
590 591
  BroadcastKernelForDifferentVecSize<ET, InT, OutT, Functor, NumOuts>(
      ctx, ins, outs, axis, func);
592 593
}

594 595 596 597 598 599 600 601 602 603 604 605 606 607
template <typename Functor, typename T, typename OutType = T>
void ElementwiseCompute(const GPUContext &dev_ctx,
                        const DenseTensor &x,
                        const DenseTensor &y,
                        int axis,
                        Functor func,
                        DenseTensor *z) {
  std::vector<const DenseTensor *> ins = {&x, &y};
  std::vector<DenseTensor *> outs = {z};
  z->mutable_data<OutType>(dev_ctx.GetPlace());
  BroadcastKernel<ElementwiseType::kBinary, T, OutType, Functor, 1>(
      dev_ctx, ins, &outs, axis, func);
}

608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623
template <typename DeviceContext,
          typename T,
          typename Functor,
          typename InverseFunctor>
void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
                                const DenseTensor &x,
                                const DenseTensor &y,
                                DenseTensor *z,
                                int axis = -1) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  dev_ctx.template Alloc<T>(z);
  funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z);
}

#else
624

625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644
template <typename DeviceContext,
          typename T,
          typename Functor,
          typename InverseFunctor>
void DefaultElementwiseOperator(const DeviceContext &dev_ctx,
                                const DenseTensor &x,
                                const DenseTensor &y,
                                DenseTensor *z,
                                int axis = -1) {
  auto x_dims = x.dims();
  auto y_dims = y.dims();
  dev_ctx.template Alloc<T>(z);
  if (x_dims.size() >= y_dims.size()) {
    funcs::ElementwiseCompute<Functor, T>(dev_ctx, x, y, axis, Functor(), z);
  } else {
    funcs::ElementwiseCompute<InverseFunctor, T>(
        dev_ctx, x, y, axis, InverseFunctor(), z);
  }
}

645 646
#endif

647
}  // namespace funcs
648
}  // namespace phi