random_functor.cpp 15.6 KB
Newer Older
L
Liang Depeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
B
Bowen Chen 已提交
16
#include "oneflow/core/common/global.h"
L
Liang Depeng 已提交
17
#include "oneflow/core/common/optional.h"
B
Bowen Chen 已提交
18
#include "oneflow/core/common/protobuf.h"
L
Liang Depeng 已提交
19 20 21 22 23 24 25 26
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/op_interpreter.h"
#include "oneflow/core/framework/random_generator.h"
B
Bowen Chen 已提交
27
#include "oneflow/core/framework/nd_sbp.h"
L
Liang Depeng 已提交
28 29 30
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/impl/common.h"
#include "oneflow/core/functional/impl/unary_functor.h"
B
Bowen Chen 已提交
31 32
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/global_for.h"
K
Kevin_Xiong 已提交
33 34
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/job/lazy_mode.h"
L
Liang Depeng 已提交
35
#include "oneflow/user/kernels/bernoulli_kernel.h"
B
Bowen Chen 已提交
36
#include "oneflow/user/kernels/distributions/normal_kernel.h"
B
Bowen Chen 已提交
37
#include "oneflow/user/kernels/distributions/uniform_kernel.h"
K
Kevin_Xiong 已提交
38

L
Liang Depeng 已提交
39 40 41 42 43 44 45 46 47 48 49
namespace oneflow {
namespace one {
namespace functional {

namespace impl {

class BernoulliFunctor {
 public:
  BernoulliFunctor() {
    bernoulli_op_ = CHECK_JUST(one::OpBuilder("bernoulli").Input("in").Output("out").Build());
  }
Z
ZZK 已提交
50
  Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const Symbol<DType>& dtype,
L
Liang Depeng 已提交
51 52
                           const Optional<one::Generator>& generator) const {
    MutableAttrMap bernoulli_attrs;
Z
ZZK 已提交
53
    JUST(bernoulli_attrs.SetAttr<DataType>("dtype", dtype->data_type()));
L
Liang Depeng 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66

    std::shared_ptr<one::Generator> gen;
    if (!generator) {
      gen = JUST(one::DefaultAutoGenerator());
    } else {
      gen = JUST(generator.value());
    }

    JUST(bernoulli_attrs.SetAttr<int64_t>("seed", gen->current_seed()));

    const auto& bernoulli_kernel_state = std::make_shared<BernoulliKernelState>(gen);

    return OpInterpUtil::Dispatch<Tensor>(
L
Li Xinqi 已提交
67
        *bernoulli_op_, {x}, OpExprInterpContext(bernoulli_attrs, bernoulli_kernel_state));
L
Liang Depeng 已提交
68 69 70 71 72
  }

 private:
  std::shared_ptr<OpExpr> bernoulli_op_;
};
B
Bowen Chen 已提交
73 74 75
class RandFunctor {
 public:
  RandFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }
Z
ZZK 已提交
76
  Maybe<Tensor> operator()(const Shape& shape, const Optional<Symbol<DType>>& dtype,
B
Bowen Chen 已提交
77 78 79 80
                           const Optional<Symbol<Device>>& device,
                           const Optional<one::Generator>& generator) const {
    DataType dtype_val = DataType::kFloat;
    if (dtype.has_value()) {
Z
ZZK 已提交
81
      dtype_val = JUST(dtype.value())->data_type();
B
Bowen Chen 已提交
82
      if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
83
        OF_UNIMPLEMENTED() << "Only support float and double in rand().";
B
Bowen Chen 已提交
84 85 86 87 88 89 90 91 92 93
      }
    }

    MutableAttrMap attrs;
    JUST(attrs.SetAttr<double>("low", 0));
    JUST(attrs.SetAttr<double>("high", 1));
    JUST(attrs.SetAttr<Shape>("shape", shape));
    JUST(attrs.SetAttr<DataType>("dtype", dtype_val));

    std::shared_ptr<one::Generator> gen;
94

B
Bowen Chen 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123
    if (!generator) {
      gen = JUST(one::DefaultAutoGenerator());
    } else {
      gen = JUST(generator.value());
    }

    JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));

