未验证 提交 381492fc 编写于 作者: L Leo Chen 提交者: GitHub

add try finally, test=develop (#24243)

上级 50330c6c
...@@ -427,8 +427,10 @@ class TrainingDecoder(object): ...@@ -427,8 +427,10 @@ class TrainingDecoder(object):
if self._status != TrainingDecoder.BEFORE_DECODER: if self._status != TrainingDecoder.BEFORE_DECODER:
raise ValueError('decoder.block() can only be invoked once') raise ValueError('decoder.block() can only be invoked once')
self._status = TrainingDecoder.IN_DECODER self._status = TrainingDecoder.IN_DECODER
with self._dynamic_rnn.block(): with self._dynamic_rnn.block():
yield yield
self._status = TrainingDecoder.AFTER_DECODER self._status = TrainingDecoder.AFTER_DECODER
self._state_cell._leave_decoder(self) self._state_cell._leave_decoder(self)
......
...@@ -51,9 +51,11 @@ def program_desc_tracing_guard(enable): ...@@ -51,9 +51,11 @@ def program_desc_tracing_guard(enable):
if tracer: if tracer:
original_val = tracer._enable_program_desc_tracing original_val = tracer._enable_program_desc_tracing
tracer._enable_program_desc_tracing = enable tracer._enable_program_desc_tracing = enable
yield try:
if tracer: yield
tracer._enable_program_desc_tracing = original_val finally:
if tracer:
tracer._enable_program_desc_tracing = original_val
_functional_dygraph_context_manager = None _functional_dygraph_context_manager = None
...@@ -136,14 +138,16 @@ def disable_dygraph(): ...@@ -136,14 +138,16 @@ def disable_dygraph():
_functional_dygraph_context_manager = None _functional_dygraph_context_manager = None
@contextlib.contextmanager @signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True): def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer() tracer = framework._dygraph_tracer()
if tracer: if tracer:
mode = tracer._train_mode mode = tracer._train_mode
tracer._train_mode = is_train tracer._train_mode = is_train
yield try:
tracer._train_mode = mode yield
finally:
tracer._train_mode = mode
else: else:
yield yield
......
...@@ -97,8 +97,10 @@ def scope_guard(scope): ...@@ -97,8 +97,10 @@ def scope_guard(scope):
""" """
ex = _switch_scope(scope) ex = _switch_scope(scope)
yield try:
_switch_scope(ex) yield
finally:
_switch_scope(ex)
def as_numpy(tensor): def as_numpy(tensor):
......
...@@ -487,8 +487,10 @@ def name_scope(prefix=None): ...@@ -487,8 +487,10 @@ def name_scope(prefix=None):
assert prefix, "namescope prefix can not be empty." assert prefix, "namescope prefix can not be empty."
global _name_scope global _name_scope
_name_scope = _name_scope.child(prefix) _name_scope = _name_scope.child(prefix)
yield try:
_name_scope = _name_scope.parent() yield
finally:
_name_scope = _name_scope.parent()
def _full_name_scope(): def _full_name_scope():
...@@ -3984,14 +3986,16 @@ class Program(object): ...@@ -3984,14 +3986,16 @@ class Program(object):
""" """
return self.__op_role_var return self.__op_role_var
@contextlib.contextmanager @signature_safe_contextmanager
def _backward_role_guard(self): def _backward_role_guard(self):
tmp_role = self._current_role tmp_role = self._current_role
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Backward self._current_role = OpRole.Backward
yield try:
self._current_role = tmp_role yield
finally:
self._current_role = tmp_role
@signature_safe_contextmanager @signature_safe_contextmanager
def _optimized_guard(self, param_and_grads): def _optimized_guard(self, param_and_grads):
...@@ -4020,9 +4024,11 @@ class Program(object): ...@@ -4020,9 +4024,11 @@ class Program(object):
var.name if isinstance(var, Variable) else var var.name if isinstance(var, Variable) else var
for var in param_and_grads for var in param_and_grads
] ]
yield try:
self.__op_role_var = tmp_var yield
self._current_role = tmp_role finally:
self.__op_role_var = tmp_var
self._current_role = tmp_role
@signature_safe_contextmanager @signature_safe_contextmanager
def _lr_schedule_guard(self, is_with_opt=False): def _lr_schedule_guard(self, is_with_opt=False):
...@@ -4055,9 +4061,11 @@ class Program(object): ...@@ -4055,9 +4061,11 @@ class Program(object):
self._current_role = int(OpRole.LRSched) | int(OpRole.Optimize) self._current_role = int(OpRole.LRSched) | int(OpRole.Optimize)
# TODO(typhoonzero): how to set target learning rate var # TODO(typhoonzero): how to set target learning rate var
self.__op_role_var = [] self.__op_role_var = []
yield try:
self.__op_role_var = tmp_var yield
self._current_role = tmp_role finally:
self.__op_role_var = tmp_var
self._current_role = tmp_role
def __str__(self): def __str__(self):
""" """
...@@ -5310,10 +5318,12 @@ def program_guard(main_program, startup_program=None): ...@@ -5310,10 +5318,12 @@ def program_guard(main_program, startup_program=None):
check_type(startup_program, 'startup_program', Program, check_type(startup_program, 'startup_program', Program,
'fluid.program_guard') 'fluid.program_guard')
startup_program = switch_startup_program(startup_program) startup_program = switch_startup_program(startup_program)
yield try:
switch_main_program(main_program) yield
if startup_program is not None: finally:
switch_startup_program(startup_program) switch_main_program(main_program)
if startup_program is not None:
switch_startup_program(startup_program)
def _get_var(name, program=None): def _get_var(name, program=None):
...@@ -5343,10 +5353,11 @@ def _dygraph_guard(tracer): ...@@ -5343,10 +5353,11 @@ def _dygraph_guard(tracer):
_dygraph_tracer_ = tracer _dygraph_tracer_ = tracer
core._switch_tracer(tracer) core._switch_tracer(tracer)
yield try:
yield
core._switch_tracer(tmp_trace) finally:
_dygraph_tracer_ = tmp_trace core._switch_tracer(tmp_trace)
_dygraph_tracer_ = tmp_trace
@signature_safe_contextmanager @signature_safe_contextmanager
...@@ -5355,9 +5366,10 @@ def _dygraph_place_guard(place): ...@@ -5355,9 +5366,10 @@ def _dygraph_place_guard(place):
tmp_place = _dygraph_current_expected_place_ tmp_place = _dygraph_current_expected_place_
_dygraph_current_expected_place_ = place _dygraph_current_expected_place_ = place
yield try:
yield
_dygraph_current_expected_place_ = tmp_place finally:
_dygraph_current_expected_place_ = tmp_place
def load_op_library(lib_filename): def load_op_library(lib_filename):
...@@ -5437,8 +5449,10 @@ def device_guard(device=None): ...@@ -5437,8 +5449,10 @@ def device_guard(device=None):
"The Attr(device) should be 'cpu' or 'gpu', and it can also be empty string or None " "The Attr(device) should be 'cpu' or 'gpu', and it can also be empty string or None "
"when there is no need to specify device. But received %s" % device) "when there is no need to specify device. But received %s" % device)
pre_device = switch_device(device) pre_device = switch_device(device)
yield try:
switch_device(pre_device) yield
finally:
switch_device(pre_device)
def set_flags(flags): def set_flags(flags):
......
...@@ -18,7 +18,6 @@ from . import framework ...@@ -18,7 +18,6 @@ from . import framework
from . import core from . import core
from .framework import in_dygraph_mode from .framework import in_dygraph_mode
import numpy as np import numpy as np
from .wrapped_decorator import signature_safe_contextmanager
from .core import VarDesc from .core import VarDesc
from . import unique_name from . import unique_name
from .data_feeder import check_variable_and_dtype, check_type, check_dtype from .data_feeder import check_variable_and_dtype, check_type, check_dtype
......
...@@ -33,6 +33,7 @@ from paddle.fluid.executor import Executor, global_scope ...@@ -33,6 +33,7 @@ from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.evaluator import Evaluator from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, \ from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, \
program_guard program_guard
from .wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.compiler import CompiledProgram from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from . import reader from . import reader
...@@ -183,7 +184,7 @@ def _clone_var_in_block_(block, var): ...@@ -183,7 +184,7 @@ def _clone_var_in_block_(block, var):
persistable=True) persistable=True)
@contextlib.contextmanager @signature_safe_contextmanager
def _load_program_scope(main=None, startup=None, scope=None): def _load_program_scope(main=None, startup=None, scope=None):
prog = main if main else paddle.fluid.Program() prog = main if main else paddle.fluid.Program()
startup_prog = startup if startup else paddle.fluid.Program() startup_prog = startup if startup else paddle.fluid.Program()
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
from ..wrapped_decorator import signature_safe_contextmanager
import multiprocessing import multiprocessing
import os import os
import six import six
......
...@@ -98,10 +98,12 @@ def cuda_profiler(output_file, output_mode=None, config=None): ...@@ -98,10 +98,12 @@ def cuda_profiler(output_file, output_mode=None, config=None):
core.nvprof_init(output_file, output_mode, config_file) core.nvprof_init(output_file, output_mode, config_file)
# Enables profiler collection by the active CUDA profiling tool. # Enables profiler collection by the active CUDA profiling tool.
core.nvprof_start() core.nvprof_start()
yield try:
yield
# Disables profiler collection. # Disables profiler collection.
core.nvprof_stop() finally:
os.remove(config_file) core.nvprof_stop()
os.remove(config_file)
def reset_profiler(): def reset_profiler():
...@@ -345,5 +347,7 @@ def profiler(state, ...@@ -345,5 +347,7 @@ def profiler(state,
thread0::elementwise_add 8 1.96555 0.191884 0.518004 0.245693 0.196998 thread0::elementwise_add 8 1.96555 0.191884 0.518004 0.245693 0.196998
""" """
start_profiler(state, tracer_option) start_profiler(state, tracer_option)
yield try:
stop_profiler(sorted_key, profile_path) yield
finally:
stop_profiler(sorted_key, profile_path)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
class TestContextManagerRaiseException(unittest.TestCase):
# When exception raised in 'with' context, we should safely exit the context
def test_func1(self):
def foo():
with fluid.dygraph.guard():
print("raise error in context manager")
raise TypeError("error")
self.assertRaises(TypeError, foo)
def test_func2(self):
# After test_func1 executed, if fluid.dygraph.guard() in test_func1 safely exited,
# fluid.in_dygraph_mode() should be false.
self.assertEqual(fluid.in_dygraph_mode(), False)
if __name__ == '__main__':
unittest.main()
...@@ -223,5 +223,7 @@ def guard(new_generator=None): ...@@ -223,5 +223,7 @@ def guard(new_generator=None):
new_generator = UniqueNameGenerator(new_generator.decode()) new_generator = UniqueNameGenerator(new_generator.decode())
old_generator, old_para_name_checker = switch(new_generator) old_generator, old_para_name_checker = switch(new_generator)
yield try:
switch(old_generator, old_para_name_checker) yield
finally:
switch(old_generator, old_para_name_checker)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册