From ed45ecc626c07d50c0a1128b92797a4247d69aa9 Mon Sep 17 00:00:00 2001 From: qizhaoaoe <10208099+qizhaoaoe@users.noreply.github.com> Date: Tue, 25 Apr 2023 14:25:03 +0800 Subject: [PATCH] [fluid clean] remove Print. (#51778) * fluid clean: remove print/switch from fluid to static * remove Switch in static.__init__ * fix conflicts. * replace Switch by case. * fix piecewise_lr decay. * fix typo * fix conflicts. * fix lr dtype * keep Switch in paddle.static.nn.control_flow and fix piecewise_lr. * fix conflicts. * keep Switch in the fluid. * fix Switch doc * fix example in Switch doc * fix Switch doc. * fix static/__init__. --- python/paddle/fluid/layers/control_flow.py | 100 +----------------- .../fluid/layers/learning_rate_scheduler.py | 37 ++++--- python/paddle/fluid/optimizer.py | 69 ++++++------ .../dist_fleet_heter_pipeline_ctr.py | 2 +- .../fluid/tests/unittests/test_switch.py | 34 +++--- .../paddle/jit/dy2static/convert_operators.py | 4 +- python/paddle/nn/layer/rnn.py | 2 - python/paddle/reader/decorator.py | 3 +- python/paddle/static/__init__.py | 2 +- python/paddle/static/nn/control_flow.py | 100 +++++++++++++++++- .../test_multi_precision_fp16_train.py | 10 +- test/ipu/test_print_op_ipu.py | 2 +- 12 files changed, 185 insertions(+), 180 deletions(-) diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index a0ad94df79d..6d402df9f3c 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -50,7 +50,6 @@ from paddle import _C_ops, _legacy_C_ops __all__ = [ 'Switch', 'StaticRNN', - 'Print', 'while_loop', ] @@ -141,104 +140,6 @@ def select_input(inputs, mask): return out -@static_only -def Print( - input, - first_n=-1, - message=None, - summarize=20, - print_tensor_name=True, - print_tensor_type=True, - print_tensor_shape=True, - print_tensor_layout=True, - print_tensor_lod=True, - print_phase='both', -): - ''' - :api_attr: Static Graph - - **Print operator** - - This creates a print op that will print when a tensor is accessed. - - Wraps the tensor passed in so that whenever that a tensor is accessed, - the message `message` is printed, along with the current value of the - tensor `t`. - - Args: - input (Tensor): A Tensor to print. - first_n (int, optional): Only log `first_n` number of times. Default: -1. - message (str, optional): A string message to print as a prefix. Default: None. - summarize (int, optional): Number of elements in the tensor to be print. If - it's value is -1, then all elements in the tensor will be print. - print_tensor_name (bool, optional): Print the tensor name. Default: True. - print_tensor_type (bool, optional): Print the tensor type. Defaultt: True. - print_tensor_shape (bool, optional): Print the tensor shape. Default: True. - print_tensor_layout (bool, optional): Print the tensor layout. Default: True. - print_tensor_lod (bool, optional): Print the tensor lod. Default: True. - print_phase (str, optional): Which phase to displace, including 'forward', - 'backward' and 'both'. Default: 'both'. If set to 'backward', will - only print the gradients of input tensor; If set to 'both', will - both print the input tensor itself and the gradients of input tensor. - - Returns: - Tensor: Output tensor. - - NOTES: - The input and output are two different Tensor, and in the - following process, you should use the output Tensor but not the input, - otherwise, the print layer doesn't have backward. - - Examples: - .. code-block:: python - - import paddle - - paddle.enable_static() - - x = paddle.full(shape=[2, 3], fill_value=3, dtype='int64') - out = paddle.static.Print(x, message="The content of input layer:") - - main_program = paddle.static.default_main_program() - exe = paddle.static.Executor(place=paddle.CPUPlace()) - res = exe.run(main_program, fetch_list=[out]) - # Variable: fill_constant_1.tmp_0 - # - message: The content of input layer: - # - lod: {} - # - place: CPUPlace - # - shape: [2, 3] - # - layout: NCHW - # - dtype: long - # - data: [3 3 3 3 3 3] - ''' - check_variable_and_dtype( - input, - 'input', - ['float32', 'float64', 'int32', 'int64', 'bool'], - 'fluid.layers.Print', - ) - - helper = LayerHelper('print' + "_" + input.name, **locals()) - output = helper.create_variable_for_type_inference(input.dtype) - helper.append_op( - type='print', - inputs={'In': input}, - outputs={'Out': output}, - attrs={ - 'first_n': first_n, - 'summarize': summarize, - 'message': message or "", - 'print_tensor_name': print_tensor_name, - 'print_tensor_type': print_tensor_type, - 'print_tensor_shape': print_tensor_shape, - 'print_tensor_layout': print_tensor_layout, - 'print_tensor_lod': print_tensor_lod, - 'print_phase': print_phase.upper(), - }, - ) - return output - - # (TODO: Mine) There exists dependency. It will be removed later. class BlockGuard: """ @@ -1512,6 +1413,7 @@ def expand_undefined_var(nest1, nest2, names): return nest1_out, nest2_out +# TODO: It will be deleted later. class Switch: """ :api_attr: Static Graph diff --git a/python/paddle/fluid/layers/learning_rate_scheduler.py b/python/paddle/fluid/layers/learning_rate_scheduler.py index 6d86fa9448c..65fe1d1e77b 100644 --- a/python/paddle/fluid/layers/learning_rate_scheduler.py +++ b/python/paddle/fluid/layers/learning_rate_scheduler.py @@ -359,9 +359,11 @@ def polynomial_decay( shape=[1], dtype='float32', value=1.0 ) - with control_flow.Switch() as switch: - with switch.case(global_step == zero_var): - paddle.assign(one_var, output=div_res) + div_val = paddle.static.nn.cond( + global_step == zero_var, lambda: one_var, lambda: div_res + ) + paddle.assign(div_val, output=div_res) + decay_steps = decay_steps * div_res else: decay_steps_var = paddle.tensor.fill_constant( @@ -432,7 +434,7 @@ def piecewise_decay(boundaries, values): persistable=True, name="learning_rate", ) - + # TODO: fluid.layers.control_flow.Switch should be replaced by paddle.static.nn.case(or cond) if possible with control_flow.Switch() as switch: for i in range(len(boundaries)): boundary_val = paddle.tensor.fill_constant( @@ -455,7 +457,6 @@ def piecewise_decay(boundaries, values): value=float(values[len(values) - 1]), out=lr, ) - return lr @@ -589,17 +590,19 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr): ) global_step = _decay_step_counter() - - with control_flow.Switch() as switch: - with switch.case(global_step < warmup_steps): - decayed_lr = start_lr + linear_step * ( - global_step / float(warmup_steps) + if not isinstance(learning_rate, Variable): + learning_rate = paddle.tensor.fill_constant( + shape=[1], dtype=dtype, value=float(learning_rate) + ) + lr_val = paddle.static.nn.case( + pred_fn_pairs=[ + ( + global_step < warmup_steps, + lambda: start_lr + + linear_step * (global_step / float(warmup_steps)), ) - paddle.assign(decayed_lr, lr) - with switch.default(): - if not isinstance(learning_rate, Variable): - learning_rate = paddle.tensor.fill_constant( - shape=[1], dtype=dtype, value=float(learning_rate) - ) - paddle.assign(learning_rate, lr) + ], + default=lambda: learning_rate, + ) + paddle.assign(lr_val, lr) return lr diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 1dca88d61e3..f6bd5dbd37c 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4373,12 +4373,12 @@ class ExponentialMovingAverage: ema = block._clone_variable(self._ema_vars[param.name]) paddle.assign(param, output=tmp) # bias correction - with layers.control_flow.Switch() as switch: - with switch.case(global_step > 0): - paddle.assign(ema / (1.0 - decay_pow), output=param) - with switch.default(): - paddle.assign(ema, output=param) - + param_val = paddle.static.nn.cond( + global_step > 0, + lambda: ema / (1.0 - decay_pow), + lambda: ema, + ) + paddle.assign(param_val, output=param) self.restore_program = Program() block = self.restore_program.global_block() with program_guard(main_program=self.restore_program): @@ -4399,13 +4399,12 @@ class ExponentialMovingAverage: if self._thres_steps is not None: decay_t = (self._thres_steps + 1.0) / (self._thres_steps + 10.0) - with layers.control_flow.Switch() as switch: - with switch.case(decay_t < self._decay): - paddle.assign(decay_t, decay_var) - with switch.default(): - paddle.assign( - np.array([self._decay], dtype=np.float32), decay_var - ) + decay_val = paddle.static.nn.cond( + decay_t < self._decay, + lambda: decay_t, + lambda: np.array([self._decay], dtype=np.float32), + ) + paddle.assign(decay_val, decay_var) return decay_var def _get_decay_pow(self, block): @@ -7408,26 +7407,30 @@ class LookaheadOptimizer: ) mod = paddle.remainder(step, k) - with layers.control_flow.Switch() as switch: - with switch.case(step == one_var): - for param_name in params: - fast_var = main_block.var(param_name) - slow_var = param_to_slow[param_name] - paddle.assign(fast_var, output=slow_var) - with switch.case(mod == zero_var): - for param_name in params: - fast_var = main_block.var(param_name) - slow_var = param_to_slow[param_name] - tmp_var = paddle.add( - paddle.multiply(fast_var, alpha), - paddle.multiply( - slow_var, paddle.subtract(one_var, alpha) - ), - ) - paddle.assign(tmp_var, output=slow_var) - paddle.assign(tmp_var, output=fast_var) - with switch.default(): - pass + for param_name in params: + fast_var = main_block.var(param_name) + slow_var = param_to_slow[param_name] + tmp_var = paddle.add( + paddle.multiply(fast_var, alpha), + paddle.multiply(slow_var, paddle.subtract(one_var, alpha)), + ) + slow_val = paddle.static.nn.case( + [ + (step == one_var, lambda: fast_var), + (mod == zero_var, lambda: tmp_var), + ], + default=lambda: slow_var, + ) + paddle.assign(slow_val, slow_var) + + fast_val = paddle.static.nn.case( + [ + (mod == zero_var, lambda: tmp_var), + ], + default=lambda: fast_var, + ) + paddle.assign(fast_val, fast_var) + return mini_out diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_heter_pipeline_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_heter_pipeline_ctr.py index 5dd23f13525..a5010e275aa 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_heter_pipeline_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_heter_pipeline_ctr.py @@ -126,7 +126,7 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase): input=predict, label=label, reduction='none', use_softmax=False ) avg_cost = paddle.mean(x=cost) - fluid.layers.Print(avg_cost, message="avg_cost") + paddle.static.Print(avg_cost, message="avg_cost") self.feeds = datas self.train_file_path = ["fake1", "fake2"] diff --git a/python/paddle/fluid/tests/unittests/test_switch.py b/python/paddle/fluid/tests/unittests/test_switch.py index d5d118867b1..428e5537f8b 100644 --- a/python/paddle/fluid/tests/unittests/test_switch.py +++ b/python/paddle/fluid/tests/unittests/test_switch.py @@ -15,7 +15,7 @@ import unittest import paddle -from paddle.fluid import core, framework, layers +from paddle.fluid import core, framework from paddle.fluid.executor import Executor from paddle.fluid.framework import default_startup_program @@ -40,15 +40,15 @@ class TestSwitch(unittest.TestCase): shape=[1], value=-1.0, dtype='float32', persistable=True ) - with layers.Switch() as switch: - with switch.case(paddle.less_than(x, zero_var)): - paddle.assign(zero_var, result) - with switch.case(paddle.less_than(x, one_var)): - paddle.assign(one_var, result) - with switch.case(paddle.less_than(x, two_var)): - paddle.assign(two_var, result) - with switch.default(): - paddle.assign(three_var, result) + res = paddle.static.nn.case( + pred_fn_pairs=[ + (paddle.less_than(x, zero_var), lambda: zero_var), + (paddle.less_than(x, one_var), lambda: one_var), + (paddle.less_than(x, two_var), lambda: two_var), + ], + default=lambda: three_var, + ) + paddle.assign(res, result) cpu = core.CPUPlace() exe = Executor(cpu) @@ -85,17 +85,19 @@ class TestSwitchCaseError(unittest.TestCase): # 1. The type of 'condition' in case must be Variable. def test_condition_type(): - with layers.Switch() as switch: - with switch.case(1): - paddle.assign(zero_var, result) + res = paddle.static.nn.case( + [(1, lambda: zero_var)], default=lambda: result + ) + paddle.assign(res, result) self.assertRaises(TypeError, test_condition_type) # 2. The dtype of 'condition' in case must be 'bool'. def test_condition_dtype(): - with layers.Switch() as switch: - with switch.case(cond): - paddle.assign(zero_var, result) + res = paddle.static.nn.case( + [cond, lambda: zero_var], default=lambda: result + ) + paddle.assign(res, result) self.assertRaises(TypeError, test_condition_dtype) diff --git a/python/paddle/jit/dy2static/convert_operators.py b/python/paddle/jit/dy2static/convert_operators.py index ad9abcc9849..52d6b7cb854 100644 --- a/python/paddle/jit/dy2static/convert_operators.py +++ b/python/paddle/jit/dy2static/convert_operators.py @@ -21,7 +21,7 @@ from paddle.fluid.dygraph.base import ( in_declarative_mode, ) from paddle.fluid.framework import Variable, core -from paddle.fluid.layers import Print, control_flow +from paddle.fluid.layers import control_flow from paddle.fluid.layers.control_flow import while_loop from .utils import ( @@ -749,7 +749,7 @@ def convert_print(*objects, sep=' ', end='\n', file=None, flush=False): """ for obj in objects: if isinstance(obj, Variable): - Print(obj) + paddle.static.Print(obj) print(*objects, sep=sep, end=end, file=file, flush=flush) diff --git a/python/paddle/nn/layer/rnn.py b/python/paddle/nn/layer/rnn.py index cc8ab648b88..2a0c9157a7a 100644 --- a/python/paddle/nn/layer/rnn.py +++ b/python/paddle/nn/layer/rnn.py @@ -299,13 +299,11 @@ def _rnn_static_graph( pre_state = paddle.utils.map_structure( lambda x: paddle.tensor.array_read(x, start_i), init_array ) - # pre_state = paddle.fluid.layers.Print( pre_state, message="pre") outputs, new_states = cell(step_in, pre_state, **kwargs) assert isinstance(outputs, paddle.fluid.framework.Variable) paddle.utils.assert_same_structure(new_states, pre_state) if sequence_length: step_mask = paddle.unsqueeze(mask[start_i], 1) - # paddle.fluid.layers.Print( step_mask, message="mask") # new_states = map_structure( # partial(_maybe_copy, step_mask=step_mask), # pre_state, new_states diff --git a/python/paddle/reader/decorator.py b/python/paddle/reader/decorator.py index e5c47ebdb34..bd40c4553e8 100644 --- a/python/paddle/reader/decorator.py +++ b/python/paddle/reader/decorator.py @@ -553,8 +553,9 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000): with fluid.program_guard(fluid.Program(), fluid.Program()): place = fluid.CPUPlace() # the 1st 2 is batch size + image = paddle.static.data(name='image', dtype='int64', shape=[2, 1, 2]) - fluid.layers.Print(image) + paddle.static.Print(image) # print detailed tensor info of image variable reader = fluid.io.PyReader(feed_list=[image], capacity=2) diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index f63971b966a..084579a58e5 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -65,7 +65,7 @@ from ..fluid.framework import Operator # noqa: F401 from ..fluid.framework import Parameter # noqa: F401 from ..fluid.framework import ipu_shard_guard # noqa: F401 from ..fluid.framework import set_ipu_shard # noqa: F401 -from ..fluid.layers.control_flow import Print # noqa: F401 +from .nn.control_flow import Print # noqa: F401 from ..fluid.param_attr import WeightNormParamAttr # noqa: F401 from ..fluid.optimizer import Optimizer # noqa: F401 from ..fluid.optimizer import Adam # noqa: F401 diff --git a/python/paddle/static/nn/control_flow.py b/python/paddle/static/nn/control_flow.py index c5d52654775..bc5f1d2d5d6 100644 --- a/python/paddle/static/nn/control_flow.py +++ b/python/paddle/static/nn/control_flow.py @@ -24,7 +24,7 @@ from paddle.common_ops_import import ( convert_dtype, ) from paddle.fluid import core -from paddle.fluid.framework import Operator, Program, Variable +from paddle.fluid.framework import Operator, Program, Variable, static_only # Temporary solution, it will be deleted later from paddle.fluid.layers.control_flow import ConditionalBlock, select_input @@ -1329,3 +1329,101 @@ def change_none_to_undefinedvar(nest1, nest2): nest1_out = pack_sequence_as(nest1, list(map(map_fn, flatten(nest1)))) nest2_out = pack_sequence_as(nest2, list(map(map_fn, flatten(nest2)))) return nest1_out, nest2_out + + +@static_only +def Print( + input, + first_n=-1, + message=None, + summarize=20, + print_tensor_name=True, + print_tensor_type=True, + print_tensor_shape=True, + print_tensor_layout=True, + print_tensor_lod=True, + print_phase='both', +): + ''' + :api_attr: Static Graph + + **Print operator** + + This creates a print op that will print when a tensor is accessed. + + Wraps the tensor passed in so that whenever that a tensor is accessed, + the message `message` is printed, along with the current value of the + tensor `t`. + + Args: + input (Tensor): A Tensor to print. + first_n (int, optional): Only log `first_n` number of times. Default: -1. + message (str, optional): A string message to print as a prefix. Default: None. + summarize (int, optional): Number of elements in the tensor to be print. If + it's value is -1, then all elements in the tensor will be print. + print_tensor_name (bool, optional): Print the tensor name. Default: True. + print_tensor_type (bool, optional): Print the tensor type. Defaultt: True. + print_tensor_shape (bool, optional): Print the tensor shape. Default: True. + print_tensor_layout (bool, optional): Print the tensor layout. Default: True. + print_tensor_lod (bool, optional): Print the tensor lod. Default: True. + print_phase (str, optional): Which phase to displace, including 'forward', + 'backward' and 'both'. Default: 'both'. If set to 'backward', will + only print the gradients of input tensor; If set to 'both', will + both print the input tensor itself and the gradients of input tensor. + + Returns: + Tensor: Output tensor. + + NOTES: + The input and output are two different Tensor, and in the + following process, you should use the output Tensor but not the input, + otherwise, the print layer doesn't have backward. + + Examples: + .. code-block:: python + + import paddle + + paddle.enable_static() + + x = paddle.full(shape=[2, 3], fill_value=3, dtype='int64') + out = paddle.static.Print(x, message="The content of input layer:") + + main_program = paddle.static.default_main_program() + exe = paddle.static.Executor(place=paddle.CPUPlace()) + res = exe.run(main_program, fetch_list=[out]) + # Variable: fill_constant_1.tmp_0 + # - message: The content of input layer: + # - lod: {} + # - place: CPUPlace + # - shape: [2, 3] + # - layout: NCHW + # - dtype: long + # - data: [3 3 3 3 3 3] + ''' + check_variable_and_dtype( + input, + 'input', + ['float32', 'float64', 'int32', 'int64', 'bool'], + 'paddle.static.Print', + ) + + helper = LayerHelper('print' + "_" + input.name, **locals()) + output = helper.create_variable_for_type_inference(input.dtype) + helper.append_op( + type='print', + inputs={'In': input}, + outputs={'Out': output}, + attrs={ + 'first_n': first_n, + 'summarize': summarize, + 'message': message or "", + 'print_tensor_name': print_tensor_name, + 'print_tensor_type': print_tensor_type, + 'print_tensor_shape': print_tensor_shape, + 'print_tensor_layout': print_tensor_layout, + 'print_tensor_lod': print_tensor_lod, + 'print_phase': print_phase.upper(), + }, + ) + return output diff --git a/test/contrib/test_multi_precision_fp16_train.py b/test/contrib/test_multi_precision_fp16_train.py index a364d2161eb..218cfcd542d 100644 --- a/test/contrib/test_multi_precision_fp16_train.py +++ b/test/contrib/test_multi_precision_fp16_train.py @@ -295,12 +295,10 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase): one_var = paddle.tensor.fill_constant( shape=[1], dtype='int64', value=1 ) - with fluid.layers.control_flow.Switch() as switch: - with switch.case(label != zero_var): - paddle.assign(zero_var, output=label) - with switch.default(): - paddle.assign(one_var, output=label) - + label_val = paddle.static.nn.cond( + label != zero_var, lambda: zero_var, lambda: one_var + ) + paddle.assign(label_val, output=label) net = resnet_cifar10(image) logits = paddle.static.nn.fc( x=net, size=10, activation="softmax" diff --git a/test/ipu/test_print_op_ipu.py b/test/ipu/test_print_op_ipu.py index 358866ee0a8..10449cd48ae 100644 --- a/test/ipu/test_print_op_ipu.py +++ b/test/ipu/test_print_op_ipu.py @@ -55,7 +55,7 @@ class TestBase(IPUOpTest): dtype=self.feed_dtype[0], ) out = paddle.static.nn.conv2d(x, num_filters=3, filter_size=3) - out = paddle.fluid.layers.Print(out, **self.attrs) + out = paddle.static.Print(out, **self.attrs) if self.is_training: loss = paddle.mean(out) -- GitLab