未验证 提交 381492fc 编写于 作者: L Leo Chen 提交者: GitHub

add try finally, test=develop (#24243)

上级 50330c6c
......@@ -427,8 +427,10 @@ class TrainingDecoder(object):
if self._status != TrainingDecoder.BEFORE_DECODER:
raise ValueError('decoder.block() can only be invoked once')
self._status = TrainingDecoder.IN_DECODER
with self._dynamic_rnn.block():
yield
self._status = TrainingDecoder.AFTER_DECODER
self._state_cell._leave_decoder(self)
......
......@@ -51,7 +51,9 @@ def program_desc_tracing_guard(enable):
if tracer:
original_val = tracer._enable_program_desc_tracing
tracer._enable_program_desc_tracing = enable
try:
yield
finally:
if tracer:
tracer._enable_program_desc_tracing = original_val
......@@ -136,13 +138,15 @@ def disable_dygraph():
_functional_dygraph_context_manager = None
@contextlib.contextmanager
@signature_safe_contextmanager
def _switch_tracer_mode_guard_(is_train=True):
tracer = framework._dygraph_tracer()
if tracer:
mode = tracer._train_mode
tracer._train_mode = is_train
try:
yield
finally:
tracer._train_mode = mode
else:
yield
......
......@@ -97,7 +97,9 @@ def scope_guard(scope):
"""
ex = _switch_scope(scope)
try:
yield
finally:
_switch_scope(ex)
......
......@@ -487,7 +487,9 @@ def name_scope(prefix=None):
assert prefix, "namescope prefix can not be empty."
global _name_scope
_name_scope = _name_scope.child(prefix)
try:
yield
finally:
_name_scope = _name_scope.parent()
......@@ -3984,13 +3986,15 @@ class Program(object):
"""
return self.__op_role_var
@contextlib.contextmanager
@signature_safe_contextmanager
def _backward_role_guard(self):
tmp_role = self._current_role
OpRole = core.op_proto_and_checker_maker.OpRole
self._current_role = OpRole.Backward
try:
yield
finally:
self._current_role = tmp_role
@signature_safe_contextmanager
......@@ -4020,7 +4024,9 @@ class Program(object):
var.name if isinstance(var, Variable) else var
for var in param_and_grads
]
try:
yield
finally:
self.__op_role_var = tmp_var
self._current_role = tmp_role
......@@ -4055,7 +4061,9 @@ class Program(object):
self._current_role = int(OpRole.LRSched) | int(OpRole.Optimize)
# TODO(typhoonzero): how to set target learning rate var
self.__op_role_var = []
try:
yield
finally:
self.__op_role_var = tmp_var
self._current_role = tmp_role
......@@ -5310,7 +5318,9 @@ def program_guard(main_program, startup_program=None):
check_type(startup_program, 'startup_program', Program,
'fluid.program_guard')
startup_program = switch_startup_program(startup_program)
try:
yield
finally:
switch_main_program(main_program)
if startup_program is not None:
switch_startup_program(startup_program)
......@@ -5343,8 +5353,9 @@ def _dygraph_guard(tracer):
_dygraph_tracer_ = tracer
core._switch_tracer(tracer)
try:
yield
finally:
core._switch_tracer(tmp_trace)
_dygraph_tracer_ = tmp_trace
......@@ -5355,8 +5366,9 @@ def _dygraph_place_guard(place):
tmp_place = _dygraph_current_expected_place_
_dygraph_current_expected_place_ = place
try:
yield
finally:
_dygraph_current_expected_place_ = tmp_place
......@@ -5437,7 +5449,9 @@ def device_guard(device=None):
"The Attr(device) should be 'cpu' or 'gpu', and it can also be empty string or None "
"when there is no need to specify device. But received %s" % device)
pre_device = switch_device(device)
try:
yield
finally:
switch_device(pre_device)
......
......@@ -18,7 +18,6 @@ from . import framework
from . import core
from .framework import in_dygraph_mode
import numpy as np
from .wrapped_decorator import signature_safe_contextmanager
from .core import VarDesc
from . import unique_name
from .data_feeder import check_variable_and_dtype, check_type, check_dtype
......
......@@ -33,6 +33,7 @@ from paddle.fluid.executor import Executor, global_scope
from paddle.fluid.evaluator import Evaluator
from paddle.fluid.framework import Program, Parameter, default_main_program, default_startup_program, Variable, \
program_guard
from .wrapped_decorator import signature_safe_contextmanager
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.log_helper import get_logger
from . import reader
......@@ -183,7 +184,7 @@ def _clone_var_in_block_(block, var):
persistable=True)
@contextlib.contextmanager
@signature_safe_contextmanager
def _load_program_scope(main=None, startup=None, scope=None):
prog = main if main else paddle.fluid.Program()
startup_prog = startup if startup else paddle.fluid.Program()
......
......@@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import print_function
from ..wrapped_decorator import signature_safe_contextmanager
import multiprocessing
import os
import six
......
......@@ -98,8 +98,10 @@ def cuda_profiler(output_file, output_mode=None, config=None):
core.nvprof_init(output_file, output_mode, config_file)
# Enables profiler collection by the active CUDA profiling tool.
core.nvprof_start()
try:
yield
# Disables profiler collection.
finally:
core.nvprof_stop()
os.remove(config_file)
......@@ -345,5 +347,7 @@ def profiler(state,
thread0::elementwise_add 8 1.96555 0.191884 0.518004 0.245693 0.196998
"""
start_profiler(state, tracer_option)
try:
yield
finally:
stop_profiler(sorted_key, profile_path)
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
class TestContextManagerRaiseException(unittest.TestCase):
# When exception raised in 'with' context, we should safely exit the context
def test_func1(self):
def foo():
with fluid.dygraph.guard():
print("raise error in context manager")
raise TypeError("error")
self.assertRaises(TypeError, foo)
def test_func2(self):
# After test_func1 executed, if fluid.dygraph.guard() in test_func1 safely exited,
# fluid.in_dygraph_mode() should be false.
self.assertEqual(fluid.in_dygraph_mode(), False)
if __name__ == '__main__':
unittest.main()
......@@ -223,5 +223,7 @@ def guard(new_generator=None):
new_generator = UniqueNameGenerator(new_generator.decode())
old_generator, old_para_name_checker = switch(new_generator)
try:
yield
finally:
switch(old_generator, old_para_name_checker)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册