未验证 提交 8bbb829a 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Support nested input and output (#24752)

* support nested input and output test=develop

* remove code of convert type(output) in unittest test=develop

* add warning test=develop
上级 a7e21cbe
...@@ -17,7 +17,6 @@ from __future__ import print_function ...@@ -17,7 +17,6 @@ from __future__ import print_function
import gast import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
......
...@@ -17,7 +17,6 @@ from __future__ import print_function ...@@ -17,7 +17,6 @@ from __future__ import print_function
import gast import gast
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
......
...@@ -23,9 +23,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe ...@@ -23,9 +23,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
......
...@@ -14,11 +14,75 @@ ...@@ -14,11 +14,75 @@
from __future__ import print_function from __future__ import print_function
import numpy as np import numpy as np
import logging
from paddle.fluid import log_helper
from paddle.fluid import framework, backward, core from paddle.fluid import framework, backward, core
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
import paddle.compat as cpt import paddle.compat as cpt
_logger = log_helper.get_logger(
__name__, logging.WARNING, fmt='%(asctime)s-%(levelname)s: %(message)s')
class NestSequence(object):
"""
A wrapper class that easily to flatten and restore the nest structure of
given sequence.
"""
def __init__(self, raw_input, need_check=False):
self.__raw_input = raw_input
self.__var_ids = self._get_var_ids()
self._check_non_variable(need_check)
def tolist(self):
"""
Flattens the nested sequences into single list.
"""
return flatten(self.__raw_input)
def restore(self, value_list):
"""
Restores the nested sequence from value list.
"""
assert len(self.tolist()) == len(value_list)
return pack_sequence_as(self.__raw_input, value_list)
def _get_var_ids(self):
var_ids = []
for idx, var in enumerate(self.tolist()):
if isinstance(var, (framework.Variable, core.VarBase)):
var_ids.append(idx)
return var_ids
def _check_non_variable(self, need_check):
"""
Raises warning if output of traced function contains non-tensor type values.
"""
if need_check:
warning_types = set()
for var in self.tolist():
if not isinstance(var, (framework.Variable, core.VarBase)):
warning_types.add(type(var))
if warning_types:
_logger.warning(
"Output of traced function contains non-tensor type values: {}. "
"Currently, We don't support to update them while training and will return "
"what we first saw. Please try to return them as tensor.".
format(list(warning_types)))
@property
def var_ids(self):
return self.__var_ids
def __getitem__(self, item):
return self.tolist()[item]
class PartialProgramLayer(layers.Layer): class PartialProgramLayer(layers.Layer):
""" """
...@@ -43,8 +107,8 @@ class PartialProgramLayer(layers.Layer): ...@@ -43,8 +107,8 @@ class PartialProgramLayer(layers.Layer):
def __init__(self, main_program, inputs, outputs, parameters=None): def __init__(self, main_program, inputs, outputs, parameters=None):
super(PartialProgramLayer, self).__init__() super(PartialProgramLayer, self).__init__()
self.inputs = inputs self._inputs = NestSequence(inputs)
self.outputs = outputs self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else [] self._params = parameters if parameters is not None else []
# Check all params from main program can be found in self._params: # Check all params from main program can be found in self._params:
# 1. parameter in self._params should be type `framework.ParamBase` which are created in dygraph. # 1. parameter in self._params should be type `framework.ParamBase` which are created in dygraph.
...@@ -65,7 +129,7 @@ class PartialProgramLayer(layers.Layer): ...@@ -65,7 +129,7 @@ class PartialProgramLayer(layers.Layer):
def _append_backward_desc(self): def _append_backward_desc(self):
program = self._infer_program.clone() program = self._infer_program.clone()
targets = [] targets = []
for out in self.outputs: for out in self._outputs.tolist():
if isinstance(out, framework.Variable): if isinstance(out, framework.Variable):
targets.append(program.global_block().var(out.name)) targets.append(program.global_block().var(out.name))
...@@ -101,37 +165,37 @@ class PartialProgramLayer(layers.Layer): ...@@ -101,37 +165,37 @@ class PartialProgramLayer(layers.Layer):
'is_test': not self.training 'is_test': not self.training
}) })
outs = out_vars return self._restore_out(out_vars)
if len(outs) == 1:
outs = outs[0]
return outs
def _prepare(self, inputs): def _prepare(self, inputs):
""" """
Prepare inputs, outputs, attrs. Prepare inputs, outputs, attrs.
""" """
assert isinstance(inputs, (tuple, list)) assert isinstance(inputs, (tuple, list))
# Flatten inputs with nested structure into single list.
flatten_inputs = flatten(inputs)
# Convert variable into VarBase and feed in training data. # Convert variable into VarBase and feed in training data.
input_vars = [] input_vars = []
for i, value in enumerate(inputs): for i, value in enumerate(flatten_inputs):
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
var = core.VarBase( var = core.VarBase(
value=value, value=value,
name=self.inputs[i].desc.name(), name=self._inputs[i].desc.name(),
persistable=False, persistable=False,
place=framework._current_expected_place(), place=framework._current_expected_place(),
zero_copy=True) zero_copy=True)
elif isinstance(value, core.VarBase): elif isinstance(value, core.VarBase):
var = value var = value
var.name = self.inputs[i].desc.name() var.name = self._inputs[i].desc.name()
else: else:
continue continue
input_vars.append(var) input_vars.append(var)
# Create VarBase to receive output data. # Create VarBase to receive output data.
out_vars = [] out_vars = []
for var in self.outputs: for idx in self._outputs.var_ids:
if not isinstance(var, framework.Variable): var = self._outputs[idx]
continue assert isinstance(var, framework.Variable)
var_desc = var.desc var_desc = var.desc
var_base = core.VarBase(var_desc.dtype(), var_base = core.VarBase(var_desc.dtype(),
var_desc.shape(), var_desc.shape(),
...@@ -147,6 +211,20 @@ class PartialProgramLayer(layers.Layer): ...@@ -147,6 +211,20 @@ class PartialProgramLayer(layers.Layer):
return input_vars, out_vars, tmp_scope_vec return input_vars, out_vars, tmp_scope_vec
def _restore_out(self, out_vars):
"""
Restores same nested outputs by only replacing the Variable with VarBase.
"""
flatten_outputs = self._outputs.tolist()
for i, idx in enumerate(self._outputs.var_ids):
flatten_outputs[idx] = out_vars[i]
outs = self._outputs.restore(flatten_outputs)
if len(outs) == 1:
outs = outs[0]
return outs
def _set_grad_type(self, params): def _set_grad_type(self, params):
# NOTE: if user set sparse gradient mode, the param's gradient # NOTE: if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just # will be SelectedRows, not LoDTensor. But tracer will just
......
...@@ -25,6 +25,8 @@ from paddle.fluid import framework ...@@ -25,6 +25,8 @@ from paddle.fluid import framework
from paddle.fluid import executor from paddle.fluid import executor
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
...@@ -108,7 +110,7 @@ class FunctionSpec(object): ...@@ -108,7 +110,7 @@ class FunctionSpec(object):
def to_static_inputs(self, main_program): def to_static_inputs(self, main_program):
inputs = [] inputs = []
block = main_program.global_block() block = main_program.global_block()
for input_var in self.args: for input_var in flatten(self.args):
if isinstance(input_var, np.ndarray): if isinstance(input_var, np.ndarray):
feed_layer = block.create_var( feed_layer = block.create_var(
name=unique_name.generate('feed'), name=unique_name.generate('feed'),
...@@ -127,7 +129,8 @@ class FunctionSpec(object): ...@@ -127,7 +129,8 @@ class FunctionSpec(object):
feed_layer = input_var feed_layer = input_var
inputs.append(feed_layer) inputs.append(feed_layer)
return inputs # Restores the nested structure as self.args
return pack_sequence_as(self.args, inputs)
@property @property
def dyfunc(self): def dyfunc(self):
...@@ -175,12 +178,12 @@ class ConcreteProgram(object): ...@@ -175,12 +178,12 @@ class ConcreteProgram(object):
of program as fetch_list. of program as fetch_list.
""" """
# Transforms dygraph function into static function and caches it. # Transforms dygraph function into static function and caches it.
dygaph_function = func_spec.dyfunc dygraph_function = func_spec.dyfunc
static_func = convert_function_with_cache(dygaph_function) static_func = convert_function_with_cache(dygraph_function)
main_program, startup_program = framework.Program(), framework.Program() main_program, startup_program = framework.Program(), framework.Program()
# Note: The random seed should be synchronized into cached program # Note: The random seed should be synchronized into cached program
# if set in `fluid.dygrap_guard` because some ops rely on it, such as # if set in `fluid.dygraph_guard` because some ops rely on it, such as
# `fluid.layers.dropout`. # `fluid.layers.dropout`.
main_program.random_seed = framework.default_main_program().random_seed main_program.random_seed = framework.default_main_program().random_seed
startup_program.random_seed = framework.default_startup_program( startup_program.random_seed = framework.default_startup_program(
...@@ -203,7 +206,7 @@ class ConcreteProgram(object): ...@@ -203,7 +206,7 @@ class ConcreteProgram(object):
inputs=inputs, inputs=inputs,
outputs=outputs, outputs=outputs,
parameters=all_parameters, parameters=all_parameters,
func=dygaph_function, func=dygraph_function,
main_program=main_program, main_program=main_program,
startup_program=startup_program) startup_program=startup_program)
......
...@@ -186,12 +186,9 @@ class TestTransform(TestTransformBase): ...@@ -186,12 +186,9 @@ class TestTransform(TestTransformBase):
if not isinstance(dy_outs, tuple): if not isinstance(dy_outs, tuple):
dy_outs = (dy_outs, ) dy_outs = (dy_outs, )
# NOTE: return type is difference
st_outs = self.get_static_output() st_outs = self.get_static_output()
if not isinstance(st_outs, list): if not isinstance(st_outs, tuple):
st_outs = (st_outs, ) st_outs = (st_outs, )
else:
st_outs = tuple(st_outs)
for x, y in zip(dy_outs, st_outs): for x, y in zip(dy_outs, st_outs):
self.assertTrue(np.allclose(x.numpy(), y.numpy())) self.assertTrue(np.allclose(x.numpy(), y.numpy()))
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph import declarative
import unittest
SEED = 2020
def nested_input(x, y):
sum_res = x + y[0]
z_elem = y[3]['z']
sub_res = z_elem[0] - z_elem[1]
mul_res = y[-1]['d']['da'] * y[-1]['d']['dc']
mean_func = fluid.layers.mean
out = mean_func(sub_res) + mean_func(sum_res) + mean_func(mul_res)
return out
def nested_output(x, y):
sum_res = x + y
sub_res = x - y
mul_res = x * y
out = {}
out['z'] = sum_res
out['a'] = [sub_res, 64, [mul_res, "cmd"]]
return out
def fake_data(shape):
x_data = np.random.random(shape).astype('float32')
return fluid.dygraph.to_variable(x_data)
class TestWithNestedInput(unittest.TestCase):
def setUp(self):
self.x = None
self.y = None
def fake_input(self):
self.x = fake_data([10, 16])
self.y = [
fake_data([10, 16]), "preprocess_cmd", 64, {
'z': [fake_data([10, 12]), fake_data([10, 12])],
'c': fake_data([10, 10]),
'd': {
'da': 12,
'dc': fake_data([10, 10])
}
}
]
def _run(self, to_static):
with fluid.dygraph.guard():
if self.x is None or self.y is None:
self.fake_input()
if to_static:
out = declarative(nested_input)(self.x, self.y)
else:
out = nested_input(self.x, self.y)
return out.numpy()
def test_nest(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
self.assertTrue(np.allclose(dygraph_res, static_res))
class TestWithNestedOutput(unittest.TestCase):
def setUp(self):
self.x = None
self.y = None
def _run(self, to_static):
with fluid.dygraph.guard():
if self.x is None or self.y is None:
self.x = fake_data([10, 16])
self.y = fake_data([10, 16])
if to_static:
out = declarative(nested_output)(self.x, self.y)
else:
out = nested_output(self.x, self.y)
return out
def test_nest(self):
dygraph_res = self._run(to_static=False)
dygraph_res = flatten(dygraph_res)
static_res = self._run(to_static=True)
static_res = flatten(static_res)
self.assertTrue(len(dygraph_res) == len(static_res))
for dy_var, st_var in zip(dygraph_res, static_res):
if isinstance(dy_var, fluid.core.VarBase):
self.assertTrue(np.allclose(dy_var.numpy(), st_var.numpy()))
else:
self.assertTrue(dy_var, st_var)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册