elementwise_npu.h 5.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
19
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
20 21 22

namespace paddle {
namespace operators {
23
using Tensor = phi::DenseTensor;
24 25

template <typename T>
26
void NpuBroadcast(const platform::NPUDeviceContext& dev_ctx,
27
                  const phi::DenseTensor* src,
28 29
                  int axis,
                  const framework::DDim& dst_dims,
30
                  phi::DenseTensor* transformed_src) {
31 32 33 34 35 36 37 38 39 40 41 42 43 44
  auto stream = dev_ctx.stream();

  // 1. expand the axis with dim 1
  auto src_dims = src->dims();
  Tensor tmp_src;
  tmp_src.ShareDataWith(*src);
  tmp_src.Resize(src_dims);
  for (int i = 0; i < src_dims.size(); ++i) {
    if (src_dims[i] == 1 && dst_dims[i + axis] > 1) {
      Tensor tmp_tensor;
      auto tmp_tensor_dims = tmp_src.dims();
      tmp_tensor_dims[i] = dst_dims[i + axis];
      tmp_tensor.mutable_data<T>(tmp_tensor_dims, dev_ctx.GetPlace());
      const auto& runner =
45 46 47
          NpuOpRunner("TileWithAxis",
                      {tmp_src},
                      {tmp_tensor},
48 49 50 51 52 53 54 55 56
                      {{"axis", static_cast<int64_t>(i)},
                       {"tiles", static_cast<int64_t>(dst_dims[i + axis])}});
      runner.Run(stream);
      tmp_src.ShareDataWith(tmp_tensor);
      tmp_src.Resize(tmp_tensor_dims);
    }
  }

  // 2.expand the ahead axis
57
  auto prev = phi::product(phi::slice_ddim(dst_dims, 0, axis));
58 59
  if (prev > 1) {
    Tensor tmp_tensor;
60
    auto tmp_tensor_dims = phi::slice_ddim(dst_dims, 0, axis + src_dims.size());
61
    tmp_tensor.mutable_data<T>(tmp_tensor_dims, dev_ctx.GetPlace());
62
    const auto& runner =
63 64 65
        NpuOpRunner("ExpandD",
                    {tmp_src},
                    {tmp_tensor},
66
                    {{"shape", phi::vectorize<int64_t>(tmp_tensor_dims)}});
67 68 69 70
    runner.Run(stream);
    tmp_src.ShareDataWith(tmp_tensor);
    tmp_src.Resize(tmp_tensor_dims);
  } else {
71
    tmp_src.Resize(phi::slice_ddim(dst_dims, 0, axis + src_dims.size()));
72 73 74
  }

  // 3.expand the tail axis
75 76
  auto post = phi::product(
      phi::slice_ddim(dst_dims, axis + src_dims.size(), dst_dims.size()));
77
  if (post > 1) {
78
    auto src_dims_vec = phi::vectorize<int>(tmp_src.dims());
79
    src_dims_vec.push_back(1);
80
    tmp_src.Resize(phi::make_ddim(src_dims_vec));
81 82 83 84

    Tensor tmp_tensor;
    tmp_tensor.mutable_data<T>(dst_dims, dev_ctx.GetPlace());
    const auto& runner =
85 86 87
        NpuOpRunner("TileWithAxis",
                    {tmp_src},
                    {tmp_tensor},
88 89 90 91 92 93 94 95 96 97 98
                    {{"axis", static_cast<int64_t>(axis + src_dims.size())},
                     {"tiles", static_cast<int64_t>(post)}});
    runner.Run(stream);
    tmp_src.ShareDataWith(tmp_tensor);
  }
  tmp_src.Resize(dst_dims);
  framework::TensorCopy(tmp_src, dev_ctx.GetPlace(), transformed_src);
}

template <typename T>
void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx,
99 100
                               const phi::DenseTensor* x,
                               const phi::DenseTensor* y,
101
                               int axis,
102 103
                               phi::DenseTensor* transformed_x,
                               phi::DenseTensor* transformed_y) {
104 105 106 107
  auto x_dims = x->dims();
  auto y_dims = y->dims();
  bool is_xsize_larger = true;
  int max_dim = x_dims.size();
108
  std::vector<int> dst_dims_vec = phi::vectorize<int>(x_dims);
109 110 111 112

  if (x_dims.size() < y_dims.size()) {
    is_xsize_larger = false;
    max_dim = y_dims.size();
113
    dst_dims_vec = phi::vectorize<int>(y_dims);
114 115 116 117 118 119 120
  }

  axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
  int x_axis = is_xsize_larger ? 0 : axis;
  int y_axis = is_xsize_larger ? axis : 0;

  PADDLE_ENFORCE_GE(
121 122
      axis,
      0,
123 124 125
      platform::errors::InvalidArgument(
          "Axis should be great than or equal to 0, but received axis is %d.",
          axis));
126
  PADDLE_ENFORCE_LE(axis,
127
                    max_dim,
128 129
                    platform::errors::InvalidArgument(
                        "Axis should be less than %d, but received axis is %d.",
130 131
                        max_dim,
                        axis));
132 133 134 135 136 137 138 139 140 141

  for (int i = 0; i < x_dims.size(); ++i) {
    dst_dims_vec[i + x_axis] =
        std::max(dst_dims_vec[i + x_axis], static_cast<int>(x_dims[i]));
  }
  for (int i = 0; i < y_dims.size(); ++i) {
    dst_dims_vec[i + y_axis] =
        std::max(dst_dims_vec[i + y_axis], static_cast<int>(y_dims[i]));
  }

142
  auto dst_dims = phi::make_ddim(dst_dims_vec);
143 144 145 146 147 148
  NpuBroadcast<T>(dev_ctx, x, x_axis, dst_dims, transformed_x);
  NpuBroadcast<T>(dev_ctx, y, y_axis, dst_dims, transformed_y);
}

}  // namespace operators
}  // namespace paddle