    const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);

    if (device.has_value()) {
      Symbol<Device> device_symbol = JUST(device.value());
      return OpInterpUtil::Dispatch<Tensor>(
          *op_, {}, OpExprInterpContext(attrs, device_symbol, uniform_kernel_state));
    } else {
      return OpInterpUtil::Dispatch<Tensor>(*op_, {},
                                            OpExprInterpContext(attrs, uniform_kernel_state));
    }
  }

 private:
  std::shared_ptr<OpExpr> op_;
};

class ConsistentRandFunctor {
 public:
  ConsistentRandFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }
  Maybe<Tensor> operator()(const Shape& shape, const Symbol<ParallelDesc>& placement,
                           const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
Z
ZZK 已提交
124
                           const Optional<Symbol<DType>>& dtype,
B
Bowen Chen 已提交
125 126 127
                           const Optional<one::Generator>& generator) const {
    DataType dtype_val = DataType::kFloat;
    if (dtype.has_value()) {
Z
ZZK 已提交
128
      dtype_val = JUST(dtype.value())->data_type();
B
Bowen Chen 已提交
129
      if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
130
        OF_UNIMPLEMENTED() << "Only support float and double in rand().";
B
Bowen Chen 已提交
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163
      }
    }

    MutableAttrMap attrs;
    JUST(attrs.SetAttr<double>("low", 0));
    JUST(attrs.SetAttr<double>("high", 1));
    JUST(attrs.SetAttr<Shape>("shape", shape));
    JUST(attrs.SetAttr<DataType>("dtype", dtype_val));

    std::shared_ptr<one::Generator> gen;
    if (!generator) {
      gen = JUST(one::DefaultAutoGenerator());
    } else {
      gen = JUST(generator.value());
    }

    JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));

    const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);

    const auto& parallel_distribution = JUST(GetNdSbp(sbp_tuple));
    if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
      JUST(attrs.SetAttr<std::string>("nd_sbp", parallel_distribution->DebugString()));
    }
    return OpInterpUtil::Dispatch<Tensor>(
        *op_, {},
        OpExprInterpContext(attrs, placement, parallel_distribution, uniform_kernel_state));
  }

 private:
  std::shared_ptr<OpExpr> op_;
};

B
Bowen Chen 已提交
164 165 166
class RandNFunctor {
 public:
  RandNFunctor() { op_ = CHECK_JUST(one::OpBuilder("normal").Output("out").Build()); }
Z
ZZK 已提交
167
  Maybe<Tensor> operator()(const Shape& shape, const Optional<Symbol<DType>>& dtype,
B
Bowen Chen 已提交
168 169 170 171
                           const Optional<Symbol<Device>>& device,
                           const Optional<one::Generator>& generator) const {
    DataType dtype_val = DataType::kFloat;
    if (dtype.has_value()) {
Z
ZZK 已提交
172 173
      dtype_val = JUST(dtype.value())->data_type();

B
Bowen Chen 已提交
174
      if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
175
        OF_UNIMPLEMENTED() << "Only support float and double in randn().";
B
Bowen Chen 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
      }
    }

    MutableAttrMap attrs;
    JUST(attrs.SetAttr<double>("mean", 0));
    JUST(attrs.SetAttr<double>("std", 1));
    JUST(attrs.SetAttr<Shape>("shape", shape));
    JUST(attrs.SetAttr<DataType>("dtype", dtype_val));

    std::shared_ptr<one::Generator> gen;

    if (!generator) {
      gen = JUST(one::DefaultAutoGenerator());
    } else {
      gen = JUST(generator.value());
    }

    JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));

    const auto& normal_kernel_state = std::make_shared<NormalKernelState>(gen);

    if (device.has_value()) {
      Symbol<Device> device_symbol = JUST(device.value());
      return OpInterpUtil::Dispatch<Tensor>(
          *op_, {}, OpExprInterpContext(attrs, device_symbol, normal_kernel_state));
    } else {
      return OpInterpUtil::Dispatch<Tensor>(*op_, {},
                                            OpExprInterpContext(attrs, normal_kernel_state));
    }
  }

 private:
  std::shared_ptr<OpExpr> op_;
};

