未验证 提交 ed45ecc6 编写于 作者: Q qizhaoaoe 提交者: GitHub

[fluid clean] remove Print. (#51778)

* fluid clean: remove print/switch from fluid to static

* remove Switch in static.__init__

* fix conflicts.

* replace Switch by case.

* fix piecewise_lr decay.

* fix typo

* fix conflicts.

* fix lr dtype

* keep Switch in paddle.static.nn.control_flow and fix piecewise_lr.

* fix conflicts.

* keep Switch in the fluid.

* fix Switch doc

* fix example in Switch doc

* fix Switch doc.

* fix static/__init__.
上级 bddeecd1
......@@ -50,7 +50,6 @@ from paddle import _C_ops, _legacy_C_ops
__all__ = [
'Switch',
'StaticRNN',
'Print',
'while_loop',
]
......@@ -141,104 +140,6 @@ def select_input(inputs, mask):
return out
@static_only
def Print(
input,
first_n=-1,
message=None,
summarize=20,
print_tensor_name=True,
print_tensor_type=True,
print_tensor_shape=True,
print_tensor_layout=True,
print_tensor_lod=True,
print_phase='both',
):
'''
:api_attr: Static Graph
**Print operator**
This creates a print op that will print when a tensor is accessed.
Wraps the tensor passed in so that whenever that a tensor is accessed,
the message `message` is printed, along with the current value of the
tensor `t`.
Args:
input (Tensor): A Tensor to print.
first_n (int, optional): Only log `first_n` number of times. Default: -1.
message (str, optional): A string message to print as a prefix. Default: None.
summarize (int, optional): Number of elements in the tensor to be print. If
it's value is -1, then all elements in the tensor will be print.
print_tensor_name (bool, optional): Print the tensor name. Default: True.
print_tensor_type (bool, optional): Print the tensor type. Defaultt: True.
print_tensor_shape (bool, optional): Print the tensor shape. Default: True.
print_tensor_layout (bool, optional): Print the tensor layout. Default: True.
print_tensor_lod (bool, optional): Print the tensor lod. Default: True.
print_phase (str, optional): Which phase to displace, including 'forward',
'backward' and 'both'. Default: 'both'. If set to 'backward', will
only print the gradients of input tensor; If set to 'both', will
both print the input tensor itself and the gradients of input tensor.
Returns:
Tensor: Output tensor.
NOTES:
The input and output are two different Tensor, and in the
following process, you should use the output Tensor but not the input,
otherwise, the print layer doesn't have backward.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
x = paddle.full(shape=[2, 3], fill_value=3, dtype='int64')
out = paddle.static.Print(x, message="The content of input layer:")
main_program = paddle.static.default_main_program()
exe = paddle.static.Executor(place=paddle.CPUPlace())
res = exe.run(main_program, fetch_list=[out])
# Variable: fill_constant_1.tmp_0
# - message: The content of input layer:
# - lod: {}
# - place: CPUPlace
# - shape: [2, 3]
# - layout: NCHW
# - dtype: long
# - data: [3 3 3 3 3 3]
'''
check_variable_and_dtype(
input,
'input',
['float32', 'float64', 'int32', 'int64', 'bool'],
'fluid.layers.Print',
)
helper = LayerHelper('print' + "_" + input.name, **locals())
output = helper.create_variable_for_type_inference(input.dtype)
helper.append_op(
type='print',
inputs={'In': input},
outputs={'Out': output},
attrs={
'first_n': first_n,
'summarize': summarize,
'message': message or "",
'print_tensor_name': print_tensor_name,
'print_tensor_type': print_tensor_type,
'print_tensor_shape': print_tensor_shape,
'print_tensor_layout': print_tensor_layout,
'print_tensor_lod': print_tensor_lod,
'print_phase': print_phase.upper(),
},
)
return output
# (TODO: Mine) There exists dependency. It will be removed later.
class BlockGuard:
"""
......@@ -1512,6 +1413,7 @@ def expand_undefined_var(nest1, nest2, names):
return nest1_out, nest2_out
# TODO: It will be deleted later.
class Switch:
"""
:api_attr: Static Graph
......
......@@ -359,9 +359,11 @@ def polynomial_decay(
shape=[1], dtype='float32', value=1.0
)
with control_flow.Switch() as switch:
with switch.case(global_step == zero_var):
paddle.assign(one_var, output=div_res)
div_val = paddle.static.nn.cond(
global_step == zero_var, lambda: one_var, lambda: div_res
)
paddle.assign(div_val, output=div_res)
decay_steps = decay_steps * div_res
else:
decay_steps_var = paddle.tensor.fill_constant(
......@@ -432,7 +434,7 @@ def piecewise_decay(boundaries, values):
persistable=True,
name="learning_rate",
)
# TODO: fluid.layers.control_flow.Switch should be replaced by paddle.static.nn.case(or cond) if possible
with control_flow.Switch() as switch:
for i in range(len(boundaries)):
boundary_val = paddle.tensor.fill_constant(
......@@ -455,7 +457,6 @@ def piecewise_decay(boundaries, values):
value=float(values[len(values) - 1]),
out=lr,
)
return lr
......@@ -589,17 +590,19 @@ def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
)
global_step = _decay_step_counter()
with control_flow.Switch() as switch:
with switch.case(global_step < warmup_steps):
decayed_lr = start_lr + linear_step * (
global_step / float(warmup_steps)
if not isinstance(learning_rate, Variable):
learning_rate = paddle.tensor.fill_constant(
shape=[1], dtype=dtype, value=float(learning_rate)
)
lr_val = paddle.static.nn.case(
pred_fn_pairs=[
(
global_step < warmup_steps,
lambda: start_lr
+ linear_step * (global_step / float(warmup_steps)),
)
paddle.assign(decayed_lr, lr)
with switch.default():
if not isinstance(learning_rate, Variable):
learning_rate = paddle.tensor.fill_constant(
shape=[1], dtype=dtype, value=float(learning_rate)
)
paddle.assign(learning_rate, lr)
],
default=lambda: learning_rate,
)
paddle.assign(lr_val, lr)
return lr
......@@ -4373,12 +4373,12 @@ class ExponentialMovingAverage:
ema = block._clone_variable(self._ema_vars[param.name])
paddle.assign(param, output=tmp)
# bias correction
with layers.control_flow.Switch() as switch:
with switch.case(global_step > 0):
paddle.assign(ema / (1.0 - decay_pow), output=param)
with switch.default():
paddle.assign(ema, output=param)
param_val = paddle.static.nn.cond(
global_step > 0,
lambda: ema / (1.0 - decay_pow),
lambda: ema,
)
paddle.assign(param_val, output=param)
self.restore_program = Program()
block = self.restore_program.global_block()
with program_guard(main_program=self.restore_program):
......@@ -4399,13 +4399,12 @@ class ExponentialMovingAverage:
if self._thres_steps is not None:
decay_t = (self._thres_steps + 1.0) / (self._thres_steps + 10.0)
with layers.control_flow.Switch() as switch:
with switch.case(decay_t < self._decay):
paddle.assign(decay_t, decay_var)
with switch.default():
paddle.assign(
np.array([self._decay], dtype=np.float32), decay_var
)
decay_val = paddle.static.nn.cond(
decay_t < self._decay,
lambda: decay_t,
lambda: np.array([self._decay], dtype=np.float32),
)
paddle.assign(decay_val, decay_var)
return decay_var
def _get_decay_pow(self, block):
......@@ -7408,26 +7407,30 @@ class LookaheadOptimizer:
)
mod = paddle.remainder(step, k)
with layers.control_flow.Switch() as switch:
with switch.case(step == one_var):
for param_name in params:
fast_var = main_block.var(param_name)
slow_var = param_to_slow[param_name]
paddle.assign(fast_var, output=slow_var)
with switch.case(mod == zero_var):
for param_name in params:
fast_var = main_block.var(param_name)
slow_var = param_to_slow[param_name]
tmp_var = paddle.add(
paddle.multiply(fast_var, alpha),
paddle.multiply(
slow_var, paddle.subtract(one_var, alpha)
),
)
paddle.assign(tmp_var, output=slow_var)
paddle.assign(tmp_var, output=fast_var)
with switch.default():
pass
for param_name in params:
fast_var = main_block.var(param_name)
slow_var = param_to_slow[param_name]
tmp_var = paddle.add(
paddle.multiply(fast_var, alpha),
paddle.multiply(slow_var, paddle.subtract(one_var, alpha)),
)
slow_val = paddle.static.nn.case(
[
(step == one_var, lambda: fast_var),
(mod == zero_var, lambda: tmp_var),
],
default=lambda: slow_var,
)
paddle.assign(slow_val, slow_var)
fast_val = paddle.static.nn.case(
[
(mod == zero_var, lambda: tmp_var),
],
default=lambda: fast_var,
)
paddle.assign(fast_val, fast_var)
return mini_out
......
......@@ -126,7 +126,7 @@ class TestHeterPipelinePsCTR2x2(FleetDistHeterRunnerBase):
input=predict, label=label, reduction='none', use_softmax=False
)
avg_cost = paddle.mean(x=cost)
fluid.layers.Print(avg_cost, message="avg_cost")
paddle.static.Print(avg_cost, message="avg_cost")
self.feeds = datas
self.train_file_path = ["fake1", "fake2"]
......
......@@ -15,7 +15,7 @@
import unittest
import paddle
from paddle.fluid import core, framework, layers
from paddle.fluid import core, framework
from paddle.fluid.executor import Executor
from paddle.fluid.framework import default_startup_program
......@@ -40,15 +40,15 @@ class TestSwitch(unittest.TestCase):
shape=[1], value=-1.0, dtype='float32', persistable=True
)
with layers.Switch() as switch:
with switch.case(paddle.less_than(x, zero_var)):
paddle.assign(zero_var, result)
with switch.case(paddle.less_than(x, one_var)):
paddle.assign(one_var, result)
with switch.case(paddle.less_than(x, two_var)):
paddle.assign(two_var, result)
with switch.default():
paddle.assign(three_var, result)
res = paddle.static.nn.case(
pred_fn_pairs=[
(paddle.less_than(x, zero_var), lambda: zero_var),
(paddle.less_than(x, one_var), lambda: one_var),
(paddle.less_than(x, two_var), lambda: two_var),
],
default=lambda: three_var,
)
paddle.assign(res, result)
cpu = core.CPUPlace()
exe = Executor(cpu)
......@@ -85,17 +85,19 @@ class TestSwitchCaseError(unittest.TestCase):
# 1. The type of 'condition' in case must be Variable.
def test_condition_type():
with layers.Switch() as switch:
with switch.case(1):
paddle.assign(zero_var, result)
res = paddle.static.nn.case(
[(1, lambda: zero_var)], default=lambda: result
)
paddle.assign(res, result)
self.assertRaises(TypeError, test_condition_type)
# 2. The dtype of 'condition' in case must be 'bool'.
def test_condition_dtype():
with layers.Switch() as switch:
with switch.case(cond):
paddle.assign(zero_var, result)
res = paddle.static.nn.case(
[cond, lambda: zero_var], default=lambda: result
)
paddle.assign(res, result)
self.assertRaises(TypeError, test_condition_dtype)
......
......@@ -21,7 +21,7 @@ from paddle.fluid.dygraph.base import (
in_declarative_mode,
)
from paddle.fluid.framework import Variable, core
from paddle.fluid.layers import Print, control_flow
from paddle.fluid.layers import control_flow
from paddle.fluid.layers.control_flow import while_loop
from .utils import (
......@@ -749,7 +749,7 @@ def convert_print(*objects, sep=' ', end='\n', file=None, flush=False):
"""
for obj in objects:
if isinstance(obj, Variable):
Print(obj)
paddle.static.Print(obj)
print(*objects, sep=sep, end=end, file=file, flush=flush)
......
......@@ -299,13 +299,11 @@ def _rnn_static_graph(
pre_state = paddle.utils.map_structure(
lambda x: paddle.tensor.array_read(x, start_i), init_array
)
# pre_state = paddle.fluid.layers.Print( pre_state, message="pre")
outputs, new_states = cell(step_in, pre_state, **kwargs)
assert isinstance(outputs, paddle.fluid.framework.Variable)
paddle.utils.assert_same_structure(new_states, pre_state)
if sequence_length:
step_mask = paddle.unsqueeze(mask[start_i], 1)
# paddle.fluid.layers.Print( step_mask, message="mask")
# new_states = map_structure(
# partial(_maybe_copy, step_mask=step_mask),
# pre_state, new_states
......
......@@ -553,8 +553,9 @@ def multiprocess_reader(readers, use_pipe=True, queue_size=1000):
with fluid.program_guard(fluid.Program(), fluid.Program()):
place = fluid.CPUPlace()
# the 1st 2 is batch size
image = paddle.static.data(name='image', dtype='int64', shape=[2, 1, 2])
fluid.layers.Print(image)
paddle.static.Print(image)
# print detailed tensor info of image variable
reader = fluid.io.PyReader(feed_list=[image], capacity=2)
......
......@@ -65,7 +65,7 @@ from ..fluid.framework import Operator # noqa: F401
from ..fluid.framework import Parameter # noqa: F401
from ..fluid.framework import ipu_shard_guard # noqa: F401
from ..fluid.framework import set_ipu_shard # noqa: F401
from ..fluid.layers.control_flow import Print # noqa: F401
from .nn.control_flow import Print # noqa: F401
from ..fluid.param_attr import WeightNormParamAttr # noqa: F401
from ..fluid.optimizer import Optimizer # noqa: F401
from ..fluid.optimizer import Adam # noqa: F401
......
......@@ -24,7 +24,7 @@ from paddle.common_ops_import import (
convert_dtype,
)
from paddle.fluid import core
from paddle.fluid.framework import Operator, Program, Variable
from paddle.fluid.framework import Operator, Program, Variable, static_only
# Temporary solution, it will be deleted later
from paddle.fluid.layers.control_flow import ConditionalBlock, select_input
......@@ -1329,3 +1329,101 @@ def change_none_to_undefinedvar(nest1, nest2):
nest1_out = pack_sequence_as(nest1, list(map(map_fn, flatten(nest1))))
nest2_out = pack_sequence_as(nest2, list(map(map_fn, flatten(nest2))))
return nest1_out, nest2_out
@static_only
def Print(
input,
first_n=-1,
message=None,
summarize=20,
print_tensor_name=True,
print_tensor_type=True,
print_tensor_shape=True,
print_tensor_layout=True,
print_tensor_lod=True,
print_phase='both',
):
'''
:api_attr: Static Graph
**Print operator**
This creates a print op that will print when a tensor is accessed.
Wraps the tensor passed in so that whenever that a tensor is accessed,
the message `message` is printed, along with the current value of the
tensor `t`.
Args:
input (Tensor): A Tensor to print.
first_n (int, optional): Only log `first_n` number of times. Default: -1.
message (str, optional): A string message to print as a prefix. Default: None.
summarize (int, optional): Number of elements in the tensor to be print. If
it's value is -1, then all elements in the tensor will be print.
print_tensor_name (bool, optional): Print the tensor name. Default: True.
print_tensor_type (bool, optional): Print the tensor type. Defaultt: True.
print_tensor_shape (bool, optional): Print the tensor shape. Default: True.
print_tensor_layout (bool, optional): Print the tensor layout. Default: True.
print_tensor_lod (bool, optional): Print the tensor lod. Default: True.
print_phase (str, optional): Which phase to displace, including 'forward',
'backward' and 'both'. Default: 'both'. If set to 'backward', will
only print the gradients of input tensor; If set to 'both', will
both print the input tensor itself and the gradients of input tensor.
Returns:
Tensor: Output tensor.
NOTES:
The input and output are two different Tensor, and in the
following process, you should use the output Tensor but not the input,
otherwise, the print layer doesn't have backward.
Examples:
.. code-block:: python
import paddle
paddle.enable_static()
x = paddle.full(shape=[2, 3], fill_value=3, dtype='int64')
out = paddle.static.Print(x, message="The content of input layer:")
main_program = paddle.static.default_main_program()
exe = paddle.static.Executor(place=paddle.CPUPlace())
res = exe.run(main_program, fetch_list=[out])
# Variable: fill_constant_1.tmp_0
# - message: The content of input layer:
# - lod: {}
# - place: CPUPlace
# - shape: [2, 3]
# - layout: NCHW
# - dtype: long
# - data: [3 3 3 3 3 3]
'''
check_variable_and_dtype(
input,
'input',
['float32', 'float64', 'int32', 'int64', 'bool'],
'paddle.static.Print',
)
helper = LayerHelper('print' + "_" + input.name, **locals())
output = helper.create_variable_for_type_inference(input.dtype)
helper.append_op(
type='print',
inputs={'In': input},
outputs={'Out': output},
attrs={
'first_n': first_n,
'summarize': summarize,
'message': message or "",
'print_tensor_name': print_tensor_name,
'print_tensor_type': print_tensor_type,
'print_tensor_shape': print_tensor_shape,
'print_tensor_layout': print_tensor_layout,
'print_tensor_lod': print_tensor_lod,
'print_phase': print_phase.upper(),
},
)
return output
......@@ -295,12 +295,10 @@ class TestAmpWithNonIterableDataLoader(unittest.TestCase):
one_var = paddle.tensor.fill_constant(
shape=[1], dtype='int64', value=1
)
with fluid.layers.control_flow.Switch() as switch:
with switch.case(label != zero_var):
paddle.assign(zero_var, output=label)
with switch.default():
paddle.assign(one_var, output=label)
label_val = paddle.static.nn.cond(
label != zero_var, lambda: zero_var, lambda: one_var
)
paddle.assign(label_val, output=label)
net = resnet_cifar10(image)
logits = paddle.static.nn.fc(
x=net, size=10, activation="softmax"
......
......@@ -55,7 +55,7 @@ class TestBase(IPUOpTest):
dtype=self.feed_dtype[0],
)
out = paddle.static.nn.conv2d(x, num_filters=3, filter_size=3)
out = paddle.fluid.layers.Print(out, **self.attrs)
out = paddle.static.Print(out, **self.attrs)
if self.is_training:
loss = paddle.mean(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册