未验证 提交 8eefb6c4 编写于 作者: B Bowen Chen 提交者: GitHub

add flow.randn (#5736)

* add flow.randn

* auto format by CI

* refine

* add mean and std as attr to normal_op/kernel/distribution

* refine

* refine

* keep module, fix docstring to pass CI

* refine
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: NYao Chi <later@usopp.net>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 70bf1daa
......@@ -76,6 +76,7 @@ oneflow
ones,
ones_like,
pow,
randn,
repeat,
reshape,
reciprocal,
......
......@@ -66,6 +66,13 @@ struct OpExprInterpContext {
: attrs(attrs_arg),
parallel_desc(parallel_desc_arg),
parallel_distribution(parallel_distribution_arg) {}
OpExprInterpContext(const AttrMap& attrs_arg, Symbol<ParallelDesc> parallel_desc_arg,
Symbol<cfg::ParallelDistribution> parallel_distribution_arg,
std::shared_ptr<user_op::OpKernelState> state_arg)
: attrs(attrs_arg),
parallel_desc(parallel_desc_arg),
parallel_distribution(parallel_distribution_arg),
state(state_arg) {}
AttrMap attrs;
Optional<Symbol<Device>> device; // for local op
......
......@@ -933,3 +933,13 @@
- name: "reduce_sum_like"
signature: "Tensor ReduceSumLike(Tensor in, Tensor like, *,Int32List axis)"
bind_python: True
- name: "randn"
signature: "Tensor RandN(*, Shape shape, DataType dtype=None, Device device=None,
Generator generator=None)"
bind_python: True
- name: "consistent_randn"
signature: "Tensor ConsistentRandN(*, Shape shape, Placement placement,
SbpList sbp_tuple, DataType dtype=None, Generator generator=None)"
bind_python: True
......@@ -13,7 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/global.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
......@@ -22,10 +24,14 @@ limitations under the License.
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/op_interpreter.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/impl/common.h"
#include "oneflow/core/functional/impl/unary_functor.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/user/kernels/bernoulli_kernel.h"
#include "oneflow/user/kernels/distributions/normal_kernel.h"
namespace oneflow {
namespace one {
......@@ -62,9 +68,104 @@ class BernoulliFunctor {
std::shared_ptr<OpExpr> bernoulli_op_;
};
class RandNFunctor {
public:
RandNFunctor() { op_ = CHECK_JUST(one::OpBuilder("normal").Output("out").Build()); }
Maybe<Tensor> operator()(const Shape& shape, const Optional<DataType>& dtype,
const Optional<Symbol<Device>>& device,
const Optional<one::Generator>& generator) const {
DataType dtype_val = DataType::kFloat;
if (dtype.has_value()) {
dtype_val = JUST(dtype.value());
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in randn";
}
}
MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("mean", 0));
JUST(attrs.SetAttr<double>("std", 1));
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<DataType>("dtype", dtype_val));
std::shared_ptr<one::Generator> gen;
if (!generator) {
gen = JUST(one::DefaultAutoGenerator());
} else {
gen = JUST(generator.value());
}
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& normal_kernel_state = std::make_shared<NormalKernelState>(gen);
if (device.has_value()) {
Symbol<Device> device_symbol = JUST(device.value());
return OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, device_symbol, normal_kernel_state));
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {},
OpExprInterpContext(attrs, normal_kernel_state));
}
}
private:
std::shared_ptr<OpExpr> op_;
};
class ConsistentRandNFunctor {
public:
ConsistentRandNFunctor() { op_ = CHECK_JUST(one::OpBuilder("normal").Output("out").Build()); }
Maybe<Tensor> operator()(const Shape& shape, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
const Optional<DataType>& dtype,
const Optional<one::Generator>& generator) const {
DataType dtype_val = DataType::kFloat;
if (dtype.has_value()) {
dtype_val = JUST(dtype.value());
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in randn";
}
}
MutableAttrMap attrs;
JUST(attrs.SetAttr<double>("mean", 0));
JUST(attrs.SetAttr<double>("std", 1));
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<DataType>("dtype", dtype_val));
std::shared_ptr<one::Generator> gen;
if (!generator) {
gen = JUST(one::DefaultAutoGenerator());
} else {
gen = JUST(generator.value());
}
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& normal_kernel_state = std::make_shared<NormalKernelState>(gen);
const auto& parallel_distribution = JUST(GetNdSbp(sbp_tuple));
if (!JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
JUST(attrs.SetAttr<std::string>("nd_sbp", parallel_distribution->DebugString()));
}
return OpInterpUtil::Dispatch<Tensor>(
*op_, {},
OpExprInterpContext(attrs, placement, parallel_distribution, normal_kernel_state));
}
private:
std::shared_ptr<OpExpr> op_;
};
} // namespace impl
ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::BernoulliFunctor>("Bernoulli"); };
ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::BernoulliFunctor>("Bernoulli");
m.add_functor<impl::RandNFunctor>("RandN");
m.add_functor<impl::ConsistentRandNFunctor>("ConsistentRandN");
};
} // namespace functional
} // namespace one
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/distributions/normal_distribution.h"
#include "oneflow/core/framework/framework.h"
namespace oneflow {
template<typename T>
void NormalDistribution<DeviceType::kCPU, T>::operator()(
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const {
CHECK_GE(elem_cnt, 0);
auto gen = CHECK_JUST(generator->Get<one::CPUGeneratorImpl>());
std::normal_distribution<T> random_distribution(mean_, std_);
for (int64_t i = 0; i < elem_cnt; ++i) { dptr[i] = random_distribution(gen->engine()); }
}
#define INITIATE_CPU_NORMAL_DISTRIBUTION(T, typeproto) \
template void NormalDistribution<DeviceType::kCPU, T>::operator()( \
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr, \
const std::shared_ptr<one::Generator>& generator) const;
OF_PP_FOR_EACH_TUPLE(INITIATE_CPU_NORMAL_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/distributions/normal_distribution.h"
#include "oneflow/core/common/data_type.h"
namespace oneflow {
namespace {
template<typename T>
__device__ T GenNormal(curandState* state, const T mean, const T std);
template<>
__device__ float GenNormal<float>(curandState* state, const float mean, const float std) {
return (curand_normal(state) + mean) / std;
}
template<>
__device__ double GenNormal<double>(curandState* state, const double mean, const double std) {
return (curand_normal_double(state) + mean) / std;
}
template<typename T>
__global__ void GenerateGpu(curandState* state, const int64_t elem_cnt, T* dptr, const T mean,
const T std) {
const int id = blockIdx.x * blockDim.x + threadIdx.x;
curandState localState = state[id];
if (id < elem_cnt) { dptr[id] = GenNormal<T>(&localState, mean, std); }
state[id] = localState;
}
} // namespace
template<typename T>
void NormalDistribution<DeviceType::kGPU, T>::operator()(
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const {
CHECK_GE(elem_cnt, 0);
auto gen = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>());
int32_t block_num = gen->max_block_num();
int32_t thread_num = gen->max_thread_num();
auto* curand_states = gen->curand_states();
GenerateGpu<T><<<block_num, thread_num, 0, device_ctx->cuda_stream()>>>(curand_states, elem_cnt,
dptr, mean_, std_);
}
#define INITIATE_GPU_NORMAL_DISTRIBUTION(T, typeproto) \
template void NormalDistribution<DeviceType::kGPU, T>::operator()( \
DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr, \
const std::shared_ptr<one::Generator>& generator) const;
OF_PP_FOR_EACH_TUPLE(INITIATE_GPU_NORMAL_DISTRIBUTION, FLOATING_DATA_TYPE_SEQ)
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTION_NORMAL_DISTRIBUTION_H_
#define ONEFLOW_USER_KERNELS_DISTRIBUTION_NORMAL_DISTRIBUTION_H_
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/framework/random_generator.h"
#ifdef WITH_CUDA
#include <curand.h>
#include <curand_kernel.h>
#endif
namespace oneflow {
template<DeviceType device_type, typename T>
class NormalDistribution;
template<typename T>
class NormalDistribution<DeviceType::kCPU, T> final {
public:
OF_DISALLOW_COPY_AND_MOVE(NormalDistribution);
NormalDistribution(T mean, T std) : mean_(mean), std_(std) {}
~NormalDistribution() = default;
void operator()(DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const;
private:
const T mean_;
const T std_;
};
#ifdef WITH_CUDA
template<typename T>
class NormalDistribution<DeviceType::kGPU, T> final {
public:
OF_DISALLOW_COPY_AND_MOVE(NormalDistribution);
NormalDistribution(T mean, T std) : mean_(mean), std_(std) {}
~NormalDistribution() = default;
void operator()(DeviceCtx* device_ctx, const int64_t elem_cnt, T* dptr,
const std::shared_ptr<one::Generator>& generator) const;
private:
const T mean_;
const T std_;
};
#endif
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_DISTRIBUTION_NORMAL_DISTRIBUTION_H_
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/user/kernels/distributions/normal_kernel.h"
namespace oneflow {
namespace {
#define REGISTER_UNIFORM_KERNEL(device, dtype) \
REGISTER_USER_KERNEL("normal").SetCreateFn<NormalKernel<device, dtype>>().SetIsMatchedHob( \
(user_op::HobDeviceTag() == device) \
& (user_op::HobAttr<DataType>("dtype") == GetDataType<dtype>::value));
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, float)
REGISTER_UNIFORM_KERNEL(DeviceType::kCPU, double)
#ifdef WITH_CUDA
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, float)
REGISTER_UNIFORM_KERNEL(DeviceType::kGPU, double)
#endif
} // namespace
} // namespace oneflow
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_KERNEL_H_
#define ONEFLOW_USER_KERNELS_DISTRIBUTIONS_NORMAL_KERNEL_H_
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/distributions/normal_distribution.h"
namespace oneflow {
class NormalKernelState : public user_op::OpKernelState {
public:
explicit NormalKernelState(const std::shared_ptr<one::Generator>& generator)
: generator_(generator) {}
const std::shared_ptr<one::Generator>& generator() const { return generator_; }
private:
std::shared_ptr<one::Generator> generator_;
};
namespace {
template<DeviceType device_type, typename T>
class NormalKernel final : public user_op::OpKernel {
public:
NormalKernel() = default;
~NormalKernel() = default;
std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
const auto& generator = CHECK_JUST(one::MakeAutoGenerator());
generator->set_current_seed(ctx->Attr<int64_t>("seed"));
return std::make_shared<NormalKernelState>(generator);
}
private:
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
const double mean = ctx->Attr<double>("mean");
const double std = ctx->Attr<double>("std");
int64_t elem_cnt = out->shape().elem_cnt();
T* out_dptr = out->mut_dptr<T>();
auto* normal_state = dynamic_cast<NormalKernelState*>(state);
CHECK_NOTNULL(normal_state);
const auto& generator = normal_state->generator();
CHECK_NOTNULL(generator);
NormalDistribution<device_type, T> distribution(static_cast<T>(mean), static_cast<T>(std));
distribution(ctx->device_ctx(), elem_cnt, out_dptr, generator);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
} // namespace
} // namespace oneflow
#endif // ONEFLOW_USER_KERNELS_NORMAL_KERNEL_H_
\ No newline at end of file
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/global.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
Maybe<void> InferNormalParallelDistribution(user_op::InferParallelDistributionFnContext* ctx);
REGISTER_NO_GRAD_USER_OP("normal")
.Output("out")
.SetOutputBufferNum(1)
.Attr<double>("mean", 0)
.Attr<double>("std", 1)
.Attr<DataType>("dtype")
.Attr<Shape>("shape")
.Attr<std::string>("nd_sbp")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
Shape* out_shape = ctx->OutputShape("out", 0);
const Shape& shape = ctx->Attr<Shape>("shape");
*out_shape = shape;
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
ctx->NewBuilder().Broadcast(ctx->inputs()).Broadcast(ctx->outputs()).Build();
return Maybe<void>::Ok();
})
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
auto dtype = ctx->Attr<DataType>("dtype");
*ctx->OutputDType("out", 0) = dtype;
return Maybe<void>::Ok();
})
.SetParallelDistributionInferFn(&InferNormalParallelDistribution);
Maybe<void> InferNormalParallelDistribution(user_op::InferParallelDistributionFnContext* ctx) {
cfg::ParallelDistribution* out = ctx->ParallelDistribution4ArgNameAndIndex("out", 0);
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
const auto& pb_str = ctx->user_op_conf().attr<std::string>("nd_sbp");
ParallelDistribution pb;
CHECK_OR_RETURN(TxtString2PbMessage(pb_str, &pb));
out->InitFromProto(pb);
} else {
out->mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();
}
return Maybe<void>::Ok();
}
} // namespace oneflow
\ No newline at end of file
......@@ -291,6 +291,7 @@ from oneflow.nn.modules.ne import ne_op as not_equal
from oneflow.nn.modules.negative import negative_op as neg
from oneflow.nn.modules.negative import negative_op as negative
from oneflow.nn.modules.random_ops import bernoulli
from oneflow.nn.modules.random_ops import randn_op as randn
from oneflow.nn.modules.reduce_ops import _max as max
from oneflow.nn.modules.reduce_ops import _mean as mean
from oneflow.nn.modules.reduce_ops import _min as min
......
......@@ -13,11 +13,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import random
import sys
from typing import Optional, Union
import oneflow as flow
from oneflow.nn.module import Module
from oneflow.nn.modules.utils import _single
def bernoulli(input, *, generator=None, out=None):
......@@ -58,6 +58,115 @@ def bernoulli(input, *, generator=None, out=None):
return flow.F.bernoulli(input, flow.float32, generator)
class RandN(Module):
def __init__(
self,
size,
generator=None,
dtype=None,
layout=None,
device=None,
placement=None,
sbp=None,
requires_grad=False,
) -> None:
super().__init__()
assert size is not None, "shape must not be None!"
assert isinstance(
size, (int, tuple, list, flow.Size)
), "shape should be int or tuple int!"
self.device = device
if isinstance(self.device, str):
self.device = flow.device(self.device)
self.requires_grad = requires_grad
size = _single(size)
if generator is None:
generator = flow.Generator()
self.generator = generator
self.placement = placement
self.sbp = sbp
if placement is not None:
assert isinstance(sbp, (flow.sbp.sbp, tuple, list)), "sbp: %s" % sbp
if isinstance(self.sbp, flow.sbp.sbp):
self.sbp = (self.sbp,)
else:
for elem in sbp:
assert isinstance(elem, flow.sbp.sbp), "sbp: %s" % sbp
assert len(self.sbp) == len(placement.hierarchy)
else:
assert sbp is None, "sbp: %s" % sbp
self.size = size
self.dtype = dtype
def forward(self):
if self.placement is not None:
res = flow.F.consistent_randn(
self.size, self.placement, self.sbp, self.dtype, self.generator
)
else:
res = flow.F.randn(self.size, self.dtype, self.device, self.generator)
res.requires_grad = self.requires_grad
return res
def randn_op(
*size,
out=None,
generator=None,
dtype: Optional[flow.dtype] = None,
layout=None,
device: Union[flow.device, str, None] = None,
placement: flow.placement = None,
sbp: flow._oneflow_internal.sbp.sbp = None,
requires_grad: bool = False
):
"""
Returns a tensor filled with random numbers from a normal distribution with mean 0 and variance 1 (also called the standard normal distribution).
The shape of the tensor is defined by the variable argument ``size``.
Args:
size (int... or flow.Size): Defining the shape of the output tensor.
Can be a variable number of arguments or a collection like a list or tuple or flow.Size.
out (optional): The output tensor.
dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.float32``.
layout (optional): The desired layout of returned Tensor.
generator (flow.Generator, optional) – a pseudorandom number generator for sampling
device (torch.device, optional): The desired device of returned local tensor. If None, uses the
current device.
placement (flow.placement, optional): The desired device of returned consistent tensor. If None, will
construct local tensor.
sbp (flow.sbp, optional): The desired sbp of returned consistent tensor. It must be equal with the
numbers of placement.
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> x = flow.randn(3,3)
>>> x.shape
flow.Size([3, 3])
>>> x.is_consistent
False
>>> placement = flow.placement("cpu", {0:[0]})
>>> sbp = flow.sbp.broadcast
>>> x = flow.randn(3,3,placement=placement,sbp=sbp)
>>> x.is_consistent
True
"""
assert out is None, "out not supported yet"
assert layout is None, "layout not supported yet"
if generator is None:
generator = flow.default_generator()
return RandN(
size, generator, dtype, layout, device, placement, sbp, requires_grad
)()
if __name__ == "__main__":
import doctest
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow as flow
import oneflow.unittest
from test_util import GenArgList
from automated_test_util import *
def _test_rand(test_case, device, shape):
y1 = flow.randn(*shape, device=flow.device(device))
y2 = flow.randn(*shape, device=flow.device(device))
test_case.assertTrue(not np.array_equal(y1, y2))
test_case.assertTrue(shape == y1.shape)
def _test_different_dtype(test_case, device, shape):
y1 = flow.randn(*shape, dtype=flow.float32, device=flow.device(device))
y2 = flow.randn(*shape, dtype=flow.float64, device=flow.device(device))
test_case.assertTrue(not np.array_equal(y1, y2))
test_case.assertTrue(shape == y1.shape)
with test_case.assertRaises(
oneflow._oneflow_internal.exception.UnimplementedException
):
flow.randn(*shape, dtype=flow.int32, device=flow.device(device))
def _test_backward(test_case, device, shape):
x = flow.randn(*shape, device=flow.device(device), requires_grad=True)
y = x.sum()
y.backward()
test_case.assertTrue(np.array_equal(np.ones(shape), x.grad.numpy()))
def _test_with_generator(test_case, device, shape):
gen = flow.Generator()
gen.manual_seed(0)
y1 = flow.randn(
*shape, dtype=flow.float32, device=flow.device(device), generator=gen
)
y1_np = y1.numpy()
gen.manual_seed(0)
y2 = flow.randn(
*shape, dtype=flow.float32, device=flow.device(device), generator=gen
)
test_case.assertTrue(np.array_equal(y1_np, y2.numpy()))
@flow.unittest.skip_unless_1n1d()
class TestConstantModule(flow.unittest.TestCase):
def test_consistent_naive(test_case):
placement = flow.placement("cpu", {0: [0]})
sbp = (flow.sbp.broadcast,)
x = flow.randn(16, 16, placement=placement, sbp=sbp)
test_case.assertEqual(x.sbp, sbp)
test_case.assertEqual(x.placement, placement)
def test_cast(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_rand,
_test_different_dtype,
_test_backward,
_test_with_generator,
]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5), (2, 0, 4)]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册