elementwise_npu.h 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once

#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
19
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

namespace paddle {
namespace operators {
using Tensor = framework::Tensor;

template <typename T>
void NpuBroadcast(const platform::NPUDeviceContext& dev_ctx, const Tensor* src,
                  int axis, const framework::DDim& dst_dims,
                  Tensor* transformed_src) {
  auto stream = dev_ctx.stream();

  // 1. expand the axis with dim 1
  auto src_dims = src->dims();
  Tensor tmp_src;
  tmp_src.ShareDataWith(*src);
  tmp_src.Resize(src_dims);
  for (int i = 0; i < src_dims.size(); ++i) {
    if (src_dims[i] == 1 && dst_dims[i + axis] > 1) {
      Tensor tmp_tensor;
      auto tmp_tensor_dims = tmp_src.dims();
      tmp_tensor_dims[i] = dst_dims[i + axis];
      tmp_tensor.mutable_data<T>(tmp_tensor_dims, dev_ctx.GetPlace());
      const auto& runner =
          NpuOpRunner("TileWithAxis", {tmp_src}, {tmp_tensor},
                      {{"axis", static_cast<int64_t>(i)},
                       {"tiles", static_cast<int64_t>(dst_dims[i + axis])}});
      runner.Run(stream);
      tmp_src.ShareDataWith(tmp_tensor);
      tmp_src.Resize(tmp_tensor_dims);
    }
  }

  // 2.expand the ahead axis
53
  auto prev = phi::product(phi::slice_ddim(dst_dims, 0, axis));
54 55
  if (prev > 1) {
    Tensor tmp_tensor;
56
    auto tmp_tensor_dims = phi::slice_ddim(dst_dims, 0, axis + src_dims.size());
57
    tmp_tensor.mutable_data<T>(tmp_tensor_dims, dev_ctx.GetPlace());
58 59
    const auto& runner =
        NpuOpRunner("ExpandD", {tmp_src}, {tmp_tensor},
60
                    {{"shape", phi::vectorize<int64_t>(tmp_tensor_dims)}});
61 62 63 64
    runner.Run(stream);
    tmp_src.ShareDataWith(tmp_tensor);
    tmp_src.Resize(tmp_tensor_dims);
  } else {
65
    tmp_src.Resize(phi::slice_ddim(dst_dims, 0, axis + src_dims.size()));
66 67 68
  }

  // 3.expand the tail axis
69 70
  auto post = phi::product(
      phi::slice_ddim(dst_dims, axis + src_dims.size(), dst_dims.size()));
71
  if (post > 1) {
72
    auto src_dims_vec = phi::vectorize<int>(tmp_src.dims());
73
    src_dims_vec.push_back(1);
74
    tmp_src.Resize(phi::make_ddim(src_dims_vec));
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

    Tensor tmp_tensor;
    tmp_tensor.mutable_data<T>(dst_dims, dev_ctx.GetPlace());
    const auto& runner =
        NpuOpRunner("TileWithAxis", {tmp_src}, {tmp_tensor},
                    {{"axis", static_cast<int64_t>(axis + src_dims.size())},
                     {"tiles", static_cast<int64_t>(post)}});
    runner.Run(stream);
    tmp_src.ShareDataWith(tmp_tensor);
  }
  tmp_src.Resize(dst_dims);
  framework::TensorCopy(tmp_src, dev_ctx.GetPlace(), transformed_src);
}

template <typename T>
void NpuElementWiseOpBroadcast(const platform::NPUDeviceContext& dev_ctx,
                               const Tensor* x, const Tensor* y, int axis,
                               Tensor* transformed_x, Tensor* transformed_y) {
  auto x_dims = x->dims();
  auto y_dims = y->dims();
  bool is_xsize_larger = true;
  int max_dim = x_dims.size();
97
  std::vector<int> dst_dims_vec = phi::vectorize<int>(x_dims);
98 99 100 101

  if (x_dims.size() < y_dims.size()) {
    is_xsize_larger = false;
    max_dim = y_dims.size();
102
    dst_dims_vec = phi::vectorize<int>(y_dims);
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
  }

  axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
  int x_axis = is_xsize_larger ? 0 : axis;
  int y_axis = is_xsize_larger ? axis : 0;

  PADDLE_ENFORCE_GE(
      axis, 0,
      platform::errors::InvalidArgument(
          "Axis should be great than or equal to 0, but received axis is %d.",
          axis));
  PADDLE_ENFORCE_LT(axis, max_dim,
                    platform::errors::InvalidArgument(
                        "Axis should be less than %d, but received axis is %d.",
                        max_dim, axis));

  for (int i = 0; i < x_dims.size(); ++i) {
    dst_dims_vec[i + x_axis] =
        std::max(dst_dims_vec[i + x_axis], static_cast<int>(x_dims[i]));
  }
  for (int i = 0; i < y_dims.size(); ++i) {
    dst_dims_vec[i + y_axis] =
        std::max(dst_dims_vec[i + y_axis], static_cast<int>(y_dims[i]));
  }

128
  auto dst_dims = phi::make_ddim(dst_dims_vec);
129 130 131 132 133 134
  NpuBroadcast<T>(dev_ctx, x, x_axis, dst_dims, transformed_x);
  NpuBroadcast<T>(dev_ctx, y, y_axis, dst_dims, transformed_y);
}

}  // namespace operators
}  // namespace paddle