提交 68fc7c2c 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!4082 Added notation for graph-level seed access interfaces

Merge pull request !4082 from peixu_ren/master
......@@ -22,6 +22,7 @@ from .multitype_ops import _constexpr_utils as const_utils
from ...common import dtype as mstype
from ...common.tensor import Tensor
from ..._checkparam import Validator as validator
from ..._checkparam import check_int_positive
from ..._checkparam import Rel
# set graph-level RNG seed
......@@ -29,11 +30,36 @@ _GRAPH_SEED = 0
@constexpr
def set_seed(seed):
"""
Set the graph-level seed.
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
random seed.
Args:
seed(Int): the graph-level seed value that to be set.
Examples:
>>> C.set_seed(10)
"""
check_int_positive(seed)
global _GRAPH_SEED
_GRAPH_SEED = seed
@constexpr
def get_seed():
"""
Get the graph-level seed.
Graph-level seed is used as a global variable, that can be used in different ops in case op-level seed is not set.
If op-level seed is 0, use graph-level seed; if op-level seed is also 0, the system would generate a
random seed.
Returns:
Interger. The current graph-level seed.
Examples:
>>> C.get_seed(10)
"""
return _GRAPH_SEED
......@@ -58,7 +84,6 @@ def normal(shape, mean, stddev, seed=0):
>>> shape = (4, 16)
>>> mean = Tensor(1.0, mstype.float32)
>>> stddev = Tensor(1.0, mstype.float32)
>>> C.set_seed(10)
>>> output = C.normal(shape, mean, stddev, seed=5)
"""
mean_dtype = F.dtype(mean)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册