未验证 提交 2a38f123 编写于 作者: 0 0YuanZhang0 提交者: GitHub

ADD Manual seed op into paddle.framework (#23537)

* test=develop
Co-authored-by: Nwuxing03 <wuxing03@baidu.com>
上级 7f0b2c74
......@@ -36,6 +36,7 @@ batch = batch.batch
import paddle.sysconfig
import paddle.tensor
import paddle.nn
import paddle.framework
import paddle.imperative
# TODO: define alias in tensor and framework directory
......@@ -199,7 +200,7 @@ from .tensor.search import nonzero #DEFINE_ALIAS
from .tensor.search import sort #DEFINE_ALIAS
# from .framework.framework import set_default_dtype #DEFINE_ALIAS
# from .framework.framework import get_default_dtype #DEFINE_ALIAS
# from .framework.random import manual_seed #DEFINE_ALIAS
from .framework.random import manual_seed #DEFINE_ALIAS
# from .framework import append_backward #DEFINE_ALIAS
# from .framework import gradients #DEFINE_ALIAS
# from .framework import Executor #DEFINE_ALIAS
......
......@@ -66,6 +66,8 @@ _dygraph_tracer_ = None
_dygraph_current_expected_place_ = None
_current_device = None
global_prog_seed = 0
def require_version(min_version, max_version=None):
"""
......@@ -3653,7 +3655,8 @@ class Program(object):
self.desc = core.ProgramDesc()
self.blocks = [Block(self, 0)]
self.current_block_idx = 0
self._seed = 0
global global_prog_seed
self._seed = global_prog_seed
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self.__op_role_var = []
......@@ -3692,6 +3695,33 @@ class Program(object):
# appending gradients times
self._appending_grad_times = 0
def global_seed(self, seed=0):
"""
Set global seed for Program
Returns:
None.
Examples:
.. code-block:: python
import paddle.fluid as fluid
prog = fluid.default_main_program()
print(prog.random_seed)
## 0
## the default random seed is 0
prog.global_seed(102)
prog1 = fluid.default_main_program()
print(prog1.random_seed)
## 102
## the random seed is 102
"""
global global_prog_seed
global_prog_seed = seed
self._seed = global_prog_seed
@property
def _op_role(self):
"""
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from paddle.framework import manual_seed
from paddle.fluid.framework import Program, default_main_program, default_startup_program
class TestManualSeed(unittest.TestCase):
def test_manual_seed(self):
local_program = Program()
local_main_prog = default_main_program()
local_start_prog = default_startup_program()
self.assertEqual(0, local_program.random_seed)
self.assertEqual(0, local_main_prog.random_seed)
self.assertEqual(0, local_start_prog.random_seed)
manual_seed(102)
global_program1 = Program()
global_program2 = Program()
global_main_prog = default_main_program()
global_start_prog = default_startup_program()
self.assertEqual(102, global_program1.random_seed)
self.assertEqual(102, global_program2.random_seed)
self.assertEqual(102, global_main_prog.random_seed)
self.assertEqual(102, global_start_prog.random_seed)
if __name__ == '__main__':
unittest.main()
......@@ -38,3 +38,6 @@
# 'WeightNormParamAttr',
# 'Model',
# 'Sequential']
from . import random
from .random import manual_seed
......@@ -13,4 +13,28 @@
# limitations under the License.
# TODO: define random api
# __all__ = ['manual_seed']
import paddle.fluid as fluid
__all__ = ['manual_seed']
def manual_seed(seed):
"""
Set global manual seed for program
Args:
manual_seed(int): random seed for program
Returns:
None.
Examples:
.. code-block:: python
from paddle.framework import manual_seed
manual_seed(102)
"""
fluid.default_main_program().random_seed = seed
fluid.default_startup_program().random_seed = seed
program = fluid.Program()
program.global_seed(seed)
......@@ -141,6 +141,7 @@ packages=['paddle',
'paddle.distributed',
'paddle.fluid',
'paddle.tensor',
'paddle.framework',
'paddle.fluid.dygraph',
'paddle.tensor',
'paddle.fluid.dygraph.dygraph_to_static',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册