未验证 提交 381492fc 编写于 作者: L Leo Chen 提交者: GitHub

add try finally, test=develop (#24243)

上级 50330c6c
......@@ -427,8 +427,10 @@ class TrainingDecoder(object):
if self._status != TrainingDecoder.BEFORE_DECODER:
raise ValueError('decoder.block() can only be invoked once')
self._status = TrainingDecoder.IN_DECODER
with self._dynamic_rnn.block():
yield
self._status = TrainingDecoder.AFTER_DECODER
self._state_cell._leave_decoder(self)
......
......@@ -51,9 +51,11 @@ def program_desc_tracing_guard(enable):
if tracer:
original_val = tracer._enable_program_desc_tracing
tracer._enable_program_desc_tracing = enable
yield
if tracer:
tracer._enable_program_desc_tracing = original_val
try:
yield
finally:
if tracer:
tracer._enable_program_desc_tracing = original_val
_functional_dygraph_context_manager = None
......@@ -136,14 +138,16 @@ def disable_dygraph():
_functional_dygraph_context_manager = None
@contextlib.contextmanager
@signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
if tracer:
mode = tracer._train_mode
tracer._train_mode = is_train
yield
tracer._train_mode = mode
try:
yield
finally:
tracer._train_mode = mode
else:
yield
......
......@@ -97,8 +97,10 @@ def scope_guard(scope):
"""
ex = _switch_scope(scope)
yield
_switch_scope(ex)
try:
yield
finally:
_switch_scope(ex)
def as_numpy(tensor):
......
......@@ -487,8 +487,10 @@ def name_scope(prefix=None):
assert prefix, "namescope prefix can not be empty."
global _name_scope
_name_scope = _name_scope.child(prefix)
yield
_name_scope = _name_scope.parent()
try:
yield
finally:
_name_scope = _name_scope.parent()
def _full_name_scope():
......@@ -3984,14 +3986,16 @@ class Program(object):
"""
return self.__op_role_var
@contextlib.contextmanager
@signature_safe_contextmanager
def _backward_role_guard(self):
tmp_role = self._current_role
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Backward
yield
self._current_role = tmp_role
try:
yield
finally:
self._current_role = tmp_role
@signature_safe_contextmanager
def _optimized_guard(self, param_and_grads):
......@@ -4020,9 +4024,11 @@ class Program(object):
var.name if isinstance(var, Variable) else var
for var in param_and_grads
]
yield
self.__op_role_var = tmp_var
self._current_role = tmp_role
try:
yield
finally:
self.__op_role_var = tmp_var
self._current_role = tmp_role
@signature_safe_contextmanager
def _lr_schedule_guard(self, is_with_opt=False):
......@@ -4055,9 +4061,11 @@ class Program(object):
self._current_role = int(OpRole.LRSched) | int(OpRole.Optimize)
# TODO(typhoonzero): how to set target learning rate var
self.__op_role_var = []
yield
self.__op_role_var = tmp_var
self._current_role = tmp_role
try:
yield
finally:
self.__op_role_var = tmp_var
self._current_role = tmp_role
def __str__(self):
"""
......@@ -5310,10 +5318,12 @@ def program_guard(main_program, startup_program=None):
check_type(startup_program, 'startup_program', Program,
'fluid.program_guard')
startup_program = switch_startup_program(startup_program)
yield
switch_main_program(main_program)
if startup_program is not None:
switch_startup_program(startup_program)
try:
yield
finally:
switch_main_program(main_program)
if startup_program is not None:
switch_startup_program(startup_program)
def _get_var(name, program=None):
......@@ -5343,10 +5353,11 @@ def _dygraph_guard(tracer):
_dygraph_tracer_ = tracer
core._switch_tracer(tracer)
yield
core._switch_tracer(tmp_trace)
_dygraph_tracer_ = tmp_trace
try:
yield
finally:
core._switch_tracer(tmp_trace)
_dygraph_tracer_ = tmp_trace
@signature_safe_contextmanager
......@@ -5355,9 +5366,10 @@ def _dygraph_place_guard(place):
tmp_place = _dygraph_current_expected_place_
_dygraph_current_expected_place_ = place
yield
_dygraph_current_expected_place_ = tmp_place
try:
yield
finally:
_dygraph_current_expected_place_ = tmp_place
def load_op_library(lib_filename):
......@@ -5437,8 +5449,10 @@ def device_guard(device=None):
"The Attr(device) should be 'cpu' or 'gpu', and it can also be empty string or None "
"when there is no need to specify device. But received %s" % device)
pre_device = switch_device(device)
yield
switch_device(pre_device)
try:
yield
finally:
switch_device(pre_device)
def set_flags(flags):
......
......@@ -18,7 +18,6 @@ from . import framework
from . import core
from .framework import in_dygraph_mode
import numpy as np
from .wrapped_decorator import signature_safe_contextmanager
from .core import VarDesc
from . import unique_name
from .data_feeder import check_variable_and_dtype, check_type, check_dtype
......
......@@ -33,6 +33,7 @@ from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, \
program_guard
from .wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.log_helper import get_logger
from . import reader
......@@ -183,7 +184,7 @@ def _clone_var_in_block_(block, var):
persistable=True)
@contextlib.contextmanager
@signature_safe_contextmanager
def _load_program_scope(main=None, startup=None, scope=None):
prog = main if main else paddle.fluid.Program()
startup_prog = startup if startup else paddle.fluid.Program()
......
......@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import print_function
from ..wrapped_decorator import signature_safe_contextmanager
import multiprocessing
import os
import six
......
......@@ -98,10 +98,12 @@ def cuda_profiler(output_file, output_mode=None, config=None):
core.nvprof_init(output_file, output_mode, config_file)
# Enables profiler collection by the active CUDA profiling tool.
core.nvprof_start()
yield
try:
yield
# Disables profiler collection.
core.nvprof_stop()
os.remove(config_file)
finally:
core.nvprof_stop()
os.remove(config_file)
def reset_profiler():
......@@ -345,5 +347,7 @@ def profiler(state,
thread0::elementwise_add 8 1.96555 0.191884 0.518004 0.245693 0.196998
"""
start_profiler(state, tracer_option)
yield
stop_profiler(sorted_key, profile_path)
try:
yield
finally:
stop_profiler(sorted_key, profile_path)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
class TestContextManagerRaiseException(unittest.TestCase):
# When exception raised in 'with' context, we should safely exit the context
def test_func1(self):
def foo():
with fluid.dygraph.guard():
print("raise error in context manager")
raise TypeError("error")
self.assertRaises(TypeError, foo)
def test_func2(self):
# After test_func1 executed, if fluid.dygraph.guard() in test_func1 safely exited,
# fluid.in_dygraph_mode() should be false.
self.assertEqual(fluid.in_dygraph_mode(), False)
if __name__ == '__main__':
unittest.main()
......@@ -223,5 +223,7 @@ def guard(new_generator=None):
new_generator = UniqueNameGenerator(new_generator.decode())
old_generator, old_para_name_checker = switch(new_generator)
yield
switch(old_generator, old_para_name_checker)
try:
yield
finally:
switch(old_generator, old_para_name_checker)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册