未验证 提交 17ff46b6 编写于 作者: 0 0YuanZhang0 提交者: GitHub

Cherry-pick Manual seed op into paddle.framework (#23537)

add manual_seed op into paddle.framework.random
上级 8b914906
...@@ -36,6 +36,7 @@ batch = batch.batch ...@@ -36,6 +36,7 @@ batch = batch.batch
import paddle.sysconfig import paddle.sysconfig
import paddle.tensor import paddle.tensor
import paddle.nn import paddle.nn
import paddle.framework
import paddle.imperative import paddle.imperative
# TODO: define alias in tensor and framework directory # TODO: define alias in tensor and framework directory
...@@ -198,7 +199,7 @@ from .tensor.search import nonzero #DEFINE_ALIAS ...@@ -198,7 +199,7 @@ from .tensor.search import nonzero #DEFINE_ALIAS
from .tensor.search import sort #DEFINE_ALIAS from .tensor.search import sort #DEFINE_ALIAS
# from .framework.framework import set_default_dtype #DEFINE_ALIAS # from .framework.framework import set_default_dtype #DEFINE_ALIAS
# from .framework.framework import get_default_dtype #DEFINE_ALIAS # from .framework.framework import get_default_dtype #DEFINE_ALIAS
# from .framework.random import manual_seed #DEFINE_ALIAS from .framework.random import manual_seed #DEFINE_ALIAS
# from .framework import append_backward #DEFINE_ALIAS # from .framework import append_backward #DEFINE_ALIAS
# from .framework import gradients #DEFINE_ALIAS # from .framework import gradients #DEFINE_ALIAS
# from .framework import Executor #DEFINE_ALIAS # from .framework import Executor #DEFINE_ALIAS
......
...@@ -66,6 +66,8 @@ _dygraph_tracer_ = None ...@@ -66,6 +66,8 @@ _dygraph_tracer_ = None
_dygraph_current_expected_place_ = None _dygraph_current_expected_place_ = None
_current_device = None _current_device = None
global_prog_seed = 0
def require_version(min_version, max_version=None): def require_version(min_version, max_version=None):
""" """
...@@ -3653,7 +3655,8 @@ class Program(object): ...@@ -3653,7 +3655,8 @@ class Program(object):
self.desc = core.ProgramDesc() self.desc = core.ProgramDesc()
self.blocks = [Block(self, 0)] self.blocks = [Block(self, 0)]
self.current_block_idx = 0 self.current_block_idx = 0
self._seed = 0 global global_prog_seed
self._seed = global_prog_seed
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self.__op_role_var = [] self.__op_role_var = []
...@@ -3692,6 +3695,33 @@ class Program(object): ...@@ -3692,6 +3695,33 @@ class Program(object):
# appending gradients times # appending gradients times
self._appending_grad_times = 0 self._appending_grad_times = 0
def global_seed(self, seed=0):
"""
Set global seed for Program
Returns:
None.
Examples:
.. code-block:: python
import paddle.fluid as fluid
prog = fluid.default_main_program()
print(prog.random_seed)
## 0
## the default random seed is 0
prog.global_seed(102)
prog1 = fluid.default_main_program()
print(prog1.random_seed)
## 102
## the random seed is 102
"""
global global_prog_seed
global_prog_seed = seed
self._seed = global_prog_seed
@property @property
def _op_role(self): def _op_role(self):
""" """
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from paddle.framework import manual_seed
from paddle.fluid.framework import Program, default_main_program, default_startup_program
class TestManualSeed(unittest.TestCase):
def test_manual_seed(self):
local_program = Program()
local_main_prog = default_main_program()
local_start_prog = default_startup_program()
self.assertEqual(0, local_program.random_seed)
self.assertEqual(0, local_main_prog.random_seed)
self.assertEqual(0, local_start_prog.random_seed)
manual_seed(102)
global_program1 = Program()
global_program2 = Program()
global_main_prog = default_main_program()
global_start_prog = default_startup_program()
self.assertEqual(102, global_program1.random_seed)
self.assertEqual(102, global_program2.random_seed)
self.assertEqual(102, global_main_prog.random_seed)
self.assertEqual(102, global_start_prog.random_seed)
if __name__ == '__main__':
unittest.main()
...@@ -38,3 +38,6 @@ ...@@ -38,3 +38,6 @@
# 'WeightNormParamAttr', # 'WeightNormParamAttr',
# 'Model', # 'Model',
# 'Sequential'] # 'Sequential']
from . import random
from .random import manual_seed
...@@ -13,4 +13,28 @@ ...@@ -13,4 +13,28 @@
# limitations under the License. # limitations under the License.
# TODO: define random api # TODO: define random api
# __all__ = ['manual_seed'] import paddle.fluid as fluid
__all__ = ['manual_seed']
def manual_seed(seed):
"""
Set global manual seed for program
Args:
manual_seed(int): random seed for program
Returns:
None.
Examples:
.. code-block:: python
from paddle.framework import manual_seed
manual_seed(102)
"""
fluid.default_main_program().random_seed = seed
fluid.default_startup_program().random_seed = seed
program = fluid.Program()
program.global_seed(seed)
...@@ -141,6 +141,7 @@ packages=['paddle', ...@@ -141,6 +141,7 @@ packages=['paddle',
'paddle.distributed', 'paddle.distributed',
'paddle.fluid', 'paddle.fluid',
'paddle.tensor', 'paddle.tensor',
'paddle.framework',
'paddle.fluid.dygraph', 'paddle.fluid.dygraph',
'paddle.tensor', 'paddle.tensor',
'paddle.fluid.dygraph.dygraph_to_static', 'paddle.fluid.dygraph.dygraph_to_static',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册