diag_embed_op.h 4.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"

namespace paddle {
namespace operators {

template <typename T>
struct DiagEmbedFunctor {
  DiagEmbedFunctor(const T* input, int64_t numel, const int64_t* dim,
                   int64_t offset, int64_t dims_size, T* output,
                   const int64_t* strides)
      : input_(input),
        numel_(numel),
        dim_(dim),
        offset_(offset),
        dims_size_(dims_size),
        output_(output),
        strides_(strides) {}

  HOSTDEVICE void operator()(size_t idx) const {
    int64_t position = 0;
    auto numel = numel_;
    int64_t num = idx;
    for (int64_t i = 0; i < dims_size_; i++) {
      numel = numel / dim_[i];
      position += num / numel * strides_[i];
      num = num % numel;
    }
    output_[position + offset_] = input_[idx];
  }

  const T* input_;
  int64_t numel_;
  const int64_t* dim_;
  int64_t offset_;
  int64_t dims_size_;
  T* output_;
  const int64_t* strides_;
};

template <typename DeviceContext, typename T>
class DiagEmbedKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& context) const override {
    auto* input = context.Input<framework::Tensor>("Input");
    auto* out = context.Output<framework::Tensor>("Out");

    const int64_t offset = context.Attr<int>("offset");
    const int64_t dim1 = context.Attr<int>("dim1");
    const int64_t dim2 = context.Attr<int>("dim2");
    auto* input_data = input->data<T>();

    T* out_data = out->mutable_data<T>(context.GetPlace());
    math::SetConstant<DeviceContext, T> set_zero;
    auto& dev_ctx = context.template device_context<DeviceContext>();
    set_zero(dev_ctx, out, static_cast<T>(0.0));

    auto out_dims = out->dims();
    int dim1_ = dim1 < 0 ? out_dims.size() + dim1 : dim1;
    int dim2_ = dim2 < 0 ? out_dims.size() + dim2 : dim2;
    auto stride = framework::stride(out_dims);
    int64_t diag_size;
    int64_t storage_offset = 0;
    if (offset >= 0) {
      int64_t dim = out_dims[dim2_] - offset;
      diag_size = std::max<int64_t>(std::min(out_dims[dim1_], dim), 0);
    } else {
      int64_t dim = out_dims[dim1_] + offset;
      diag_size = std::max<int64_t>(std::min(dim, out_dims[dim2_]), 0);
    }
    if (diag_size == 0) {
      // skip
    } else if (offset >= 0) {
      storage_offset += offset * stride[dim2_];
    } else {
      storage_offset -= offset * stride[dim1_];
    }
    auto strides = vectorize(stride);
    strides.erase(strides.begin() + std::max(dim1_, dim2_));
    strides.erase(strides.begin() + std::min(dim1_, dim2_));
    strides.push_back(stride[dim1_] + stride[dim2_]);
    const auto dims = vectorize(input->dims());

#ifdef __NVCC__
    thrust::device_vector<int64_t> dims_vec(dims);
    const int64_t* dims_arr = thrust::raw_pointer_cast(dims_vec.data());
    thrust::device_vector<int64_t> strides_vec(strides);
    const int64_t* strides_arr = thrust::raw_pointer_cast(strides_vec.data());
#else
    const int64_t* dims_arr = dims.data();
    const int64_t* strides_arr = strides.data();
#endif

    platform::ForRange<DeviceContext> for_range(dev_ctx, input->numel());
    DiagEmbedFunctor<T> functor(input_data, input->numel(), dims_arr,
                                storage_offset, dims.size(), out_data,
                                strides_arr);
    for_range(functor);
  }
};
}  // namespace operators
}  // namespace paddle