diff --git a/mindspore/ops/composite/random_ops.py b/mindspore/ops/composite/random_ops.py index b0c2a55d6f10020c2f0dc69fcb6727216b77fa36..9c1b02e4f9739d3572d05d4a187ab5e09424d85d 100644 --- a/mindspore/ops/composite/random_ops.py +++ b/mindspore/ops/composite/random_ops.py @@ -22,6 +22,7 @@ from .multitype_ops import _constexpr_utils as const_utils from ...common import dtype as mstype from ...common.tensor import Tensor from ..._checkparam import Validator as validator +from ..._checkparam import check_int_positive from ..._checkparam import Rel # set graph-level RNG seed @@ -29,11 +30,36 @@ _GRAPH_SEED = 0 @constexpr def set_seed(seed): + """ + Set the graph-level seed. + Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. + If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a + random seed. + + Args: + seed(Int): the graph-level seed value that to be set. + + Examples: + >>> C.set_seed(10) + """ + check_int_positive(seed) global _GRAPH_SEED _GRAPH_SEED = seed @constexpr def get_seed(): + """ + Get the graph-level seed. + Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set. + If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a + random seed. + + Returns: + Interger. The current graph-level seed. + + Examples: + >>> C.get_seed(10) + """ return _GRAPH_SEED @@ -58,7 +84,6 @@ def normal(shape, mean, stddev, seed=0): >>> shape = (4, 16) >>> mean = Tensor(1.0, mstype.float32) >>> stddev = Tensor(1.0, mstype.float32) - >>> C.set_seed(10) >>> output = C.normal(shape, mean, stddev, seed=5) """ mean_dtype = F.dtype(mean)