未验证 提交 bb3b4321 编写于 作者: B Bowen Chen 提交者: GitHub

Refine rand functor (#6167)

* align random op datatype with torch

* refine

* refine

* refine randint

* refine rand

* refine

* support randint with only high

* fix randperm

* refine

* refine

* fix bernouli kernel

* fix bernouli kernel

* add JUST
Co-authored-by: NHoujiang Chen <chenhoujiangcug@gmail.com>
上级 7e46f9bc
......@@ -59,6 +59,17 @@ struct IsIntegral : std::integral_constant<bool, false> {};
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_TRUE_INTEGRAL, INT_DATA_TYPE_SEQ);
#undef SPECIALIZE_TRUE_INTEGRAL
// Type Trait: IsUnsignedIntegral
template<typename T>
struct IsUnsignedIntegral : std::integral_constant<bool, false> {};
#define SPECIALIZE_TRUE_INTEGRAL(type_cpp, type_proto) \
template<> \
struct IsUnsignedIntegral<type_cpp> : std::integral_constant<bool, true> {};
OF_PP_FOR_EACH_TUPLE(SPECIALIZE_TRUE_INTEGRAL, UNSIGNED_INT_DATA_TYPE_SEQ);
#undef SPECIALIZE_TRUE_INTEGRAL
// Type Trait: GetDataType
template<typename T, typename T2 = void>
......
......@@ -319,6 +319,7 @@
"Tensor (Tensor x, Tensor moving_mean=None, Tensor moving_variance=None,
Tensor gamma, Tensor beta, Int32 axis=1, Float epsilon=1e-5,
Float momentum=0.9, Bool is_training=False) => Normalization"
bind_python: True
- name: "normalization_grad"
......@@ -1175,39 +1176,42 @@
- name: "rand"
signature: [
"Tensor (Shape shape, *, DataType dtype=None, Device device=None,
Generator generator=None) => Rand",
]
bind_python: True
- name: "consistent_rand"
signature: [
"Tensor (Shape shape, *, Placement placement, SbpList sbp, DataType dtype=None,
Generator generator=None) => ConsistentRand",
]
"Tensor (Shape size, *, DataType dtype=None, Device device=None,
Generator generator=None, Bool requires_grad=False) => Rand",
"Tensor (Shape size, *, Placement placement, SbpList sbp, DataType dtype=None,
Generator generator=None, Bool requires_grad=False) => ConsistentRand",
]
bind_python: True
- name: "randn"
signature: [
"Tensor (Shape shape, *, DataType dtype=None, Device device=None,
Generator generator=None) => RandN",
]
"Tensor (Shape size, *, DataType dtype=None, Device device=None,
Generator generator=None, Bool requires_grad=False) => RandN",
"Tensor (Shape size, *, Placement placement, SbpList sbp, DataType dtype=None,
Generator generator=None, Bool requires_grad=False) => ConsistentRandN",
]
bind_python: True
- name: "consistent_randn"
- name: "randint"
signature: [
"Tensor (Shape shape, *, Placement placement, SbpList sbp, DataType dtype=None,
Generator generator=None) => ConsistentRandN",
]
"Tensor (Int64 low, Int64 high, Shape size, *, DataType dtype=None,
Device device=None, Generator generator=None, Bool requires_grad=False)=> RandInt",
"Tensor (Int64 high, Shape size, *, DataType dtype=None,
Device device=None, Generator generator=None, Bool requires_grad=False)=> RandInt",
"Tensor (Int64 low, Int64 high, Shape size, *, Placement placement, SbpList sbp_tuple,
DataType dtype=None, Generator generator=None, Bool requires_grad=False)=> ConsistentRandInt",
"Tensor (Int64 high, Shape size, *, Placement placement, SbpList sbp_tuple,
DataType dtype=None, Generator generator=None, Bool requires_grad=False)=> ConsistentRandInt",
]
bind_python: True
- name: "randint"
signature: "Tensor (Int64 low, Int64 high, Shape shape, DataType dtype=None, Device device=None, Generator generator=None)=> RandInt"
- name: "randperm"
signature: [
"Tensor (Int32 n, *, Generator generator=None, DataType dtype=kInt64, Device device=None, Bool requires_grad=False) => RandPerm",
"Tensor (Int32 n, *, Placement placement, SbpList sbp, Generator generator=None, DataType dtype=kInt64, Bool requires_grad=False) => ConsistentRandPerm",
]
bind_python: True
- name: "consistent_randint"
signature: "Tensor (Int64 low, Int64 high, Shape shape, Placement placement, SbpList sbp_tuple, DataType dtype=None, Generator generator=None)=> ConsistentRandInt"
bind_python: True
- name: "unfold"
signature:
......@@ -1243,20 +1247,6 @@
signature: "Tensor (Tensor dy, Tensor y, Tensor square_x_sum, Int32 axis, Float epsilon) => L2NormalizeGrad"
bind_python: False
- name: "randperm"
signature:
[
"Tensor (Int32 n, *, Device device=None, Generator generator=None) => RandPerm",
]
bind_python: True
- name: "consistent_randperm"
signature:
[
"Tensor (Int32 n, *, Placement placement, SbpList sbp, Generator generator=None) => ConsistentRandPerm",
]
bind_python: True
- name: "fused_self_attention"
signature: "TensorTuple (Tensor hidden_states, Int64 head_size=8, Float alpha=1.0) => FusedSelfAttention"
bind_python: True
......
......@@ -703,7 +703,7 @@ class NormalizationFunctor {
const std::shared_ptr<one::Tensor>& gamma,
const std::shared_ptr<one::Tensor>& beta, const int32_t& axis,
const float& epsilon, const float& momentum,
const bool& is_training) const {
const bool& training) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("axis", axis));
JUST(attrs.SetAttr<float>("epsilon", epsilon));
......@@ -712,7 +712,7 @@ class NormalizationFunctor {
CHECK_OR_RETURN((moving_mean && moving_variance) || (!moving_mean && !moving_variance))
<< "Both moving_mean and moving_variance should be None or Tensor.";
if (!is_training) {
if (!training) {
CHECK_OR_RETURN(moving_mean && moving_variance)
<< "Must have moving_mean and moving_variance in eval mode.";
return OpInterpUtil::Dispatch<one::Tensor>(
......
......@@ -25,6 +25,7 @@ limitations under the License.
#include "oneflow/core/framework/op_interpreter.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/functional/functional.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/impl/common.h"
#include "oneflow/core/functional/impl/unary_functor.h"
......@@ -32,9 +33,7 @@ limitations under the License.
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/user/kernels/bernoulli_kernel.h"
#include "oneflow/user/kernels/distributions/normal_kernel.h"
#include "oneflow/user/kernels/distributions/uniform_kernel.h"
#include "oneflow/user/kernels/distributions/common.h"
namespace oneflow {
namespace one {
......@@ -61,21 +60,23 @@ class BernoulliFunctor {
JUST(bernoulli_attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& bernoulli_kernel_state = std::make_shared<BernoulliKernelState>(gen);
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
return OpInterpUtil::Dispatch<Tensor>(
*bernoulli_op_, {x}, OpExprInterpContext(bernoulli_attrs, bernoulli_kernel_state));
return OpInterpUtil::Dispatch<Tensor>(*bernoulli_op_, {x},
OpExprInterpContext(bernoulli_attrs, distribution_state));
}
private:
std::shared_ptr<OpExpr> bernoulli_op_;
};
class RandFunctor {
public:
RandFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }
Maybe<Tensor> operator()(const Shape& shape, const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device,
const Optional<one::Generator>& generator) const {
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
DataType dtype_val = DataType::kFloat;
if (dtype.has_value()) {
dtype_val = JUST(dtype)->data_type();
......@@ -100,16 +101,11 @@ class RandFunctor {
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);
if (device.has_value()) {
Symbol<Device> device_symbol = JUST(device);
return OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, device_symbol, uniform_kernel_state));
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {},
OpExprInterpContext(attrs, uniform_kernel_state));
}
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
OpExprInterpContext ctx(attrs, distribution_state);
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));
JUST(result->set_requires_grad(requires_grad));
return result;
}
private:
......@@ -122,7 +118,8 @@ class ConsistentRandFunctor {
Maybe<Tensor> operator()(const Shape& shape, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
const Optional<Symbol<DType>>& dtype,
const Optional<one::Generator>& generator) const {
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
DataType dtype_val = DataType::kFloat;
if (dtype.has_value()) {
dtype_val = JUST(dtype)->data_type();
......@@ -146,14 +143,16 @@ class ConsistentRandFunctor {
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
JUST(attrs.SetAttr<std::string>("nd_sbp", nd_sbp->DebugString()));
}
return OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, uniform_kernel_state));
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state)));
JUST(result->set_requires_grad(requires_grad));
return result;
}
private:
......@@ -165,14 +164,12 @@ class RandNFunctor {
RandNFunctor() { op_ = CHECK_JUST(one::OpBuilder("normal").Output("out").Build()); }
Maybe<Tensor> operator()(const Shape& shape, const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device,
const Optional<one::Generator>& generator) const {
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
DataType dtype_val = DataType::kFloat;
if (dtype.has_value()) {
dtype_val = JUST(dtype)->data_type();
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << "Only support float and double in randn().";
}
if (dtype) { dtype_val = JUST(dtype)->data_type(); }
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << "Only support float and double in randn().";
}
MutableAttrMap attrs;
......@@ -191,16 +188,13 @@ class RandNFunctor {
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& normal_kernel_state = std::make_shared<NormalKernelState>(gen);
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
if (device.has_value()) {
Symbol<Device> device_symbol = JUST(device);
return OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, device_symbol, normal_kernel_state));
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {},
OpExprInterpContext(attrs, normal_kernel_state));
}
OpExprInterpContext ctx(attrs, distribution_state);
if (device) { ctx.device = JUST(device); }
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));
JUST(result->set_requires_grad(requires_grad));
return result;
}
private:
......@@ -213,14 +207,12 @@ class ConsistentRandNFunctor {
Maybe<Tensor> operator()(const Shape& shape, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
const Optional<Symbol<DType>>& dtype,
const Optional<one::Generator>& generator) const {
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
DataType dtype_val = DataType::kFloat;
if (dtype.has_value()) {
dtype_val = JUST(dtype)->data_type();
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << "Only support float and double in randn().";
}
if (dtype) { dtype_val = JUST(dtype)->data_type(); }
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << "Only support float and double in randn().";
}
MutableAttrMap attrs;
......@@ -238,40 +230,38 @@ class ConsistentRandNFunctor {
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& normal_kernel_state = std::make_shared<NormalKernelState>(gen);
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
JUST(attrs.SetAttr<std::string>("nd_sbp", nd_sbp->DebugString()));
}
return OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, normal_kernel_state));
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state)));
JUST(result->set_requires_grad(requires_grad));
return result;
}
private:
std::shared_ptr<OpExpr> op_;
};
class RandIntFunctor {
public:
RandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }
RandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform_int").Output("out").Build()); }
Maybe<Tensor> operator()(const int64_t low, const int64_t high, const Shape& shape,
const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device,
const Optional<one::Generator>& generator) const {
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
DataType dtype_val = DataType::kInt64;
if (dtype.has_value()) {
dtype_val = JUST(dtype)->data_type();
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in randn";
}
}
if (dtype) { dtype_val = JUST(dtype)->data_type(); }
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<double>("low", low));
JUST(attrs.SetAttr<double>("high", high - 1));
JUST(attrs.SetAttr<int64_t>("low", low));
JUST(attrs.SetAttr<int64_t>("high", high));
JUST(attrs.SetAttr<DataType>("dtype", dtype_val));
std::shared_ptr<one::Generator> gen;
......@@ -282,42 +272,50 @@ class RandIntFunctor {
}
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);
if (device.has_value()) {
Symbol<Device> device_symbol = JUST(device);
return OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, device_symbol, uniform_kernel_state));
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {},
OpExprInterpContext(attrs, uniform_kernel_state));
}
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
OpExprInterpContext ctx(attrs, distribution_state);
if (device) { ctx.device = JUST(device); }
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {}, ctx));
JUST(result->set_requires_grad(requires_grad));
return result;
}
private:
std::shared_ptr<OpExpr> op_;
};
class RandInt2Functor {
public:
Maybe<Tensor> operator()(const int64_t high, const Shape& shape,
const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device,
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
return RandInt(/*low*/ 0, high, shape, dtype, device, generator, requires_grad);
}
};
class ConsistentRandIntFunctor {
public:
ConsistentRandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }
ConsistentRandIntFunctor() {
op_ = CHECK_JUST(one::OpBuilder("uniform_int").Output("out").Build());
}
Maybe<Tensor> operator()(const int64_t low, const int64_t high, const Shape& shape,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
const Optional<Symbol<DType>>& dtype,
const Optional<one::Generator>& generator) const {
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
DataType dtype_val = DataType::kInt64;
if (dtype.has_value()) {
dtype_val = JUST(dtype)->data_type();
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in randn";
}
}
if (dtype) { dtype_val = JUST(dtype)->data_type(); }
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<double>("low", low));
JUST(attrs.SetAttr<double>("high", high - 1));
JUST(attrs.SetAttr<int64_t>("low", low));
JUST(attrs.SetAttr<int64_t>("high", high));
JUST(attrs.SetAttr<DataType>("dtype", dtype_val));
std::shared_ptr<one::Generator> gen;
if (!generator) {
......@@ -328,7 +326,7 @@ class ConsistentRandIntFunctor {
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
if (LazyMode::is_enabled()) {
std::vector<std::string> nd_sbp(sbp_tuple.size());
......@@ -341,19 +339,36 @@ class ConsistentRandIntFunctor {
}
const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
return OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, uniform_kernel_state));
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state)));
JUST(result->set_requires_grad(requires_grad));
return result;
}
private:
std::shared_ptr<OpExpr> op_;
};
class ConsistentRandInt2Functor {
public:
Maybe<Tensor> operator()(const int64_t high, const Shape& shape,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
const Optional<Symbol<DType>>& dtype,
const Optional<one::Generator>& generator,
const bool& requires_grad) const {
return ConsistentRandInt(/*low*/ 0, high, shape, placement, sbp_tuple, dtype, generator,
requires_grad);
}
};
class RandPermFunctor {
public:
RandPermFunctor() { randperm_op_ = CHECK_JUST(one::OpBuilder("randperm").Output("out").Build()); }
Maybe<Tensor> operator()(const int32_t n, const Optional<Symbol<Device>>& device,
const Optional<one::Generator>& generator) const {
Maybe<Tensor> operator()(const int32_t n, const Optional<one::Generator>& generator,
const Symbol<DType>& dtype, const Optional<Symbol<Device>>& device,
const bool& requires_grad) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("n", n));
std::shared_ptr<one::Generator> gen;
......@@ -365,15 +380,14 @@ class RandPermFunctor {
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& randperm_kernel_state = std::make_shared<UniformKernelState>(gen);
if (device.has_value()) {
Symbol<Device> device_symbol = JUST(device);
return OpInterpUtil::Dispatch<Tensor>(
*randperm_op_, {}, OpExprInterpContext(attrs, device_symbol, randperm_kernel_state));
} else {
return OpInterpUtil::Dispatch<Tensor>(*randperm_op_, {},
OpExprInterpContext(attrs, randperm_kernel_state));
}
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
OpExprInterpContext ctx(attrs, distribution_state);
if (device) { ctx.device = JUST(device); }
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(*randperm_op_, {}, ctx));
JUST(result->set_requires_grad(requires_grad));
return result;
}
private:
......@@ -387,7 +401,8 @@ class ConsistentRandPermFunctor {
}
Maybe<Tensor> operator()(const int32_t n, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
const Optional<one::Generator>& generator) const {
const Optional<one::Generator>& generator, const Symbol<DType>& dtype,
const bool& requires_grad) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("n", n));
......@@ -400,7 +415,7 @@ class ConsistentRandPermFunctor {
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);
const auto& distribution_state = std::make_shared<DistributionKernelState>(gen);
if (LazyMode::is_enabled()) {
std::vector<std::string> nd_sbp(sbp_tuple.size());
......@@ -412,8 +427,11 @@ class ConsistentRandPermFunctor {
JUST(attrs.SetAttr<std::vector<std::string>>("nd_sbp", nd_sbp));
}
const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
return OpInterpUtil::Dispatch<Tensor>(
*randperm_op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, uniform_kernel_state));
auto result = JUST(OpInterpUtil::Dispatch<Tensor>(
*randperm_op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, distribution_state)));
JUST(result->set_requires_grad(requires_grad));
return result;
}
private:
......@@ -421,16 +439,18 @@ class ConsistentRandPermFunctor {
};
} // namespace impl
using namespace impl;
ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::BernoulliFunctor>("Bernoulli");
m.add_functor<impl::RandPermFunctor>("RandPerm");
m.add_functor<impl::ConsistentRandPermFunctor>("ConsistentRandPerm");
m.add_functor<impl::RandFunctor>("Rand");
m.add_functor<impl::ConsistentRandFunctor>("ConsistentRand");
m.add_functor<impl::RandNFunctor>("RandN");
m.add_functor<impl::ConsistentRandNFunctor>("ConsistentRandN");
m.add_functor<impl::RandIntFunctor>("RandInt");
m.add_functor<impl::ConsistentRandIntFunctor>("ConsistentRandInt");
m.add_functor<BernoulliFunctor>("Bernoulli");
m.add_functor<RandPermFunctor>("RandPerm");
m.add_functor<ConsistentRandPermFunctor>("ConsistentRandPerm");
m.add_functor<RandFunctor>("Rand");
m.add_functor<ConsistentRandFunctor>("ConsistentRand");
m.add_functor<RandNFunctor>("RandN");
m.add_functor<ConsistentRandNFunctor>("ConsistentRandN");
m.add_functor<RandIntFunctor, RandInt2Functor>("RandInt");
m.add_functor<ConsistentRandIntFunctor, ConsistentRandInt2Functor>("ConsistentRandInt");
};
} // namespace functional
......
......@@ -14,7 +14,7 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/bernoulli_kernel.h"
#include "oneflow/user/kernels/distributions/common.h"
#include "oneflow/user/kernels/op_kernel_state_wrapper.h"
#include "oneflow/user/kernels/random_seed_util.h"
#include "oneflow/user/kernels/random_mask_generator.h"
......@@ -31,7 +31,7 @@ class BernoulliKerenl final : public user_op::OpKernel {
user_op::KernelInitContext* ctx) const override {
const auto& generator = CHECK_JUST(one::MakeGenerator(kCPU));
generator->set_current_seed(ctx->Attr<int64_t>("seed"));
return std::make_shared<BernoulliKernelState>(generator);
return std::make_shared<DistributionKernelState>(generator);
}
private:
......@@ -44,9 +44,9 @@ class BernoulliKerenl final : public user_op::OpKernel {
CHECK_EQ(GetDataType<K>(), out_blob->data_type());
CHECK_EQ(in_blob->shape().elem_cnt(), out_blob->shape().elem_cnt());
auto* bernoulli_kernel_state = dynamic_cast<BernoulliKernelState*>(state);
CHECK_NOTNULL(bernoulli_kernel_state);
const auto& generator = bernoulli_kernel_state->generator();
auto* kernel_state = dynamic_cast<DistributionKernelState*>(state);
CHECK_NOTNULL(kernel_state);
const auto& generator = kernel_state->generator();
CHECK_NOTNULL(generator);
const auto& cpu_generator = CHECK_JUST(generator->Get<one::CPUGeneratorImpl>());
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_COMMON_H_
#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_COMMON_H_
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/framework/random_generator.h"
namespace oneflow {
class DistributionKernelState : public user_op::OpKernelState {
public:
explicit DistributionKernelState(const std::shared_ptr<one::Generator>& generator)
: generator_(generator) {}
const std::shared_ptr<one::Generator>& generator() const { return generator_; }
private:
std::shared_ptr<one::Generator> generator_;
};
// FIXME: refine warning message
#define CHECK_OUT_OF_BOUNDS(var, name, min, max, dtype) \
CHECK(var >= min && var <= max) << name << " is out of bounds for " << dtype;
#define WARN_OUT_OF_BOUNDS(var, name, digits, dtype) \
if (var < -(1LL << digits) || var > (1LL << digits)) { \
LOG(WARNING) << name << " is out of bounds [-(2^" << digits << "), 2^" << digits << "]. " \
<< "Due to precision limitations " << dtype \
<< " can support discrete uniform distribution only within this range. " \
<< "This warning will become an error in later version release."; \
}
template<typename scalar_t>
void check_from_to_in_range(int64_t from, int64_t to_inc) {
if (IsFloating<scalar_t>::value) {
const auto min = static_cast<double>(std::numeric_limits<scalar_t>::lowest());
const auto max = static_cast<double>(std::numeric_limits<scalar_t>::max());
CHECK_OUT_OF_BOUNDS(from, "from", min, max, GetDataType<scalar_t>::value);
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, GetDataType<scalar_t>::value);
constexpr auto digits = std::numeric_limits<scalar_t>::digits;
WARN_OUT_OF_BOUNDS(from, "from", digits, GetDataType<scalar_t>::value);
WARN_OUT_OF_BOUNDS(to_inc, "to - 1", digits, GetDataType<scalar_t>::value);
} else if (IsIntegral<scalar_t>::value || IsUnsignedIntegral<scalar_t>::value) {
const auto min = static_cast<int64_t>(std::numeric_limits<scalar_t>::lowest());
const auto max = static_cast<int64_t>(std::numeric_limits<scalar_t>::max());
CHECK_OUT_OF_BOUNDS(from, "from", min, max, GetDataType<scalar_t>::value);
CHECK_OUT_OF_BOUNDS(to_inc, "to - 1", min, max, GetDataType<scalar_t>::value);
} else {
UNIMPLEMENTED()
<< "check_random_bounds handles only integral, floating-point and boolean types";
}
}
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_
......@@ -18,21 +18,11 @@ limitations under the License.
#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_KERNEL_H_
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/distributions/common.h"
#include "oneflow/user/kernels/distributions/normal_distribution.h"
namespace oneflow {
class NormalKernelState : public user_op::OpKernelState {
public:
explicit NormalKernelState(const std::shared_ptr<one::Generator>& generator)
: generator_(generator) {}
const std::shared_ptr<one::Generator>& generator() const { return generator_; }
private:
std::shared_ptr<one::Generator> generator_;
};
namespace {
template<DeviceType device_type, typename T>
......@@ -45,7 +35,7 @@ class NormalKernel final : public user_op::OpKernel {
user_op::KernelInitContext* ctx) const override {
const auto& generator = CHECK_JUST(one::MakeGenerator(device_type));
generator->set_current_seed(ctx->Attr<int64_t>("seed"));
return std::make_shared<NormalKernelState>(generator);
return std::make_shared<DistributionKernelState>(generator);
}
private:
......@@ -55,9 +45,9 @@ class NormalKernel final : public user_op::OpKernel {
const double std = ctx->Attr<double>("std");
int64_t elem_cnt = out->shape().elem_cnt();
T* out_dptr = out->mut_dptr<T>();
auto* normal_state = dynamic_cast<NormalKernelState*>(state);
CHECK_NOTNULL(normal_state);
const auto& generator = normal_state->generator();
auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);
CHECK_NOTNULL(distribution_state);
const auto& generator = distribution_state->generator();
CHECK_NOTNULL(generator);
NormalDistribution<device_type, T> distribution(static_cast<T>(mean), static_cast<T>(std));
distribution(ctx->device_ctx(), elem_cnt, out_dptr, generator);
......
......@@ -21,17 +21,6 @@ namespace oneflow {
template<typename T, typename E = void>
class CPUUniformDistributionImpl;
template<typename T>
class CPUUniformDistributionImpl<T, typename std::enable_if<std::is_integral<T>::value>::type> {
public:
CPUUniformDistributionImpl(T low, T high) : random_distribution_(low, high) {}
T operator()(std::mt19937& engine) { return random_distribution_(engine); }
private:
std::uniform_int_distribution<T> random_distribution_;
};
template<typename T>
class CPUUniformDistributionImpl<T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
......@@ -60,6 +49,5 @@ void UniformDistribution<DeviceType::kCPU, T>::operator()(
const std::shared_ptr<one::Generator>& generator) const;
OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)
OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_DISTRIBUTION, INT_DATA_TYPE_SEQ)
} // namespace oneflow
......@@ -13,8 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/distributions/uniform_distribution.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/user/kernels/distributions/uniform_distribution.h"
namespace oneflow {
......@@ -23,22 +23,20 @@ namespace {
template<typename T>
__device__ T GenUniform(curandState* state, const T low, const T high);
#define INITIATE_GENUNIFORM(T, typeproto) \
template<> \
__device__ T GenUniform<T>(curandState * state, const T low, const T high) { \
return curand_uniform(state) * (high - low) + low; \
}
OF_PP_FOR_EACH_TUPLE(INITIATE_GENUNIFORM, INT_DATA_TYPE_SEQ)
template<>
__device__ float GenUniform<float>(curandState* state, const float low, const float high) {
return curand_uniform(state) * (high - low) + low;
auto rand_num = curand_uniform(state);
// curand_uniform generates (0.0, 1.0], but we want [0.0, 1.0) here
if (rand_num == 1.0) { rand_num = 0.0; }
return rand_num * (high - low) + low;
}
template<>
__device__ double GenUniform<double>(curandState* state, const double low, const double high) {
return curand_uniform_double(state) * (high - low) + low;
auto rand_num = curand_uniform_double(state);
// curand_uniform_double generates (0.0, 1.0], but we want [0.0, 1.0) here
if (rand_num == 1.0) { rand_num = 0.0; }
return rand_num * (high - low) + low;
}
template<typename T>
......@@ -71,6 +69,5 @@ void UniformDistribution<DeviceType::kGPU, T>::operator()(
const std::shared_ptr<one::Generator>& generator) const;
OF_PP_FOR_EACH_TUPLE(INITIATE_GPU_UNIFORM_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)
OF_PP_FOR_EACH_TUPLE(INITIATE_GPU_UNIFORM_DISTRIBUTION, INT_DATA_TYPE_SEQ)
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/user/kernels/distributions/uniform_int_distribution.h"
namespace oneflow {
template<typename T>
class CPUUniformIntDistributionImpl {
public:
CPUUniformIntDistributionImpl(int64_t low, int64_t high) : random_distribution_(low, high) {}
T operator()(std::mt19937& engine) { return static_cast<T>(random_distribution_(engine)); }
private:
std::uniform_int_distribution<int64_t> random_distribution_;
};
template<typename T>
void UniformIntDistribution<DeviceType::kCPU, T>::operator()(
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const {
CHECK_GE(elem_cnt, 0);
auto gen = CHECK_JUST(generator->Get<one::CPUGeneratorImpl>());
// std::uniform_int_distribution generates [low, high], but we want [low, high) here
CPUUniformIntDistributionImpl<T> impl(low_, high_ - 1);
for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = impl(gen->engine()); }
}
#define INITIATE_CPU_UNIFORM_INT_DISTRIBUTION(T, typeproto) \
template void UniformIntDistribution<DeviceType::kCPU, T>::operator()( \
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr, \
const std::shared_ptr<one::Generator>& generator) const;
OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_INT_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)
OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_INT_DISTRIBUTION, INT_DATA_TYPE_SEQ)
OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_INT_DISTRIBUTION, UNSIGNED_INT_DATA_TYPE_SEQ)
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/framework/dtype.h"
#include "oneflow/user/kernels/distributions/uniform_int_distribution.h"
namespace oneflow {
namespace {
__device__ int64_t GenUniformInt(curandState* state, const int64_t low, const int64_t high) {
auto rand_num = curand_uniform(state);
// curand_uniform generates (0.0, 1.0], but we want [0.0, 1.0) here
if (rand_num == 1.0) { rand_num = 0.0; }
return static_cast<int64_t>(rand_num * (high - low) + low);
}
template<typename T>
__global__ void GenerateGpu(curandState* state, const int64_t elem_cnt, T* dptr, const int64_t low,
const int64_t high) {
const int id = blockIdx.x * blockDim.x + threadIdx.x;
curandState localState = state[id];
if (id < elem_cnt) { dptr[id] = static_cast<T>(GenUniformInt(&localState, low, high)); }
state[id] = localState;
}
} // namespace
template<typename T>
void UniformIntDistribution<DeviceType::kGPU, T>::operator()(
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const {
CHECK_GE(elem_cnt, 0);
auto gen = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>());
int32_t block_num = gen->max_block_num();
int32_t thread_num = gen->max_thread_num();
auto* curand_states = gen->curand_states();
GenerateGpu<T><<<block_num, thread_num, 0, device_ctx->cuda_stream()>>>(curand_states, elem_cnt,
dptr, low_, high_);
}
#define INITIATE_GPU_UNIFORM_INT_DISTRIBUTION(T, typeproto) \
template void UniformIntDistribution<DeviceType::kGPU, T>::operator()( \
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr, \
const std::shared_ptr<one::Generator>& generator) const;
OF_PP_FOR_EACH_TUPLE(INITIATE_GPU_UNIFORM_INT_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)
OF_PP_FOR_EACH_TUPLE(INITIATE_GPU_UNIFORM_INT_DISTRIBUTION, INT_DATA_TYPE_SEQ)
OF_PP_FOR_EACH_TUPLE(INITIATE_GPU_UNIFORM_INT_DISTRIBUTION, UNSIGNED_INT_DATA_TYPE_SEQ)
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_DISTRIBUTION_H_
#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_DISTRIBUTION_H_
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/framework/random_generator.h"
#ifdef WITH_CUDA
#include <curand.h>
#include <curand_kernel.h>
#endif
namespace oneflow {
template<DeviceType device_type, typename T>
class UniformIntDistribution;
template<typename T>
class UniformIntDistribution<DeviceType::kCPU, T> final {
public:
OF_DISALLOW_COPY_AND_MOVE(UniformIntDistribution);
UniformIntDistribution(int64_t low, int64_t high) : low_(low), high_(high) {}
~UniformIntDistribution() = default;
void operator()(DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const;
private:
const int64_t low_;
const int64_t high_;
};
#ifdef WITH_CUDA
template<typename T>
class UniformIntDistribution<DeviceType::kGPU, T> final {
public:
OF_DISALLOW_COPY_AND_MOVE(UniformIntDistribution);
UniformIntDistribution(int64_t low, int64_t high) : low_(low), high_(high) {}
~UniformIntDistribution() = default;
void operator()(DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const;
private:
const int64_t low_;
const int64_t high_;
};
#endif // WITH_CUDA
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_DISTRIBUTION_H_
......@@ -13,25 +13,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNEL_BERNOULLI_KERNEL_H_
#define ONEFLOW_USER_KERNEL_BERNOULLI_KERNEL_H_
#include "oneflow/user/kernels/random_mask_generator.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/distributions/uniform_int_kernel.h"
namespace oneflow {
class BernoulliKernelState : public user_op::OpKernelState {
public:
explicit BernoulliKernelState(const std::shared_ptr<one::Generator>& generator)
: generator_(generator) {}
const std::shared_ptr<one::Generator>& generator() const { return generator_; }
private:
std::shared_ptr<one::Generator> generator_;
};
namespace {
#define REGISTER_UNIFORM_KERNEL(device, dtype) \
REGISTER_USER_KERNEL("uniform_int") \
.SetCreateFn<UniformIntKernel<device, dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceTag() == device) \
& (user_op::HobAttr<DataType>("dtype") == GetDataType<dtype>::value));
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, double)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, uint8_t)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int8_t)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int32_t)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int64_t)
#ifdef WITH_CUDA
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, float)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, double)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, uint8_t)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, int8_t)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, int32_t)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, int64_t)
#endif // WITH_CUDA
} // namespace
} // namespace oneflow
#endif // ONEFLOW_USER_KERNEL_BERNOULLI_KERNEL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_KERNEL_H_
#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_KERNEL_H_
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/user/kernels/distributions/common.h"
#include "oneflow/user/kernels/distributions/uniform_int_distribution.h"
namespace oneflow {
namespace {
// The following algorithm is adopted from pytorch:
// The purpose of `update_from` and `update_to` is to find the closest valid int64_t number that can
// be used as actual `from`. The current implementation of `random_` uses uint64_t arithmetics and
// casts the result to the target dtype(scalar_t). This casting can result in generating numbers
// that happen to be greater or equal to `to` value. For instance:
//
// auto actual = torch::empty({3, 3}, torch::half);
// actual.random_(0, 65504);
//
// If random's uint64_t arithmetics produces 65503 as a random value after casting to torch::half it
// becomes 65504 and violates the requirement that random value must be less than `to`. To resolve
// this issue `update_from` and `update_to` moves `from` to the right and `to` to the left to the
// next closest value that won't go outside [from, to) after casting to the target dtype. For `to` =
// 65504 it moves left for (1 << (log2(to) - 11 + 1)) = 32 and becomes 65472, which is previous
// available number for torch::half dtype.
template<typename scalar_t>
int64_t update_from(int64_t from) {
const auto from_plus_1 = static_cast<int64_t>(static_cast<scalar_t>(from + 1));
if (from_plus_1 < from) {
int64_t from_ = std::abs(from + 1);
int n = 0;
while (from_ >>= 1) ++n;
// NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
from = from_plus_1 + (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
}
return from;
}
template<typename scalar_t>
int64_t update_to(int64_t to) {
const auto to_minus_1 = static_cast<int64_t>(static_cast<scalar_t>(to - 1));
if (to_minus_1 >= to) {
int64_t to_ = std::abs(to - 1);
int n = 0;
while (to_ >>= 1) ++n;
// NOLINTNEXTLINE(clang-analyzer-core.UndefinedBinaryOperatorResult)
to = to_minus_1 - (1LL << (n - std::numeric_limits<scalar_t>::digits + 1));
}
return to;
}
template<DeviceType device_type, typename T>
class UniformIntKernel final : public user_op::OpKernel {
public:
UniformIntKernel() = default;
~UniformIntKernel() = default;
std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
const auto& generator = CHECK_JUST(one::MakeAutoGenerator());
generator->set_current_seed(ctx->Attr<int64_t>("seed"));
return std::make_shared<DistributionKernelState>(generator);
}
private:
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
int64_t from = ctx->Attr<int64_t>("low");
int64_t to = ctx->Attr<int64_t>("high");
CHECK_LE(from, to) << "uniform kernel expects 'low' to be less than 'high', but got from="
<< from << " >= to=",
to;
if (IsFloating<T>::value) {
from = update_from<T>(from);
to = update_to<T>(to);
CHECK_LE(from, to) << "uniform kernel expects 'low' casted to dtype to be less than 'high'"
" casted to dtype, but got from="
<< from << " >= to=",
to;
}
check_from_to_in_range<T>(from, to - 1);
int64_t elem_cnt = out->shape().elem_cnt();
T* out_dptr = out->mut_dptr<T>();
auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);
CHECK_NOTNULL(distribution_state);
const auto& generator = distribution_state->generator();
CHECK_NOTNULL(generator);
UniformIntDistribution<device_type, T> distribution(from, to);
distribution(ctx->device_ctx(), elem_cnt, out_dptr, generator);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
} // namespace
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_INT_KERNEL_H_
......@@ -25,13 +25,9 @@ namespace {
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, double)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int32_t)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int64_t)
#ifdef WITH_CUDA
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, float)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, double)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, int32_t)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, int64_t)
#endif // WITH_CUDA
} // namespace
......
......@@ -17,21 +17,11 @@ limitations under the License.
#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/distributions/common.h"
#include "oneflow/user/kernels/distributions/uniform_distribution.h"
namespace oneflow {
class UniformKernelState : public user_op::OpKernelState {
public:
explicit UniformKernelState(const std::shared_ptr<one::Generator>& generator)
: generator_(generator) {}
const std::shared_ptr<one::Generator>& generator() const { return generator_; }
private:
std::shared_ptr<one::Generator> generator_;
};
namespace {
template<DeviceType device_type, typename T>
......@@ -44,21 +34,22 @@ class UniformKernel final : public user_op::OpKernel {
user_op::KernelInitContext* ctx) const override {
const auto& generator = CHECK_JUST(one::MakeGenerator(device_type));
generator->set_current_seed(ctx->Attr<int64_t>("seed"));
return std::make_shared<UniformKernelState>(generator);
return std::make_shared<DistributionKernelState>(generator);
}
private:
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const double low = ctx->Attr<double>("low");
const double high = ctx->Attr<double>("high");
const double from = ctx->Attr<double>("from");
const double to = ctx->Attr<double>("to");
check_from_to_in_range<T>(from, to);
int64_t elem_cnt = out->shape().elem_cnt();
T* out_dptr = out->mut_dptr<T>();
auto* uniform_state = dynamic_cast<UniformKernelState*>(state);
CHECK_NOTNULL(uniform_state);
const auto& generator = uniform_state->generator();
auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);
CHECK_NOTNULL(distribution_state);
const auto& generator = distribution_state->generator();
CHECK_NOTNULL(generator);
UniformDistribution<device_type, T> distribution(static_cast<T>(low), static_cast<T>(high));
UniformDistribution<device_type, T> distribution(static_cast<T>(from), static_cast<T>(to));
distribution(ctx->device_ctx(), elem_cnt, out_dptr, generator);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
......
......@@ -19,7 +19,7 @@ limitations under the License.
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/user/kernels/range_kernel_util.h"
#include "oneflow/user/kernels/distributions/uniform_kernel.h"
#include "oneflow/user/kernels/distributions/common.h"
namespace oneflow {
class CpuRandPermKernel final : public user_op::OpKernel {
......@@ -30,7 +30,7 @@ class CpuRandPermKernel final : public user_op::OpKernel {
user_op::KernelInitContext* ctx) const override {
const auto& generator = CHECK_JUST(one::MakeGenerator(kCPU));
generator->set_current_seed(ctx->Attr<int64_t>("seed"));
return std::make_shared<UniformKernelState>(generator);
return std::make_shared<DistributionKernelState>(generator);
}
private:
......@@ -39,9 +39,9 @@ class CpuRandPermKernel final : public user_op::OpKernel {
int32_t* output = out->mut_dptr<int32_t>();
const int32_t n = ctx->Attr<int32_t>("n");
if (n == 0) { return; }
auto* randperm_kernel_state = dynamic_cast<UniformKernelState*>(state);
CHECK_NOTNULL(randperm_kernel_state);
const auto& generator = randperm_kernel_state->generator();
auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);
CHECK_NOTNULL(distribution_state);
const auto& generator = distribution_state->generator();
const auto& cpu_generator = CHECK_JUST(generator->Get<one::CPUGeneratorImpl>());
CHECK_NOTNULL(generator);
user_op::RangeFunctor<DeviceType::kCPU, int32_t>()(ctx->device_ctx(), 0, 1, n, output);
......
......@@ -13,16 +13,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/op_kernel_state_wrapper.h"
#include <curand.h>
#include <curand_kernel.h>
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/user/kernels/op_kernel_state_wrapper.h"
#include "oneflow/user/kernels/range_kernel_util.h"
#include "oneflow/user/kernels/distributions/uniform_kernel.h"
#include "oneflow/user/kernels/radix_sort.cuh"
#include <curand.h>
#include <curand_kernel.h>
#include "oneflow/user/kernels/distributions/common.h"
namespace oneflow {
__global__ void GeneKeysAndValues(const int32_t n, int32_t* values, int32_t* keys,
curandState* state) {
......@@ -40,7 +41,7 @@ class GpuRandPermKernel final : public user_op::OpKernel {
user_op::KernelInitContext* ctx) const override {
const auto& generator = CHECK_JUST(one::MakeGenerator(kGPU));
generator->set_current_seed(ctx->Attr<int64_t>("seed"));
return std::make_shared<UniformKernelState>(generator);
return std::make_shared<DistributionKernelState>(generator);
}
private:
......@@ -52,9 +53,9 @@ class GpuRandPermKernel final : public user_op::OpKernel {
if (n == 0) { return; }
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
auto* randperm_kernel_state = dynamic_cast<UniformKernelState*>(state);
CHECK_NOTNULL(randperm_kernel_state);
const auto& generator = randperm_kernel_state->generator();
auto* distribution_state = dynamic_cast<DistributionKernelState*>(state);
CHECK_NOTNULL(distribution_state);
const auto& generator = distribution_state->generator();
const auto& gpu_generator = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>());
CHECK_NOTNULL(generator);
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
namespace oneflow {
REGISTER_NO_GRAD_USER_OP("uniform_int")
.Output("out")
.SetOutputBufferNum(1)
.Attr<int64_t>("from", 0)
.Attr<int64_t>("to", 1)
.Attr<DataType>("dtype")
.Attr<Shape>("shape")
.Attr<std::string>("nd_sbp")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
Shape* out_shape = ctx->OutputShape("out", 0);
const Shape& shape = ctx->Attr<Shape>("shape");
DimVector dim_vec;
if (shape.NumAxes() > 0) {
dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend());
}
*out_shape = Shape(dim_vec);
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
auto dtype = ctx->Attr<DataType>("dtype");
*ctx->OutputDType("out", 0) = dtype;
return Maybe<void>::Ok();
})
.SetNdSbpInferFn([](user_op::InferNdSbpFnContext* ctx) -> Maybe<void> {
cfg::SbpParallel default_sbp;
default_sbp.mutable_broadcast_parallel();
return user_op::InferNdSbp4SrcOp(ctx, default_sbp);
});
} // namespace oneflow
......@@ -20,8 +20,8 @@ namespace oneflow {
REGISTER_NO_GRAD_USER_OP("uniform")
.Output("out")
.SetOutputBufferNum(1)
.Attr<double>("low", 0)
.Attr<double>("high", 1)
.Attr<double>("from", 0)
.Attr<double>("to", 1)
.Attr<DataType>("dtype")
.Attr<Shape>("shape")
.Attr<std::string>("nd_sbp")
......
......@@ -71,7 +71,7 @@ class Rand(Module):
def forward(self):
if self.placement is not None:
res = flow._C.consistent_rand(
res = flow._C.rand(
self.size,
placement=self.placement,
sbp=self.sbp,
......@@ -169,12 +169,13 @@ class RandN(Module):
def forward(self):
if self.placement is not None:
res = flow._C.consistent_randn(
res = flow._C.randn(
self.size,
placement=self.placement,
sbp=self.sbp,
dtype=self.dtype,
generator=self.generator,
requires_grad=self.requires_grad,
)
else:
res = flow._C.randn(
......@@ -182,8 +183,8 @@ class RandN(Module):
dtype=self.dtype,
device=self.device,
generator=self.generator,
requires_grad=self.requires_grad,
)
res.requires_grad = self.requires_grad
return res
......@@ -276,26 +277,27 @@ class RandInt(Module):
def forward(self):
if self.placement is not None:
res = flow._C.consistent_randint(
res = flow._C.randint(
self.low,
self.high,
shape=self.size,
size=self.size,
placement=self.placement,
sbp_tuple=self.sbp,
dtype=self.dtype,
generator=self.generator,
requires_grad=self.requires_grad,
)
else:
res = flow._C.randint(
self.low,
self.high,
shape=self.size,
size=self.size,
dtype=self.dtype,
device=self.device,
generator=self.generator,
requires_grad=self.requires_grad,
)
res.requires_grad = self.requires_grad
return res.to(dtype=self.dtype)
return res
def randint_op(
......@@ -381,12 +383,20 @@ class RandPerm(Module):
def forward(self, out=None):
if self.placement is not None:
res = flow._C.consistent_randperm(
self.n, placement=self.placement, sbp=self.sbp, generator=self.generator
res = flow._C.randperm(
self.n,
placement=self.placement,
sbp=self.sbp,
generator=self.generator,
requires_grad=self.requires_grad,
)
else:
res = flow._C.randperm(self.n, device=self.device, generator=self.generator)
res.requires_grad = self.requires_grad
res = flow._C.randperm(
self.n,
device=self.device,
generator=self.generator,
requires_grad=self.requires_grad,
)
return res.to(dtype=self.dtype)
......
......@@ -78,7 +78,6 @@ def _eager_consistent_tensor_to(input, device_type, dtype):
if device_type == input.placement.device_type and dtype != input.dtype:
return flow._C.cast(input, dtype=dtype)
device = flow.device(device_type)
placement = flow._oneflow_internal._ReplacePlacementDeviceTag(
input.placement, device_type
......
......@@ -34,6 +34,15 @@ def _test_rand(test_case, device, shape):
test_case.assertTrue(shape == y1.shape)
def _test_0d_rand(test_case, device, shape):
y1 = flow.rand(*shape, device=flow.device(device))
y2 = flow.rand(*shape, device=flow.device(device))
test_case.assertTrue(
np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)
) # 0d is [] and []
test_case.assertTrue(shape == y1.shape)
def _test_different_dtype(test_case, device, shape):
y1 = flow.rand(*shape, dtype=flow.float32, device=flow.device(device))
y2 = flow.rand(*shape, dtype=flow.float64, device=flow.device(device))
......@@ -75,7 +84,15 @@ class TestConstantModule(flow.unittest.TestCase):
test_case.assertEqual(x.sbp, sbp)
test_case.assertEqual(x.placement, placement)
def test_cast(test_case):
def test_0d_randint(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [_test_0d_rand]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [(2, 0, 4), (2, 0, 2)]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
def test_cases(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_rand,
......
......@@ -24,14 +24,14 @@ import oneflow.unittest
from test_util import GenArgList
def _test_rand(test_case, device, shape, low, high):
def _test_randint(test_case, device, shape, low, high):
y1 = flow.randint(low, high, shape, device=flow.device(device))
y2 = flow.randint(low, high, shape, device=flow.device(device))
test_case.assertFalse(np.all(y1.numpy() == y2.numpy()))
test_case.assertFalse(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))
test_case.assertTrue(shape == y1.shape)
def _test_0d_rand(test_case, device, shape, low, high):
def _test_0d_randint(test_case, device, shape, low, high):
y1 = flow.randint(low, high, shape, device=flow.device(device))
y2 = flow.randint(low, high, shape, device=flow.device(device))
test_case.assertTrue(
......@@ -41,10 +41,17 @@ def _test_0d_rand(test_case, device, shape, low, high):
def _test_different_dtype(test_case, device, shape, low, high):
y1 = flow.randint(low, high, shape, dtype=flow.float32, device=flow.device(device))
y2 = flow.randint(low, high, shape, dtype=flow.float64, device=flow.device(device))
test_case.assertTrue(not np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))
test_case.assertTrue(shape == y1.shape)
for dtype in [
flow.uint8,
flow.int8,
flow.int32,
flow.int64,
flow.float32,
flow.float64,
]:
y = flow.randint(low, high, shape, dtype=dtype, device=flow.device(device))
test_case.assertTrue(y.dtype == dtype)
test_case.assertTrue(y.shape == shape)
def _test_with_generator(test_case, device, shape, low, high):
......@@ -61,9 +68,9 @@ def _test_with_generator(test_case, device, shape, low, high):
def _test_high(test_case, device, shape, low, high):
y1 = flow.randint(low, high, shape, device=flow.device(device))
y2 = flow.randint(low, high, shape, device=flow.device(device))
test_case.assertFalse(np.all(y1.numpy() == y2.numpy()))
y1 = flow._C.randint(high, shape, device=flow.device(device))
y2 = flow._C.randint(high, shape, device=flow.device(device))
test_case.assertFalse(np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))
test_case.assertTrue(shape == y1.shape)
......@@ -73,7 +80,6 @@ def _test_0rank(test_case, device, shape, low, high):
@flow.unittest.skip_unless_1n1d()
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
class TestRandint(flow.unittest.TestCase):
def test_consistent_naive(test_case):
placement = flow.placement("cpu", {0: [0]})
......@@ -82,10 +88,25 @@ class TestRandint(flow.unittest.TestCase):
test_case.assertEqual(x.sbp, sbp)
test_case.assertEqual(x.placement, placement)
def test_consistent_different_types(test_case):
for dtype in [
flow.int8,
flow.int32,
flow.int64,
flow.float32,
flow.float64,
]:
placement = flow.placement("cpu", {0: [0]})
sbp = (flow.sbp.broadcast,)
x = flow.randint(0, 16, (10, 1), placement=placement, sbp=sbp, dtype=dtype)
test_case.assertEqual(x.dtype, dtype)
test_case.assertEqual(x.sbp, sbp)
test_case.assertEqual(x.placement, placement)
def test_randint(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_rand,
_test_randint,
_test_different_dtype,
_test_with_generator,
]
......@@ -98,7 +119,7 @@ class TestRandint(flow.unittest.TestCase):
def test_0d_randint(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [_test_0d_rand]
arg_dict["test_fun"] = [_test_0d_randint]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [(2, 0, 4), (2, 0, 2)]
arg_dict["low"] = [i for i in range(10)]
......
......@@ -36,6 +36,7 @@ def _test_randperm_with_generator(test_case, N, device, dtype):
def _test_randperm_backward(test_case, N, device, dtype):
dtype = flow.float32 # fix dtype here as reduce_sum doesn't support all dtypes yet
x = flow.randperm(N, device=device, dtype=dtype)
x.requires_grad = True
y = x.sum()
......@@ -52,6 +53,29 @@ def _test_randperm_randomness(test_case, N, device, dtype):
@flow.unittest.skip_unless_1n1d()
class Testrandperm(flow.unittest.TestCase):
def test_consistent_naive(test_case):
placement = flow.placement("cpu", {0: [0]})
sbp = (flow.sbp.broadcast,)
x = flow.randperm(10, placement=placement, sbp=sbp)
test_case.assertEqual(x.sbp, sbp)
test_case.assertEqual(x.placement, placement)
def test_consistent_different_types(test_case):
for dtype in [
flow.uint8,
flow.int8,
flow.int32,
flow.int64,
flow.float32,
flow.float64,
]:
placement = flow.placement("cpu", {0: [0]})
sbp = (flow.sbp.broadcast,)
x = flow.randperm(10, placement=placement, sbp=sbp, dtype=dtype)
test_case.assertEqual(x.dtype, dtype)
test_case.assertEqual(x.sbp, sbp)
test_case.assertEqual(x.placement, placement)
def test_randperm(test_case):
arg_dict = OrderedDict()
arg_dict["test_functions"] = [
......@@ -60,7 +84,14 @@ class Testrandperm(flow.unittest.TestCase):
]
arg_dict["N"] = [i for i in range(10, 100, 5)]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["dtype"] = [flow.int32, flow.int64, flow.float32, flow.float64]
arg_dict["dtype"] = [
flow.uint8,
flow.int8,
flow.int32,
flow.int64,
flow.float32,
flow.float64,
]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册