/* Copyright 2020 The OneFlow Authors. All rights reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/common/global.h" #include "oneflow/core/common/optional.h" #include "oneflow/core/common/protobuf.h" #include "oneflow/core/framework/attr_map.h" #include "oneflow/core/framework/op_builder.h" #include "oneflow/core/framework/op_expr.h" #include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" #include "oneflow/core/framework/tensor.h" #include "oneflow/core/framework/tensor_tuple.h" #include "oneflow/core/framework/op_interpreter.h" #include "oneflow/core/framework/random_generator.h" #include "oneflow/core/framework/nd_sbp.h" #include "oneflow/core/functional/function_library.h" #include "oneflow/core/functional/impl/common.h" #include "oneflow/core/functional/impl/unary_functor.h" #include "oneflow/core/job/parallel_desc.h" #include "oneflow/core/job/global_for.h" #include "oneflow/core/job/sbp_parallel.h" #include "oneflow/core/job/lazy_mode.h" #include "oneflow/user/kernels/bernoulli_kernel.h" #include "oneflow/user/kernels/distributions/normal_kernel.h" #include "oneflow/user/kernels/distributions/uniform_kernel.h" namespace oneflow { namespace one { namespace functional { namespace impl { class BernoulliFunctor { public: BernoulliFunctor() { bernoulli_op_ = CHECK_JUST(one::OpBuilder("bernoulli").Input("in").Output("out").Build()); } Maybe operator()(const std::shared_ptr& x, const Symbol& dtype, const Optional& generator) const { MutableAttrMap bernoulli_attrs; JUST(bernoulli_attrs.SetAttr("dtype", dtype->data_type())); std::shared_ptr gen; if (!generator) { gen = JUST(one::DefaultAutoGenerator()); } else { gen = JUST(generator.value()); } JUST(bernoulli_attrs.SetAttr("seed", gen->current_seed())); const auto& bernoulli_kernel_state = std::make_shared(gen); return OpInterpUtil::Dispatch( *bernoulli_op_, {x}, OpExprInterpContext(bernoulli_attrs, bernoulli_kernel_state)); } private: std::shared_ptr bernoulli_op_; }; class RandFunctor { public: RandFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); } Maybe operator()(const Shape& shape, const Optional>& dtype, const Optional>& device, const Optional& generator) const { DataType dtype_val = DataType::kFloat; if (dtype.has_value()) { dtype_val = JUST(dtype.value())->data_type(); if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) { OF_UNIMPLEMENTED() << "Only support float and double in rand()."; } } MutableAttrMap attrs; JUST(attrs.SetAttr("low", 0)); JUST(attrs.SetAttr("high", 1)); JUST(attrs.SetAttr("shape", shape)); JUST(attrs.SetAttr("dtype", dtype_val)); std::shared_ptr gen; if (!generator) { gen = JUST(one::DefaultAutoGenerator()); } else { gen = JUST(generator.value()); } JUST(attrs.SetAttr("seed", gen->current_seed())); const auto& uniform_kernel_state = std::make_shared(gen); if (device.has_value()) { Symbol device_symbol = JUST(device.value()); return OpInterpUtil::Dispatch( *op_, {}, OpExprInterpContext(attrs, device_symbol, uniform_kernel_state)); } else { return OpInterpUtil::Dispatch(*op_, {}, OpExprInterpContext(attrs, uniform_kernel_state)); } } private: std::shared_ptr op_; }; class ConsistentRandFunctor { public: ConsistentRandFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); } Maybe operator()(const Shape& shape, const Symbol& placement, const std::vector>& sbp_tuple, const Optional>& dtype, const Optional& generator) const { DataType dtype_val = DataType::kFloat; if (dtype.has_value()) { dtype_val = JUST(dtype.value())->data_type(); if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) { OF_UNIMPLEMENTED() << "Only support float and double in rand()."; } } MutableAttrMap attrs; JUST(attrs.SetAttr("low", 0)); JUST(attrs.SetAttr("high", 1)); JUST(attrs.SetAttr("shape", shape)); JUST(attrs.SetAttr("dtype", dtype_val)); std::shared_ptr gen; if (!generator) { gen = JUST(one::DefaultAutoGenerator()); } else { gen = JUST(generator.value()); } JUST(attrs.SetAttr("seed", gen->current_seed())); const auto& uniform_kernel_state = std::make_shared(gen); const auto& parallel_distribution = JUST(GetNdSbp(sbp_tuple)); if (!JUST(*Global, MultiClient>::Get())) { JUST(attrs.SetAttr("nd_sbp", parallel_distribution->DebugString())); } return OpInterpUtil::Dispatch( *op_, {}, OpExprInterpContext(attrs, placement, parallel_distribution, uniform_kernel_state)); } private: std::shared_ptr op_; }; class RandNFunctor { public: RandNFunctor() { op_ = CHECK_JUST(one::OpBuilder("normal").Output("out").Build()); } Maybe operator()(const Shape& shape, const Optional>& dtype, const Optional>& device, const Optional& generator) const { DataType dtype_val = DataType::kFloat; if (dtype.has_value()) { dtype_val = JUST(dtype.value())->data_type(); if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) { OF_UNIMPLEMENTED() << "Only support float and double in randn()."; } } MutableAttrMap attrs; JUST(attrs.SetAttr("mean", 0)); JUST(attrs.SetAttr("std", 1)); JUST(attrs.SetAttr("shape", shape)); JUST(attrs.SetAttr("dtype", dtype_val)); std::shared_ptr gen; if (!generator) { gen = JUST(one::DefaultAutoGenerator()); } else { gen = JUST(generator.value()); } JUST(attrs.SetAttr("seed", gen->current_seed())); const auto& normal_kernel_state = std::make_shared(gen); if (device.has_value()) { Symbol device_symbol = JUST(device.value()); return OpInterpUtil::Dispatch( *op_, {}, OpExprInterpContext(attrs, device_symbol, normal_kernel_state)); } else { return OpInterpUtil::Dispatch(*op_, {}, OpExprInterpContext(attrs, normal_kernel_state)); } } private: std::shared_ptr op_; }; class ConsistentRandNFunctor { public: ConsistentRandNFunctor() { op_ = CHECK_JUST(one::OpBuilder("normal").Output("out").Build()); } Maybe operator()(const Shape& shape, const Symbol& placement, const std::vector>& sbp_tuple, const Optional>& dtype, const Optional& generator) const { DataType dtype_val = DataType::kFloat; if (dtype.has_value()) { dtype_val = JUST(dtype.value())->data_type(); if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) { OF_UNIMPLEMENTED() << "Only support float and double in randn()."; } } MutableAttrMap attrs; JUST(attrs.SetAttr("mean", 0)); JUST(attrs.SetAttr("std", 1)); JUST(attrs.SetAttr("shape", shape)); JUST(attrs.SetAttr("dtype", dtype_val)); std::shared_ptr gen; if (!generator) { gen = JUST(one::DefaultAutoGenerator()); } else { gen = JUST(generator.value()); } JUST(attrs.SetAttr("seed", gen->current_seed())); const auto& normal_kernel_state = std::make_shared(gen); const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); if (!JUST(*Global, MultiClient>::Get())) { JUST(attrs.SetAttr("nd_sbp", nd_sbp->DebugString())); } return OpInterpUtil::Dispatch( *op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, normal_kernel_state)); } private: std::shared_ptr op_; }; class RandIntFunctor { public: RandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); } Maybe operator()(const int64_t low, const int64_t high, const Shape& shape, const Optional>& dtype, const Optional>& device, const Optional& generator) const { DataType dtype_val = DataType::kInt64; if (dtype.has_value()) { dtype_val = JUST(dtype.value())->data_type(); if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) { OF_UNIMPLEMENTED() << dtype_val << "not supported in randn"; } } MutableAttrMap attrs; JUST(attrs.SetAttr("shape", shape)); JUST(attrs.SetAttr("low", low)); JUST(attrs.SetAttr("high", high - 1)); JUST(attrs.SetAttr("dtype", dtype_val)); std::shared_ptr gen; if (!generator) { gen = JUST(one::DefaultAutoGenerator()); } else { gen = JUST(generator.value()); } JUST(attrs.SetAttr("seed", gen->current_seed())); const auto& uniform_kernel_state = std::make_shared(gen); if (device.has_value()) { Symbol device_symbol = JUST(device.value()); return OpInterpUtil::Dispatch( *op_, {}, OpExprInterpContext(attrs, device_symbol, uniform_kernel_state)); } else { return OpInterpUtil::Dispatch(*op_, {}, OpExprInterpContext(attrs, uniform_kernel_state)); } } private: std::shared_ptr op_; }; class ConsistentRandIntFunctor { public: ConsistentRandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); } Maybe operator()(const int64_t low, const int64_t high, const Shape& shape, const Symbol& placement, const std::vector>& sbp_tuple, const Optional>& dtype, const Optional& generator) const { DataType dtype_val = DataType::kInt64; if (dtype.has_value()) { dtype_val = JUST(dtype.value())->data_type(); if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) { OF_UNIMPLEMENTED() << dtype_val << "not supported in randn"; } } MutableAttrMap attrs; JUST(attrs.SetAttr("shape", shape)); JUST(attrs.SetAttr("low", low)); JUST(attrs.SetAttr("high", high - 1)); JUST(attrs.SetAttr("dtype", dtype_val)); std::shared_ptr gen; if (!generator) { gen = JUST(one::DefaultAutoGenerator()); } else { gen = JUST(generator.value()); } JUST(attrs.SetAttr("seed", gen->current_seed())); const auto& uniform_kernel_state = std::make_shared(gen); if (LazyMode::is_enabled()) { std::vector nd_sbp(sbp_tuple.size()); { for (int i = 0; i < sbp_tuple.size(); ++i) { nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i)); } } JUST(attrs.SetAttr>("nd_sbp", nd_sbp)); } const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); return OpInterpUtil::Dispatch( *op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, uniform_kernel_state)); } private: std::shared_ptr op_; }; class RandPermFunctor { public: RandPermFunctor() { randperm_op_ = CHECK_JUST(one::OpBuilder("randperm").Output("out").Build()); } Maybe operator()(const int32_t n, const Optional>& device, const Optional& generator) const { MutableAttrMap attrs; JUST(attrs.SetAttr("n", n)); std::shared_ptr gen; if (!generator) { gen = JUST(one::DefaultAutoGenerator()); } else { gen = JUST(generator.value()); } JUST(attrs.SetAttr("seed", gen->current_seed())); const auto& randperm_kernel_state = std::make_shared(gen); if (device.has_value()) { Symbol device_symbol = JUST(device.value()); return OpInterpUtil::Dispatch( *randperm_op_, {}, OpExprInterpContext(attrs, device_symbol, randperm_kernel_state)); } else { return OpInterpUtil::Dispatch(*randperm_op_, {}, OpExprInterpContext(attrs, randperm_kernel_state)); } } private: std::shared_ptr randperm_op_; }; class ConsistentRandPermFunctor { public: ConsistentRandPermFunctor() { randperm_op_ = CHECK_JUST(one::OpBuilder("randperm").Output("out").Build()); } Maybe operator()(const int32_t n, const Symbol& placement, const std::vector>& sbp_tuple, const Optional& generator) const { MutableAttrMap attrs; JUST(attrs.SetAttr("n", n)); std::shared_ptr gen; if (!generator) { gen = JUST(one::DefaultAutoGenerator()); } else { gen = JUST(generator.value()); } JUST(attrs.SetAttr("seed", gen->current_seed())); const auto& uniform_kernel_state = std::make_shared(gen); if (LazyMode::is_enabled()) { std::vector nd_sbp(sbp_tuple.size()); { for (int i = 0; i < sbp_tuple.size(); ++i) { nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i)); } } JUST(attrs.SetAttr>("nd_sbp", nd_sbp)); } const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); return OpInterpUtil::Dispatch( *randperm_op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, uniform_kernel_state)); } private: std::shared_ptr randperm_op_; }; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Bernoulli"); m.add_functor("RandPerm"); m.add_functor("ConsistentRandPerm"); m.add_functor("Rand"); m.add_functor("ConsistentRand"); m.add_functor("RandN"); m.add_functor("ConsistentRandN"); m.add_functor("RandInt"); m.add_functor("ConsistentRandInt"); }; } // namespace functional } // namespace one } // namespace oneflow