generator.py 1.9 KB
Newer Older
Y
yaoxuefeng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is definition of generator class, which is for managing the state of the algorithm that produces pseudo random numbers."""

from . import core

__all__ = ['Generator']

default_rng_seed_val = 34342423252


class Generator(object):
    """Generator class"""

    def __init__(self, device="CPU"):
        """init"""
        self.device = device
        seed_in = default_rng_seed_val
        if self.device == "CPU":
            self.generator = core.Generator()
            self.generator.manual_seed(seed_in)
        else:
            raise ValueError(
                "generator class with device %s does not exist, currently only support generator with device 'CPU' "
                % device)

    def get_state(self):
        return self.generator.get_state()

    def set_state(self, state):
        self.generator.set_state(state)

    def manual_seed(self, seed):
        self.generator.manual_seed(seed)

    def seed(self):
        return self.generator.seed()

    def initial_seed(self):
        return self.generator.initial_seed()

    def random(self):
        return self.generator.random()

    def get_cpu_engine(self):
        return self.generator.get_cpu_engine()

    def set_cpu_engine(self, cpu_engine):
        self.generator.set_cpu_engine(cpu_engine)