class ConsistentRandNFunctor {
 public:
  ConsistentRandNFunctor() { op_ = CHECK_JUST(one::OpBuilder("normal").Output("out").Build()); }
  Maybe<Tensor> operator()(const Shape& shape, const Symbol<ParallelDesc>& placement,
                           const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
Z
ZZK 已提交
216
                           const Optional<Symbol<DType>>& dtype,
B
Bowen Chen 已提交
217 218 219
                           const Optional<one::Generator>& generator) const {
    DataType dtype_val = DataType::kFloat;
    if (dtype.has_value()) {
Z
ZZK 已提交
220 221
      dtype_val = JUST(dtype.value())->data_type();

B
Bowen Chen 已提交
222
      if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
223
        OF_UNIMPLEMENTED() << "Only support float and double in randn().";
B
Bowen Chen 已提交
224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
      }
    }

    MutableAttrMap attrs;
    JUST(attrs.SetAttr<double>("mean", 0));
    JUST(attrs.SetAttr<double>("std", 1));
    JUST(attrs.SetAttr<Shape>("shape", shape));
    JUST(attrs.SetAttr<DataType>("dtype", dtype_val));

    std::shared_ptr<one::Generator> gen;
    if (!generator) {
      gen = JUST(one::DefaultAutoGenerator());
    } else {
      gen = JUST(generator.value());
    }

    JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));

    const auto& normal_kernel_state = std::make_shared<NormalKernelState>(gen);

244
    const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
B
Bowen Chen 已提交
245
    if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
246
      JUST(attrs.SetAttr<std::string>("nd_sbp", nd_sbp->DebugString()));
B
Bowen Chen 已提交
247 248
    }
    return OpInterpUtil::Dispatch<Tensor>(
249
        *op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, normal_kernel_state));
B
Bowen Chen 已提交
250 251 252 253 254
  }

 private:
  std::shared_ptr<OpExpr> op_;
};
K
Kevin_Xiong 已提交
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
class RandIntFunctor {
 public:
  RandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }

  Maybe<Tensor> operator()(const int64_t low, const int64_t high, const Shape& shape,
                           const Optional<Symbol<DType>>& dtype,
                           const Optional<Symbol<Device>>& device,
                           const Optional<one::Generator>& generator) const {
    DataType dtype_val = DataType::kInt64;
    if (dtype.has_value()) {
      dtype_val = JUST(dtype.value())->data_type();

      if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
        OF_UNIMPLEMENTED() << dtype_val << "not supported in randn";
      }
    }

    MutableAttrMap attrs;
    JUST(attrs.SetAttr<Shape>("shape", shape));
    JUST(attrs.SetAttr<double>("low", low));
    JUST(attrs.SetAttr<double>("high", high - 1));
    JUST(attrs.SetAttr<DataType>("dtype", dtype_val));

    std::shared_ptr<one::Generator> gen;
    if (!generator) {
      gen = JUST(one::DefaultAutoGenerator());
    } else {
      gen = JUST(generator.value());
    }
    JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));

    const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);
    if (device.has_value()) {
      Symbol<Device> device_symbol = JUST(device.value());
      return OpInterpUtil::Dispatch<Tensor>(
          *op_, {}, OpExprInterpContext(attrs, device_symbol, uniform_kernel_state));
    } else {
      return OpInterpUtil::Dispatch<Tensor>(*op_, {},
                                            OpExprInterpContext(attrs, uniform_kernel_state));
    }
  }

 private:
  std::shared_ptr<OpExpr> op_;
};

class ConsistentRandIntFunctor {
 public:
  ConsistentRandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }
  Maybe<Tensor> operator()(const int64_t low, const int64_t high, const Shape& shape,
                           const Symbol<ParallelDesc>& placement,
                           const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
                           const Optional<Symbol<DType>>& dtype,
                           const Optional<one::Generator>& generator) const {
    DataType dtype_val = DataType::kInt64;
    if (dtype.has_value()) {
      dtype_val = JUST(dtype.value())->data_type();

      if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
        OF_UNIMPLEMENTED() << dtype_val << "not supported in randn";
      }
    }

    MutableAttrMap attrs;
    JUST(attrs.SetAttr<Shape>("shape", shape));
    JUST(attrs.SetAttr<double>("low", low));
    JUST(attrs.SetAttr<double>("high", high - 1));
    JUST(attrs.SetAttr<DataType>("dtype", dtype_val));
    std::shared_ptr<one::Generator> gen;
    if (!generator) {
      gen = JUST(one::DefaultAutoGenerator());
    } else {
      gen = JUST(generator.value());
    }

    JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));

    const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);

    if (LazyMode::is_enabled()) {
      std::vector<std::string> nd_sbp(sbp_tuple.size());
      {
        for (int i = 0; i < sbp_tuple.size(); ++i) {
          nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i));
        }
      }
      JUST(attrs.SetAttr<std::vector<std::string>>("nd_sbp", nd_sbp));
    }
    const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));

    return OpInterpUtil::Dispatch<Tensor>(
        *op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, uniform_kernel_state));
  }

 private:
  std::shared_ptr<OpExpr> op_;
};
B
Bowen Chen 已提交
352

