uniform_random_op.h 6.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <algorithm>
#include <utility>
#include <vector>
19

20
#include "paddle/fluid/framework/op_registry.h"
Y
yaoxuefeng 已提交
21
#include "paddle/fluid/framework/operator.h"
22 23
#if defined(__NVCC__) || defined(__HIPCC__)
#include <thrust/random.h>
24

25 26
#include "paddle/fluid/framework/generator.h"
#include "paddle/phi/kernels/full_kernel.h"
27 28
#include "paddle/phi/kernels/funcs/distribution_helper.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
29
#endif
30 31 32

namespace paddle {
namespace operators {
33
using Tensor = phi::DenseTensor;
34

35
inline std::vector<int64_t> GetNewDataFromShapeTensor(
36
    const phi::DenseTensor* new_data_tensor) {
37 38
  if (framework::TransToProtoVarType(new_data_tensor->dtype()) ==
      framework::proto::VarType::INT64) {
39
    auto* new_data = new_data_tensor->data<int64_t>();
40
    phi::DenseTensor cpu_starts_tensor;
41
    if (platform::is_gpu_place(new_data_tensor->place())) {
42 43
      paddle::framework::TensorCopySync(
          *new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
44 45 46 47 48
      new_data = cpu_starts_tensor.data<int64_t>();
    }
    std::vector<int64_t> vec_new_data(new_data,
                                      new_data + new_data_tensor->numel());
    return vec_new_data;
49 50
  } else if (framework::TransToProtoVarType(new_data_tensor->dtype()) ==
             framework::proto::VarType::INT32) {
51
    auto* new_data = new_data_tensor->data<int32_t>();
52
    std::vector<int64_t> vec_new_data;
53
    phi::DenseTensor cpu_starts_tensor;
54
    if (platform::is_gpu_place(new_data_tensor->place())) {
55 56
      paddle::framework::TensorCopySync(
          *new_data_tensor, platform::CPUPlace(), &cpu_starts_tensor);
57 58
      new_data = cpu_starts_tensor.data<int32_t>();
    }
59
    for (int i = 0; i < new_data_tensor->numel(); ++i) {
60 61 62 63
      vec_new_data.push_back(static_cast<int64_t>(*(new_data + i)));
    }
    return vec_new_data;
  } else {
64 65 66
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Expected dtype of ShapeTensor must be int32, int64. But got "
        "unsupport dtype: %s.",
67
        new_data_tensor->dtype()));
68 69 70
  }
}

71
inline std::vector<int64_t> GetNewDataFromShapeTensorList(
72
    const std::vector<const phi::DenseTensor*>& list_new_shape_tensor) {
73 74 75 76
  std::vector<int64_t> vec_new_shape;
  vec_new_shape.reserve(list_new_shape_tensor.size());
  for (size_t i = 0; i < list_new_shape_tensor.size(); ++i) {
    auto tensor = list_new_shape_tensor[i];
77
    PADDLE_ENFORCE_EQ(
78 79
        tensor->dims(),
        phi::make_ddim({1}),
80 81 82 83
        platform::errors::InvalidArgument(
            "Shape of dim tensor in uniform_random_op should be [1]"
            "But received tensor's dim=%s.",
            tensor->dims()));
84

85 86
    if (framework::TransToProtoVarType(tensor->dtype()) ==
        framework::proto::VarType::INT32) {
87
      if (platform::is_gpu_place(tensor->place())) {
88
        phi::DenseTensor temp;
89
        paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp);
90 91 92 93
        vec_new_shape.push_back(static_cast<int64_t>(*temp.data<int32_t>()));
      } else {
        vec_new_shape.push_back(static_cast<int64_t>(*tensor->data<int32_t>()));
      }
94 95
    } else if (framework::TransToProtoVarType(tensor->dtype()) ==
               framework::proto::VarType::INT64) {
96
      if (platform::is_gpu_place(tensor->place())) {
97
        phi::DenseTensor temp;
98
        paddle::framework::TensorCopySync(*tensor, platform::CPUPlace(), &temp);
99 100 101 102
        vec_new_shape.push_back(*temp.data<int64_t>());
      } else {
        vec_new_shape.push_back(*tensor->data<int64_t>());
      }
103
    } else {
104 105 106 107
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Expected dtype of ShapeTensorList of %d-th must be int32, int64. "
          "But got "
          "unsupport dtype: %s.",
108 109 110
          i,
          paddle::framework::DataTypeToString(
              framework::TransToProtoVarType(tensor->dtype()))));
111 112 113 114 115
    }
  }

  return vec_new_shape;
}
116 117 118 119 120 121 122 123 124 125

#if defined(__NVCC__) || defined(__HIPCC__)

template <typename T>
struct UniformGenerator {
  T min_, max_;
  unsigned int seed_;
  T diag_val_;
  unsigned int diag_num_;
  unsigned int diag_step_;
126 127
  __host__ __device__ UniformGenerator(
      T min, T max, int seed, int diag_num, int diag_step, T diag_val)
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150
      : min_(min),
        max_(max),
        seed_(seed),
        diag_num_(diag_num),
        diag_step_(diag_step),
        diag_val_(diag_val) {}

  __host__ __device__ T operator()(const unsigned int n) const {
    thrust::minstd_rand rng;
    rng.seed(seed_);
    thrust::uniform_real_distribution<T> dist(min_, max_);
    rng.discard(n);
    T out = dist(rng);
    unsigned int remainder = n % (diag_step_ + 1);
    if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
      out = diag_val_;
    }
    return out;
  }
};

template <typename T>
void UniformRandom(const framework::ExecutionContext& context,
151
                   phi::DenseTensor* tensor) {
152
  int64_t size = tensor->numel();
L
Leo Chen 已提交
153
  auto& dev_cxt = context.template device_context<phi::GPUContext>();
154 155 156 157 158 159 160 161 162 163 164
  T* data = tensor->mutable_data<T>(dev_cxt.GetPlace());
  if (size <= 0) return;
  unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));

  T min = static_cast<T>(context.Attr<float>("min"));
  T max = static_cast<T>(context.Attr<float>("max"));
  unsigned int diag_num =
      static_cast<unsigned int>(context.Attr<int>("diag_num"));
  unsigned int diag_step =
      static_cast<unsigned int>(context.Attr<int>("diag_step"));
  T diag_val = static_cast<T>(context.Attr<float>("diag_val"));
165 166 167

  if (seed == 0) {
    // Use global Generator seed
168
    using MT = typename phi::dtype::MPTypeTrait<T>::Type;
169 170 171
    phi::funcs::uniform_distribution<MT> dist;
    phi::funcs::uniform_real_transform<MT> trans(min, max);
    phi::funcs::distribution_and_transform<T>(dev_cxt, tensor, dist, trans);
172
  } else {
173
    // Use OP seed
174 175
    auto func =
        UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val);
176
    phi::IndexKernel<T, UniformGenerator<T>>(dev_cxt, tensor, func);
177 178 179
  }
}
#endif
180 181
}  // namespace operators
}  // namespace paddle