未验证 提交 1db57451 编写于 作者: B Bowen Chen 提交者: GitHub

add flow.rand (#5722)

* add flow.rand

* update docstr

* update docstr

* add consistent_rand, add more tests

* update random op

* refine

* refine, add range and int type to uniform_kernel

* refine

* refine

* update doc

* update doc

* Refactor UniformDistribution

* fix
Co-authored-by: Nhjchen2 <chenhoujiangcug@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 3001d335
......@@ -78,6 +78,7 @@ oneflow
ones,
ones_like,
pow,
rand,
randn,
repeat,
reshape,
......
......@@ -950,12 +950,13 @@
signature: "Tensor ReduceSumLike(Tensor in, Tensor like, *,Int32List axis)"
bind_python: True
- name: "split"
signature: "TensorTuple Split(Tensor x, *, Int64 split_size, Int64 dim=0)"
- name: "rand"
signature: "Tensor Rand(*, Shape shape, DataType dtype=None, Device device=None, Generator generator=None)"
bind_python: True
- name: "split_with_size"
signature: "TensorTuple SplitWithSize(Tensor x, *, Int64List split_sizes, Int64 dim=0)"
- name: "consistent_rand"
signature: "Tensor ConsistentRand(*, Shape shape, Placement placement, SbpList sbp_tuple, DataType dtype=None,
Generator generator=None)"
bind_python: True
- name: "randn"
......@@ -968,6 +969,14 @@
SbpList sbp_tuple, DataType dtype=None, Generator generator=None)"
bind_python: True
- name: "split"
signature: "TensorTuple Split(Tensor x, *, Int64 split_size, Int64 dim=0)"
bind_python: True
- name: "split_with_size"
signature: "TensorTuple SplitWithSize(Tensor x, *, Int64List split_sizes, Int64 dim=0)"
bind_python: True
- name: "l2_normalize"
signature: "TensorTuple L2Normalize(Tensor input, Int32 axis, Float epsilon, *)"
bind_python: True
......
......@@ -32,6 +32,9 @@ limitations under the License.
#include "oneflow/core/job/global_for.h"
#include "oneflow/user/kernels/bernoulli_kernel.h"
#include "oneflow/user/kernels/distributions/normal_kernel.h"
#include "oneflow/user/kernels/distributions/uniform_kernel.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
namespace one {
......@@ -68,6 +71,96 @@ class BernoulliFunctor {
std::shared_ptr<OpExpr> bernoulli_op_;
};
class RandFunctor {
public:
RandFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }
Maybe<Tensor> operator()(const Shape& shape, const Optional<DataType>& dtype,
const Optional<Symbol<Device>>& device,
const Optional<one::Generator>& generator) const {
DataType dtype_val = DataType::kFloat;
if (dtype.has_value()) {
dtype_val = JUST(dtype.value());
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in rand";
}
}
MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("low", 0));
JUST(attrs.SetAttr<double>("high", 1));
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<DataType>("dtype", dtype_val));
std::shared_ptr<one::Generator> gen;
if (!generator) {
gen = JUST(one::DefaultAutoGenerator());
} else {
gen = JUST(generator.value());
}
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);
if (device.has_value()) {
Symbol<Device> device_symbol = JUST(device.value());
return OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, device_symbol, uniform_kernel_state));
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {},
OpExprInterpContext(attrs, uniform_kernel_state));
}
}
private:
std::shared_ptr<OpExpr> op_;
};
class ConsistentRandFunctor {
public:
ConsistentRandFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }
Maybe<Tensor> operator()(const Shape& shape, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
const Optional<DataType>& dtype,
const Optional<one::Generator>& generator) const {
DataType dtype_val = DataType::kFloat;
if (dtype.has_value()) {
dtype_val = JUST(dtype.value());
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in rand";
}
}
MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("low", 0));
JUST(attrs.SetAttr<double>("high", 1));
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<DataType>("dtype", dtype_val));
std::shared_ptr<one::Generator> gen;
if (!generator) {
gen = JUST(one::DefaultAutoGenerator());
} else {
gen = JUST(generator.value());
}
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);
const auto& parallel_distribution = JUST(GetNdSbp(sbp_tuple));
if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
JUST(attrs.SetAttr<std::string>("nd_sbp", parallel_distribution->DebugString()));
}
return OpInterpUtil::Dispatch<Tensor>(
*op_, {},
OpExprInterpContext(attrs, placement, parallel_distribution, uniform_kernel_state));
}
private:
std::shared_ptr<OpExpr> op_;
};
class RandNFunctor {
public:
RandNFunctor() { op_ = CHECK_JUST(one::OpBuilder("normal").Output("out").Build()); }
......@@ -162,6 +255,8 @@ class ConsistentRandNFunctor {
ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::BernoulliFunctor>("Bernoulli");
m.add_functor<impl::RandFunctor>("Rand");
m.add_functor<impl::ConsistentRandFunctor>("ConsistentRand");
m.add_functor<impl::RandNFunctor>("RandN");
m.add_functor<impl::ConsistentRandNFunctor>("ConsistentRandN");
};
......
......@@ -14,8 +14,8 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTION_NORMAL_DISTRIBUTION_H_
#define ONEFLOW_USER_KERNELS_DISTRIBUTION_NORMAL_DISTRIBUTION_H_
#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_DISTRIBUTION_H_
#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_DISTRIBUTION_H_
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/framework/random_generator.h"
......@@ -59,8 +59,8 @@ class NormalDistribution<DeviceType::kGPU, T> final {
const T mean_;
const T std_;
};
#endif
#endif // WITH_CUDA
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_DISTRIBUTION_NORMAL_DISTRIBUTION_H_
\ No newline at end of file
#endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_DISTRIBUTION_H_
\ No newline at end of file
......@@ -29,7 +29,7 @@ REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, double)
#ifdef WITH_CUDA
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, float)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, double)
#endif
#endif // WITH_CUDA
} // namespace
} // namespace oneflow
\ No newline at end of file
......@@ -68,4 +68,4 @@ class NormalKernel final : public user_op::OpKernel {
} // namespace
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_NORMAL_KERNEL_H_
\ No newline at end of file
#endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_KERNEL_H_
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/distributions/uniform_distribution.h"
namespace oneflow {
template<typename T, typename E = void>
class CPUUniformDistributionImpl;
template<typename T>
class CPUUniformDistributionImpl<T, typename std::enable_if<std::is_integral<T>::value>::type> {
public:
CPUUniformDistributionImpl(T low, T high) : random_distribution_(low, high) {}
T operator()(std::mt19937& engine) { return random_distribution_(engine); }
private:
std::uniform_int_distribution<T> random_distribution_;
};
template<typename T>
class CPUUniformDistributionImpl<T,
typename std::enable_if<std::is_floating_point<T>::value>::type> {
public:
CPUUniformDistributionImpl(T low, T high) : random_distribution_(low, high) {}
T operator()(std::mt19937& engine) { return random_distribution_(engine); }
private:
std::uniform_real_distribution<T> random_distribution_;
};
template<typename T>
void UniformDistribution<DeviceType::kCPU, T>::operator()(
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const {
CHECK_GE(elem_cnt, 0);
auto gen = CHECK_JUST(generator->Get<one::CPUGeneratorImpl>());
CPUUniformDistributionImpl<T> impl(low_, high_);
for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = impl(gen->engine()); }
}
#define INITIATE_CPU_UNIFORM_DISTRIBUTION(T, typeproto) \
template void UniformDistribution<DeviceType::kCPU, T>::operator()( \
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr, \
const std::shared_ptr<one::Generator>& generator) const;
OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)
OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_UNIFORM_DISTRIBUTION, INT_DATA_TYPE_SEQ)
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/distributions/uniform_distribution.h"
#include "oneflow/core/common/data_type.h"
namespace oneflow {
namespace {
template<typename T>
__device__ T GenUniform(curandState* state, const T low, const T high);
#define INITIATE_GENUNIFORM(T, typeproto) \
template<> \
__device__ T GenUniform<T>(curandState * state, const T low, const T high) { \
return curand_uniform(state) * (high - low) + low; \
}
OF_PP_FOR_EACH_TUPLE(INITIATE_GENUNIFORM, INT_DATA_TYPE_SEQ)
template<>
__device__ float GenUniform<float>(curandState* state, const float low, const float high) {
return curand_uniform(state) * (high - low) + low;
}
template<>
__device__ double GenUniform<double>(curandState* state, const double low, const double high) {
return curand_uniform_double(state) * (high - low) + low;
}
template<typename T>
__global__ void GenerateGpu(curandState* state, const int64_t elem_cnt, T* dptr, const T low,
const T high) {
const int id = blockIdx.x * blockDim.x + threadIdx.x;
curandState localState = state[id];
if (id < elem_cnt) { dptr[id] = GenUniform<T>(&localState, low, high); }
state[id] = localState;
}
} // namespace
template<typename T>
void UniformDistribution<DeviceType::kGPU, T>::operator()(
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const {
CHECK_GE(elem_cnt, 0);
auto gen = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>());
int32_t block_num = gen->max_block_num();
int32_t thread_num = gen->max_thread_num();
auto* curand_states = gen->curand_states();
GenerateGpu<T><<<block_num, thread_num, 0, device_ctx->cuda_stream()>>>(curand_states, elem_cnt,
dptr, low_, high_);
}
#define INITIATE_GPU_UNIFORM_DISTRIBUTION(T, typeproto) \
template void UniformDistribution<DeviceType::kGPU, T>::operator()( \
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr, \
const std::shared_ptr<one::Generator>& generator) const;
OF_PP_FOR_EACH_TUPLE(INITIATE_GPU_UNIFORM_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)
OF_PP_FOR_EACH_TUPLE(INITIATE_GPU_UNIFORM_DISTRIBUTION, INT_DATA_TYPE_SEQ)
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_DISTRIBUTION_H_
#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_DISTRIBUTION_H_
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/framework/random_generator.h"
#ifdef WITH_CUDA
#include <curand.h>
#include <curand_kernel.h>
#endif
namespace oneflow {
template<DeviceType device_type, typename T>
class UniformDistribution;
template<typename T>
class UniformDistribution<DeviceType::kCPU, T> final {
public:
OF_DISALLOW_COPY_AND_MOVE(UniformDistribution);
UniformDistribution(T low, T high) : low_(low), high_(high) {}
~UniformDistribution() = default;
void operator()(DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const;
private:
const T low_;
const T high_;
};
#ifdef WITH_CUDA
template<typename T>
class UniformDistribution<DeviceType::kGPU, T> final {
public:
OF_DISALLOW_COPY_AND_MOVE(UniformDistribution);
UniformDistribution(T low, T high) : low_(low), high_(high) {}
~UniformDistribution() = default;
void operator()(DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const;
private:
const T low_;
const T high_;
};
#endif // WITH_CUDA
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_DISTRIBUTION_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/distributions/uniform_kernel.h"
namespace oneflow {
namespace {
#define REGISTER_UNIFORM_KERNEL(device, dtype) \
REGISTER_USER_KERNEL("uniform").SetCreateFn<UniformKernel<device, dtype>>().SetIsMatchedHob( \
(user_op::HobDeviceTag() == device) \
& (user_op::HobAttr<DataType>("dtype") == GetDataType<dtype>::value));
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, double)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int32_t)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, int64_t)
#ifdef WITH_CUDA
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, float)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, double)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, int32_t)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, int64_t)
#endif // WITH_CUDA
} // namespace
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_
#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/distributions/uniform_distribution.h"
namespace oneflow {
class UniformKernelState : public user_op::OpKernelState {
public:
explicit UniformKernelState(const std::shared_ptr<one::Generator>& generator)
: generator_(generator) {}
const std::shared_ptr<one::Generator>& generator() const { return generator_; }
private:
std::shared_ptr<one::Generator> generator_;
};
namespace {
template<DeviceType device_type, typename T>
class UniformKernel final : public user_op::OpKernel {
public:
UniformKernel() = default;
~UniformKernel() = default;
std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
const auto& generator = CHECK_JUST(one::MakeAutoGenerator());
generator->set_current_seed(ctx->Attr<int64_t>("seed"));
return std::make_shared<UniformKernelState>(generator);
}
private:
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const double low = ctx->Attr<double>("low");
const double high = ctx->Attr<double>("high");
int64_t elem_cnt = out->shape().elem_cnt();
T* out_dptr = out->mut_dptr<T>();
auto* uniform_state = dynamic_cast<UniformKernelState*>(state);
CHECK_NOTNULL(uniform_state);
const auto& generator = uniform_state->generator();
CHECK_NOTNULL(generator);
UniformDistribution<device_type, T> distribution(static_cast<T>(low), static_cast<T>(high));
distribution(ctx->device_ctx(), elem_cnt, out_dptr, generator);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
} // namespace
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_DISTRIBUTIONS_UNIFORM_KERNEL_H_
......@@ -13,8 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNEL_RANDOM_MASK_GENERATOR_H_
#define ONEFLOW_USER_KERNEL_RANDOM_MASK_GENERATOR_H_
#ifndef ONEFLOW_USER_KERNELS_RANDOM_MASK_LIKE_KERNEL_H_
#define ONEFLOW_USER_KERNELS_RANDOM_MASK_LIKE_KERNEL_H_
#include "oneflow/user/kernels/random_mask_generator.h"
#include "oneflow/core/framework/framework.h"
......@@ -69,4 +69,4 @@ class RandomMaskLikeKernel final : public user_op::OpKernel, public user_op::Cud
} // namespace
} // namespace oneflow
#endif // ONEFLOW_USER_KERNEL_RANDOM_MASK_GENERATOR_H_
#endif // ONEFLOW_USER_KERNELS_RANDOM_MASK_LIKE_KERNEL_H_
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/global.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
Maybe<void> InferUniformParallelDistribution(user_op::InferParallelDistributionFnContext* ctx);
REGISTER_NO_GRAD_USER_OP("uniform")
.Output("out")
.SetOutputBufferNum(1)
.Attr<double>("low", 0)
.Attr<double>("high", 1)
.Attr<DataType>("dtype")
.Attr<Shape>("shape")
.Attr<std::string>("nd_sbp")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
Shape* out_shape = ctx->OutputShape("out", 0);
const Shape& shape = ctx->Attr<Shape>("shape");
DimVector dim_vec;
if (shape.NumAxes() > 0) {
dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend());
}
*out_shape = Shape(dim_vec);
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
auto dtype = ctx->Attr<DataType>("dtype");
*ctx->OutputDType("out", 0) = dtype;
return Maybe<void>::Ok();
})
.SetParallelDistributionInferFn(&InferUniformParallelDistribution);
Maybe<void> InferUniformParallelDistribution(user_op::InferParallelDistributionFnContext* ctx) {
cfg::ParallelDistribution* out = ctx->ParallelDistribution4ArgNameAndIndex("out", 0);
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
const auto& pb_str = ctx->user_op_conf().attr<std::string>("nd_sbp");
ParallelDistribution pb;
CHECK_OR_RETURN(TxtString2PbMessage(pb_str, &pb));
out->InitFromProto(pb);
} else {
out->mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();
}
return Maybe<void>::Ok();
}
} // namespace oneflow
......@@ -260,6 +260,7 @@ from oneflow.nn.modules.negative import negative_op as neg
from oneflow.nn.modules.negative import negative_op as negative
from oneflow.nn.modules.nonzero import nonzero_op as nonzero
from oneflow.nn.modules.random_ops import bernoulli
from oneflow.nn.modules.random_ops import rand_op as rand
from oneflow.nn.modules.random_ops import randn_op as randn
from oneflow.nn.modules.reduce_ops import _max as max
from oneflow.nn.modules.reduce_ops import _mean as mean
......
......@@ -58,7 +58,33 @@ def bernoulli(input, *, generator=None, out=None):
return flow.F.bernoulli(input, flow.float32, generator)
class RandN(Module):
def _rand_op_common_process(
size, device=None, generator=None, placement=None, sbp=None
):
assert size is not None, "shape must not be None!"
assert isinstance(
size, (int, tuple, list, flow.Size)
), "shape should be int or tuple int!"
if isinstance(device, str):
device = flow.device(device)
size = _single(size)
processed_sbp = sbp
if generator is None:
generator = flow.Generator()
if placement is not None:
assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp
if isinstance(processed_sbp, flow.sbp.sbp):
processed_sbp = (processed_sbp,)
else:
for elem in sbp:
assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp
assert len(processed_sbp) == len(placement.hierarchy)
else:
assert sbp is None, "sbp: %s" % sbp
return size, device, generator, placement, processed_sbp
class Rand(Module):
def __init__(
self,
size,
......@@ -71,32 +97,103 @@ class RandN(Module):
requires_grad=False,
) -> None:
super().__init__()
assert size is not None, "shape must not be None!"
assert isinstance(
size, (int, tuple, list, flow.Size)
), "shape should be int or tuple int!"
self.device = device
if isinstance(self.device, str):
self.device = flow.device(self.device)
self.requires_grad = requires_grad
size = _single(size)
if generator is None:
generator = flow.Generator()
self.generator = generator
self.placement = placement
self.sbp = sbp
if placement is not None:
assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp
if isinstance(self.sbp, flow.sbp.sbp):
self.sbp = (self.sbp,)
else:
for elem in sbp:
assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp
assert len(self.sbp) == len(placement.hierarchy)
(
self.size,
self.device,
self.generator,
self.placement,
self.sbp,
) = _rand_op_common_process(size, device, generator, placement, sbp)
self.dtype = dtype
def forward(self):
if self.placement is not None:
res = flow.F.consistent_rand(
self.size, self.placement, self.sbp, self.dtype, self.generator
)
else:
assert sbp is None, "sbp: %s" % sbp
self.size = size
res = flow.F.rand(self.size, self.dtype, self.device, self.generator)
res.requires_grad = self.requires_grad
return res
def rand_op(
*size,
out=None,
generator=None,
dtype: Optional[flow.dtype] = None,
layout=None,
device: Union[flow.device, str, None] = None,
placement: flow.placement = None,
sbp: flow._oneflow_internal.sbp.sbp = None,
requires_grad: bool = False
):
"""
Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1)
The shape of the tensor is defined by the variable argument ``size``.
Args:
size (int... or flow.Size): Defining the shape of the output tensor.
Can be a variable number of arguments or a collection like a list or tuple or flow.Size.
out (optional): The output tensor.
dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``.
layout (optional): The desired layout of returned Tensor.
generator (flow.Generator, optional) – a pseudorandom number generator for sampling
device (flow.device, optional): The desired device of returned local tensor. If None, uses the
current device.
placement (flow.placement, optional): The desired device of returned consistent tensor. If None, will
construct local tensor.
sbp (flow.sbp, optional): The desired sbp of returned consistent tensor. It must be equal with the
numbers of placement.
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> x = flow.rand(3,3)
>>> x.shape
flow.Size([3, 3])
>>> x.is_consistent
False
>>> placement = flow.placement("cpu", {0: [0]})
>>> sbp = flow.sbp.broadcast
>>> x = flow.rand(3, 3, placement=placement, sbp=sbp)
>>> x.is_consistent
True
"""
assert out is None, "out not supported yet"
assert layout is None, "layout not supported yet"
if generator is None:
generator = flow.default_generator()
return Rand(size, generator, dtype, layout, device, placement, sbp, requires_grad)()
class RandN(Module):
def __init__(
self,
size,
generator=None,
dtype=None,
layout=None,
device=None,
placement=None,
sbp=None,
requires_grad=False,
) -> None:
super().__init__()
self.requires_grad = requires_grad
(
self.size,
self.device,
self.generator,
self.placement,
self.sbp,
) = _rand_op_common_process(size, device, generator, placement, sbp)
self.dtype = dtype
def forward(self):
......@@ -133,7 +230,7 @@ def randn_op(
dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``.
layout (optional): The desired layout of returned Tensor.
generator (flow.Generator, optional) – a pseudorandom number generator for sampling
device (torch.device, optional): The desired device of returned local tensor. If None, uses the
device (flow.device, optional): The desired device of returned local tensor. If None, uses the
current device.
placement (flow.placement, optional): The desired device of returned consistent tensor. If None, will
construct local tensor.
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow as flow
import oneflow.unittest
from test_util import GenArgList
from automated_test_util import *
def _test_rand(test_case, device, shape):
y1 = flow.rand(*shape, device=flow.device(device))
y2 = flow.rand(*shape, device=flow.device(device))
test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy()))
test_case.assertTrue(shape == y1.shape)
def _test_different_dtype(test_case, device, shape):
y1 = flow.rand(*shape, dtype=flow.float32, device=flow.device(device))
y2 = flow.rand(*shape, dtype=flow.float64, device=flow.device(device))
test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy()))
test_case.assertTrue(shape == y1.shape)
with test_case.assertRaises(
oneflow._oneflow_internal.exception.UnimplementedException
):
flow.rand(*shape, dtype=flow.int32, device=flow.device(device))
def _test_backward(test_case, device, shape):
x = flow.rand(*shape, device=flow.device(device), requires_grad=True)
y = x.sum()
y.backward()
test_case.assertTrue(np.array_equal(np.ones(shape), x.grad.numpy()))
def _test_with_generator(test_case, device, shape):
gen = flow.Generator()
gen.manual_seed(0)
y1 = flow.rand(
*shape, dtype=flow.float32, device=flow.device(device), generator=gen
)
y1_np = y1.numpy()
gen.manual_seed(0)
y2 = flow.rand(
*shape, dtype=flow.float32, device=flow.device(device), generator=gen
)
test_case.assertTrue(np.allclose(y1_np, y2.numpy(), atol=1e-4, rtol=1e-4))
@flow.unittest.skip_unless_1n1d()
class TestConstantModule(flow.unittest.TestCase):
def test_consistent_naive(test_case):
placement = flow.placement("cpu", {0: [0]})
sbp = (flow.sbp.broadcast,)
x = flow.rand(16, 16, placement=placement, sbp=sbp)
test_case.assertEqual(x.sbp, sbp)
test_case.assertEqual(x.placement, placement)
def test_cast(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_rand,
_test_different_dtype,
_test_backward,
_test_with_generator,
]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5), (2, 4)]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
......@@ -25,18 +25,19 @@ from test_util import GenArgList
from automated_test_util import *
def _test_rand(test_case, device, shape):
def _test_randn(test_case, device, shape):
y1 = flow.randn(*shape, device=flow.device(device))
y2 = flow.randn(*shape, device=flow.device(device))
test_case.assertTrue(not np.array_equal(y1, y2))
test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy()))
test_case.assertTrue(shape == y1.shape)
def _test_different_dtype(test_case, device, shape):
y1 = flow.randn(*shape, dtype=flow.float32, device=flow.device(device))
y2 = flow.randn(*shape, dtype=flow.float64, device=flow.device(device))
test_case.assertTrue(not np.array_equal(y1, y2))
test_case.assertTrue(not np.array_equal(y1.numpy(), y2.numpy()))
test_case.assertTrue(shape == y1.shape)
with test_case.assertRaises(
oneflow._oneflow_internal.exception.UnimplementedException
):
......@@ -76,13 +77,13 @@ class TestConstantModule(flow.unittest.TestCase):
def test_cast(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_rand,
_test_randn,
_test_different_dtype,
_test_backward,
_test_with_generator,
]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5), (2, 0, 4)]
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5), (2, 4)]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册