353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422
class RandPermFunctor {
 public:
  RandPermFunctor() { randperm_op_ = CHECK_JUST(one::OpBuilder("randperm").Output("out").Build()); }
  Maybe<Tensor> operator()(const int32_t n, const Optional<Symbol<Device>>& device,
                           const Optional<one::Generator>& generator) const {
    MutableAttrMap attrs;
    JUST(attrs.SetAttr<int32_t>("n", n));
    std::shared_ptr<one::Generator> gen;
    if (!generator) {
      gen = JUST(one::DefaultAutoGenerator());
    } else {
      gen = JUST(generator.value());
    }

    JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));

    const auto& randperm_kernel_state = std::make_shared<UniformKernelState>(gen);
    if (device.has_value()) {
      Symbol<Device> device_symbol = JUST(device.value());
      return OpInterpUtil::Dispatch<Tensor>(
          *randperm_op_, {}, OpExprInterpContext(attrs, device_symbol, randperm_kernel_state));
    } else {
      return OpInterpUtil::Dispatch<Tensor>(*randperm_op_, {},
                                            OpExprInterpContext(attrs, randperm_kernel_state));
    }
  }

 private:
  std::shared_ptr<OpExpr> randperm_op_;
};

class ConsistentRandPermFunctor {
 public:
  ConsistentRandPermFunctor() {
    randperm_op_ = CHECK_JUST(one::OpBuilder("randperm").Output("out").Build());
  }
  Maybe<Tensor> operator()(const int32_t n, const Symbol<ParallelDesc>& placement,
                           const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
                           const Optional<one::Generator>& generator) const {
    MutableAttrMap attrs;
    JUST(attrs.SetAttr<int32_t>("n", n));

    std::shared_ptr<one::Generator> gen;
    if (!generator) {
      gen = JUST(one::DefaultAutoGenerator());
    } else {
      gen = JUST(generator.value());
    }

    JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));

    const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);

    if (LazyMode::is_enabled()) {
      std::vector<std::string> nd_sbp(sbp_tuple.size());
      {
        for (int i = 0; i < sbp_tuple.size(); ++i) {
          nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i));
        }
      }
      JUST(attrs.SetAttr<std::vector<std::string>>("nd_sbp", nd_sbp));
    }
    const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
    return OpInterpUtil::Dispatch<Tensor>(
        *randperm_op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, uniform_kernel_state));
  }

 private:
  std::shared_ptr<OpExpr> randperm_op_;
};
L
Liang Depeng 已提交
423 424
}  // namespace impl

B
Bowen Chen 已提交
425 426
ONEFLOW_FUNCTION_LIBRARY(m) {
  m.add_functor<impl::BernoulliFunctor>("Bernoulli");
K
Kevin_Xiong 已提交
427 428
  m.add_functor<impl::RandPermFunctor>("RandPerm");
  m.add_functor<impl::ConsistentRandPermFunctor>("ConsistentRandPerm");
B
Bowen Chen 已提交
429 430
  m.add_functor<impl::RandFunctor>("Rand");
  m.add_functor<impl::ConsistentRandFunctor>("ConsistentRand");
B
Bowen Chen 已提交
431 432
  m.add_functor<impl::RandNFunctor>("RandN");
  m.add_functor<impl::ConsistentRandNFunctor>("ConsistentRandN");
K
Kevin_Xiong 已提交
433 434
  m.add_functor<impl::RandIntFunctor>("RandInt");
  m.add_functor<impl::ConsistentRandIntFunctor>("ConsistentRandInt");
B
Bowen Chen 已提交
435
};
L
Liang Depeng 已提交
436 437 438 439

}  // namespace functional
}  // namespace one
}  // namespace oneflow