未验证 提交 8bbb829a 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Support nested input and output (#24752)

* support nested input and output test=develop

* remove code of convert type(output) in unittest test=develop

* add warning test=develop
上级 a7e21cbe
......@@ -17,7 +17,6 @@ from __future__ import print_function
import gast
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrapper
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
......
......@@ -17,7 +17,6 @@ from __future__ import print_function
import gast
from paddle.fluid import unique_name
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
from paddle.fluid.dygraph.dygraph_to_static.variable_trans_func import create_fill_constant_node
......
......@@ -23,9 +23,7 @@ from paddle.fluid.dygraph.dygraph_to_static.static_analysis import AstNodeWrappe
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import NodeVarType
from paddle.fluid.dygraph.dygraph_to_static.static_analysis import StaticAnalysisVisitor
from paddle.fluid.dygraph.dygraph_to_static.utils import ast_to_source_code
from paddle.fluid.dygraph.dygraph_to_static.utils import create_api_shape_node
from paddle.fluid.dygraph.dygraph_to_static.utils import generate_name_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_constant_variable_node
from paddle.fluid.dygraph.dygraph_to_static.utils import get_attribute_full_name
from paddle.fluid.dygraph.dygraph_to_static.utils import is_control_flow_to_transform
from paddle.fluid.dygraph.dygraph_to_static.utils import ForNodeVisitor
......
......@@ -14,11 +14,75 @@
from __future__ import print_function
import numpy as np
import logging
from paddle.fluid import log_helper
from paddle.fluid import framework, backward, core
from paddle.fluid.dygraph import layers
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph
import paddle.compat as cpt
_logger = log_helper.get_logger(
__name__, logging.WARNING, fmt='%(asctime)s-%(levelname)s: %(message)s')
class NestSequence(object):
"""
A wrapper class that easily to flatten and restore the nest structure of
given sequence.
"""
def __init__(self, raw_input, need_check=False):
self.__raw_input = raw_input
self.__var_ids = self._get_var_ids()
self._check_non_variable(need_check)
def tolist(self):
"""
Flattens the nested sequences into single list.
"""
return flatten(self.__raw_input)
def restore(self, value_list):
"""
Restores the nested sequence from value list.
"""
assert len(self.tolist()) == len(value_list)
return pack_sequence_as(self.__raw_input, value_list)
def _get_var_ids(self):
var_ids = []
for idx, var in enumerate(self.tolist()):
if isinstance(var, (framework.Variable, core.VarBase)):
var_ids.append(idx)
return var_ids
def _check_non_variable(self, need_check):
"""
Raises warning if output of traced function contains non-tensor type values.
"""
if need_check:
warning_types = set()
for var in self.tolist():
if not isinstance(var, (framework.Variable, core.VarBase)):
warning_types.add(type(var))
if warning_types:
_logger.warning(
"Output of traced function contains non-tensor type values: {}. "
"Currently, We don't support to update them while training and will return "
"what we first saw. Please try to return them as tensor.".
format(list(warning_types)))
@property
def var_ids(self):
return self.__var_ids
def __getitem__(self, item):
return self.tolist()[item]
class PartialProgramLayer(layers.Layer):
"""
......@@ -43,8 +107,8 @@ class PartialProgramLayer(layers.Layer):
def __init__(self, main_program, inputs, outputs, parameters=None):
super(PartialProgramLayer, self).__init__()
self.inputs = inputs
self.outputs = outputs
self._inputs = NestSequence(inputs)
self._outputs = NestSequence(outputs, need_check=True)
self._params = parameters if parameters is not None else []
# Check all params from main program can be found in self._params:
# 1. parameter in self._params should be type `framework.ParamBase` which are created in dygraph.
......@@ -65,7 +129,7 @@ class PartialProgramLayer(layers.Layer):
def _append_backward_desc(self):
program = self._infer_program.clone()
targets = []
for out in self.outputs:
for out in self._outputs.tolist():
if isinstance(out, framework.Variable):
targets.append(program.global_block().var(out.name))
......@@ -101,37 +165,37 @@ class PartialProgramLayer(layers.Layer):
'is_test': not self.training
})
outs = out_vars
if len(outs) == 1:
outs = outs[0]
return outs
return self._restore_out(out_vars)
def _prepare(self, inputs):
"""
Prepare inputs, outputs, attrs.
"""
assert isinstance(inputs, (tuple, list))
# Flatten inputs with nested structure into single list.
flatten_inputs = flatten(inputs)
# Convert variable into VarBase and feed in training data.
input_vars = []
for i, value in enumerate(inputs):
for i, value in enumerate(flatten_inputs):
if isinstance(value, np.ndarray):
var = core.VarBase(
value=value,
name=self.inputs[i].desc.name(),
name=self._inputs[i].desc.name(),
persistable=False,
place=framework._current_expected_place(),
zero_copy=True)
elif isinstance(value, core.VarBase):
var = value
var.name = self.inputs[i].desc.name()
var.name = self._inputs[i].desc.name()
else:
continue
input_vars.append(var)
# Create VarBase to receive output data.
out_vars = []
for var in self.outputs:
if not isinstance(var, framework.Variable):
continue
for idx in self._outputs.var_ids:
var = self._outputs[idx]
assert isinstance(var, framework.Variable)
var_desc = var.desc
var_base = core.VarBase(var_desc.dtype(),
var_desc.shape(),
......@@ -147,6 +211,20 @@ class PartialProgramLayer(layers.Layer):
return input_vars, out_vars, tmp_scope_vec
def _restore_out(self, out_vars):
"""
Restores same nested outputs by only replacing the Variable with VarBase.
"""
flatten_outputs = self._outputs.tolist()
for i, idx in enumerate(self._outputs.var_ids):
flatten_outputs[idx] = out_vars[i]
outs = self._outputs.restore(flatten_outputs)
if len(outs) == 1:
outs = outs[0]
return outs
def _set_grad_type(self, params):
# NOTE: if user set sparse gradient mode, the param's gradient
# will be SelectedRows, not LoDTensor. But tracer will just
......
......@@ -25,6 +25,8 @@ from paddle.fluid import framework
from paddle.fluid import executor
from paddle.fluid import unique_name
from paddle.fluid.dygraph import layers
from paddle.fluid.layers.utils import flatten
from paddle.fluid.layers.utils import pack_sequence_as
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import convert_to_static
from paddle.fluid.dygraph.dygraph_to_static.ast_transformer import DygraphToStaticAst
......@@ -108,7 +110,7 @@ class FunctionSpec(object):
def to_static_inputs(self, main_program):
inputs = []
block = main_program.global_block()
for input_var in self.args:
for input_var in flatten(self.args):
if isinstance(input_var, np.ndarray):
feed_layer = block.create_var(
name=unique_name.generate('feed'),
......@@ -127,7 +129,8 @@ class FunctionSpec(object):
feed_layer = input_var
inputs.append(feed_layer)
return inputs
# Restores the nested structure as self.args
return pack_sequence_as(self.args, inputs)
@property
def dyfunc(self):
......@@ -175,12 +178,12 @@ class ConcreteProgram(object):
of program as fetch_list.
"""
# Transforms dygraph function into static function and caches it.
dygaph_function = func_spec.dyfunc
static_func = convert_function_with_cache(dygaph_function)
dygraph_function = func_spec.dyfunc
static_func = convert_function_with_cache(dygraph_function)
main_program, startup_program = framework.Program(), framework.Program()
# Note: The random seed should be synchronized into cached program
# if set in `fluid.dygrap_guard` because some ops rely on it, such as
# if set in `fluid.dygraph_guard` because some ops rely on it, such as
# `fluid.layers.dropout`.
main_program.random_seed = framework.default_main_program().random_seed
startup_program.random_seed = framework.default_startup_program(
......@@ -203,7 +206,7 @@ class ConcreteProgram(object):
inputs=inputs,
outputs=outputs,
parameters=all_parameters,
func=dygaph_function,
func=dygraph_function,
main_program=main_program,
startup_program=startup_program)
......
......@@ -186,12 +186,9 @@ class TestTransform(TestTransformBase):
if not isinstance(dy_outs, tuple):
dy_outs = (dy_outs, )
# NOTE: return type is difference
st_outs = self.get_static_output()
if not isinstance(st_outs, list):
if not isinstance(st_outs, tuple):
st_outs = (st_outs, )
else:
st_outs = tuple(st_outs)
for x, y in zip(dy_outs, st_outs):
self.assertTrue(np.allclose(x.numpy(), y.numpy()))
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.layers.utils import flatten
from paddle.fluid.dygraph import declarative
import unittest
SEED = 2020
def nested_input(x, y):
sum_res = x + y[0]
z_elem = y[3]['z']
sub_res = z_elem[0] - z_elem[1]
mul_res = y[-1]['d']['da'] * y[-1]['d']['dc']
mean_func = fluid.layers.mean
out = mean_func(sub_res) + mean_func(sum_res) + mean_func(mul_res)
return out
def nested_output(x, y):
sum_res = x + y
sub_res = x - y
mul_res = x * y
out = {}
out['z'] = sum_res
out['a'] = [sub_res, 64, [mul_res, "cmd"]]
return out
def fake_data(shape):
x_data = np.random.random(shape).astype('float32')
return fluid.dygraph.to_variable(x_data)
class TestWithNestedInput(unittest.TestCase):
def setUp(self):
self.x = None
self.y = None
def fake_input(self):
self.x = fake_data([10, 16])
self.y = [
fake_data([10, 16]), "preprocess_cmd", 64, {
'z': [fake_data([10, 12]), fake_data([10, 12])],
'c': fake_data([10, 10]),
'd': {
'da': 12,
'dc': fake_data([10, 10])
}
}
]
def _run(self, to_static):
with fluid.dygraph.guard():
if self.x is None or self.y is None:
self.fake_input()
if to_static:
out = declarative(nested_input)(self.x, self.y)
else:
out = nested_input(self.x, self.y)
return out.numpy()
def test_nest(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
self.assertTrue(np.allclose(dygraph_res, static_res))
class TestWithNestedOutput(unittest.TestCase):
def setUp(self):
self.x = None
self.y = None
def _run(self, to_static):
with fluid.dygraph.guard():
if self.x is None or self.y is None:
self.x = fake_data([10, 16])
self.y = fake_data([10, 16])
if to_static:
out = declarative(nested_output)(self.x, self.y)
else:
out = nested_output(self.x, self.y)
return out
def test_nest(self):
dygraph_res = self._run(to_static=False)
dygraph_res = flatten(dygraph_res)
static_res = self._run(to_static=True)
static_res = flatten(static_res)
self.assertTrue(len(dygraph_res) == len(static_res))
for dy_var, st_var in zip(dygraph_res, static_res):
if isinstance(dy_var, fluid.core.VarBase):
self.assertTrue(np.allclose(dy_var.numpy(), st_var.numpy()))
else:
self.assertTrue(dy_var, st_var)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册