未验证 提交 e5af90aa 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add Decorator 'dygraph_to_static_program' and ProgramTranslator.save_inference_model (#23227)

1. Add Decorator 'dygraph_to_static_program'
2. Add corresponding ProgramTranslator.get_program
3. Add ProgramTranslator.save_inference_model
4. Modified some warning information of dy2stat
5. Change program cache to contain startup_program because for users who gets program to run, they may like to initialize startup program
上级 a647bcd3
...@@ -42,7 +42,10 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static ...@@ -42,7 +42,10 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static
__all__ = ['DygraphToStaticAst', 'convert_to_static'] __all__ = ['DygraphToStaticAst', 'convert_to_static']
DECORATOR_NAMES = ['dygraph_to_static_output', 'dygraph_to_static_func'] DECORATOR_NAMES = [
'dygraph_to_static_code', 'dygraph_to_static_program',
'dygraph_to_static_func', 'dygraph_to_static_output'
]
class DygraphToStaticAst(gast.NodeTransformer): class DygraphToStaticAst(gast.NodeTransformer):
......
...@@ -100,7 +100,8 @@ class ProgramCache(object): ...@@ -100,7 +100,8 @@ class ProgramCache(object):
# Always set program to default_main_program. Because once `__call__` is called, # Always set program to default_main_program. Because once `__call__` is called,
# it means layers(or Ops) are added into default_main_program switched by outer # it means layers(or Ops) are added into default_main_program switched by outer
# `with` statement. # `with` statement.
self._program = framework.default_main_program() self._main_program = framework.default_main_program()
self._startup_program = framework.default_startup_program()
self._func_cache = FunctionCache() self._func_cache = FunctionCache()
# Stores the entry function of Net or Model. # Stores the entry function of Net or Model.
self._forward_func = None self._forward_func = None
...@@ -142,7 +143,7 @@ class ProgramCache(object): ...@@ -142,7 +143,7 @@ class ProgramCache(object):
static_func = self._func_cache.get_or_cache_func(dyfunc) static_func = self._func_cache.get_or_cache_func(dyfunc)
# self._forward_func is entry function of Net or Model. # self._forward_func is entry function of Net or Model.
# It can be called for multiple times, but layers from these functions # It can be called for multiple times, but layers from these functions
# call stack will be added into self._program only once. # call stack will be added into self._main_program only once.
# After that, cached program will be always returned by default. # After that, cached program will be always returned by default.
if static_func == self._forward_func: if static_func == self._forward_func:
self._is_repeated = True self._is_repeated = True
...@@ -157,7 +158,7 @@ class ProgramCache(object): ...@@ -157,7 +158,7 @@ class ProgramCache(object):
Returns program of the input function. If called at first time, Returns program of the input function. If called at first time,
builds a new program and caches it. builds a new program and caches it.
""" """
with framework.program_guard(self._program): with framework.program_guard(self._main_program, self._startup_program):
if func == self._forward_func: if func == self._forward_func:
# Replaces input data with `layers.data` # Replaces input data with `layers.data`
args = list(args) args = list(args)
...@@ -178,7 +179,7 @@ class ProgramCache(object): ...@@ -178,7 +179,7 @@ class ProgramCache(object):
""" """
if not self._feed_name_to_idx: if not self._feed_name_to_idx:
self._feed_name_to_idx = self._get_name_to_idx(self._forward_func) self._feed_name_to_idx = self._get_name_to_idx(self._forward_func)
with framework.program_guard(self._program): with framework.program_guard(self._main_program, self._startup_program):
for feed_name, idx in self.feed_name_to_idx.items(): for feed_name, idx in self.feed_name_to_idx.items():
batch_data = args[idx] batch_data = args[idx]
assert isinstance( assert isinstance(
...@@ -201,8 +202,12 @@ class ProgramCache(object): ...@@ -201,8 +202,12 @@ class ProgramCache(object):
return feed_name_to_idx return feed_name_to_idx
@property @property
def program(self): def main_program(self):
return self._program return self._main_program
@property
def startup_program(self):
return self._startup_program
@property @property
def inputs(self): def inputs(self):
...@@ -222,7 +227,7 @@ class ProgramCache(object): ...@@ -222,7 +227,7 @@ class ProgramCache(object):
class ProgramTranslator(object): class ProgramTranslator(object):
_singleton_lock = threading.Lock()
_instance = None _instance = None
@synchronized @synchronized
...@@ -235,7 +240,8 @@ class ProgramTranslator(object): ...@@ -235,7 +240,8 @@ class ProgramTranslator(object):
@classmethod @classmethod
def get_instance(cls): def get_instance(cls):
if cls._instance is None: if cls._instance is None:
raise ValueError("ProgramTranslator hasn\'t been created!") with cls._singleton_lock:
cls._instance = cls()
return cls._instance return cls._instance
@classmethod @classmethod
...@@ -266,8 +272,9 @@ class ProgramTranslator(object): ...@@ -266,8 +272,9 @@ class ProgramTranslator(object):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
warnings.warn( warnings.warn(
"The decorator 'dygraph_to_static_output' doesn't work in dygraph mode." "The ProgramTranslator.get_output doesn't work in dygraph "
" Please use it in static mode.") "mode. We will just return dygraph output. Use the it in "
"static mode if you would like to translate to static graph.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
program_cache = self.get_program_cache() program_cache = self.get_program_cache()
...@@ -283,12 +290,29 @@ class ProgramTranslator(object): ...@@ -283,12 +290,29 @@ class ProgramTranslator(object):
""" """
if in_dygraph_mode(): if in_dygraph_mode():
warnings.warn( warnings.warn(
"The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode." "The ProgramTranslator.get_func doesn't work in dygraph "
" Please use it in static mode.") "mode. We will just return dygraph function. Use the it in "
"static mode if you would like to translate to static graph.")
return dygraph_func return dygraph_func
static_func = convert_function_with_cache(dygraph_func) static_func = convert_function_with_cache(dygraph_func)
return static_func return static_func
def get_program(self, dygraph_func, *args, **kwargs):
"""
Returns the translated static program and input/output variables from
dygraph function.
"""
if in_dygraph_mode():
warnings.warn(
"The ProgramTranslator.get_program doesn't work in dygraph "
"mode. We will just return dygraph output. Use it in static "
"mode if you would like to translate to static graph.")
return dygraph_func(*args, **kwargs)
program_cache = self.get_program_cache()
outputs = program_cache.build_program_and_return_output(dygraph_func,
*args, **kwargs)
return self.main_program, self.startup_program, program_cache.inputs, outputs
def get_code(self, dygraph_func): def get_code(self, dygraph_func):
""" """
Returns the translated static function code from dygraph code Returns the translated static function code from dygraph code
...@@ -312,7 +336,7 @@ class ProgramTranslator(object): ...@@ -312,7 +336,7 @@ class ProgramTranslator(object):
""" """
feed_dict, fetch_list = self._prepare(args) feed_dict, fetch_list = self._prepare(args)
main_program = self._program_cache.program main_program = self._program_cache.main_program
outputs = self._exe.run(main_program, outputs = self._exe.run(main_program,
feed=feed_dict, feed=feed_dict,
fetch_list=fetch_list) fetch_list=fetch_list)
...@@ -332,6 +356,25 @@ class ProgramTranslator(object): ...@@ -332,6 +356,25 @@ class ProgramTranslator(object):
format(type(loss_name))) format(type(loss_name)))
self._loss_name = loss_name self._loss_name = loss_name
def save_inference_model(self, dirname, feed=None, fetch=None):
"""
Save current model as the inference model.
"""
program_cache = self.get_program_cache()
if feed is None:
feeded_var_names = [i.name for i in program_cache.inputs]
else:
feeded_var_names = [program_cache.inputs[i].name for i in feed]
target_vars = program_cache.outputs
from paddle.fluid.io import save_inference_model
save_inference_model(
dirname=dirname,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
executor=self._exe,
main_program=self.main_program.clone())
def _prepare(self, args): def _prepare(self, args):
""" """
Prepares with feed_dict, fetch_list, optimizer and initialize vars Prepares with feed_dict, fetch_list, optimizer and initialize vars
...@@ -347,7 +390,7 @@ class ProgramTranslator(object): ...@@ -347,7 +390,7 @@ class ProgramTranslator(object):
self._add_optimizer() self._add_optimizer()
if self._need_startup: if self._need_startup:
self._exe.run(framework.default_startup_program()) self._exe.run(self.startup_program)
self._need_startup = False self._need_startup = False
return feed_dict, fetch_list return feed_dict, fetch_list
...@@ -358,8 +401,9 @@ class ProgramTranslator(object): ...@@ -358,8 +401,9 @@ class ProgramTranslator(object):
In some models and unittest, program will be switched frequently by `program_guard`. In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset. If does, the cached program and other properties are not available and should be reset.
""" """
if self._program_cache.program: if self._program_cache.main_program:
if self._program_cache.program != framework.default_main_program(): if self._program_cache.main_program != framework.default_main_program(
):
ProgramTranslator.reset() ProgramTranslator.reset()
def _update_batch_data(self, args): def _update_batch_data(self, args):
...@@ -379,7 +423,8 @@ class ProgramTranslator(object): ...@@ -379,7 +423,8 @@ class ProgramTranslator(object):
""" """
Supports to set or update the optimizer used to minimize loss. Supports to set or update the optimizer used to minimize loss.
""" """
main_program = self._program_cache.program main_program = self._program_cache.main_program
startup_program = self._program_cache.startup_program
all_vars = main_program.block(0).vars all_vars = main_program.block(0).vars
loss_var = all_vars.get(self._loss_name, None) loss_var = all_vars.get(self._loss_name, None)
...@@ -388,7 +433,7 @@ class ProgramTranslator(object): ...@@ -388,7 +433,7 @@ class ProgramTranslator(object):
"Can't find {} in main_program, please confirm whether the loss input is correct" "Can't find {} in main_program, please confirm whether the loss input is correct"
.format(self._loss_name)) .format(self._loss_name))
# Adds optimizer to minimize loss # Adds optimizer to minimize loss
with framework.program_guard(main_program): with framework.program_guard(main_program, startup_program):
self._optimizer.minimize(loss_var) self._optimizer.minimize(loss_var)
# Avoids to set optimizer repeatedly. # Avoids to set optimizer repeatedly.
...@@ -402,5 +447,9 @@ class ProgramTranslator(object): ...@@ -402,5 +447,9 @@ class ProgramTranslator(object):
return self._program_cache return self._program_cache
@property @property
def program(self): def main_program(self):
return self._program_cache.program return self._program_cache.main_program
@property
def startup_program(self):
return self._program_cache.startup_program
...@@ -16,19 +16,19 @@ from __future__ import print_function ...@@ -16,19 +16,19 @@ from __future__ import print_function
__all__ = [ __all__ = [
'TracedLayer', 'dygraph_to_static_code', 'dygraph_to_static_func', 'TracedLayer', 'dygraph_to_static_code', 'dygraph_to_static_func',
'dygraph_to_static_output' 'dygraph_to_static_output', 'dygraph_to_static_program'
] ]
import warnings import warnings
from ..wrapped_decorator import wrap_decorator from ..wrapped_decorator import wrap_decorator
from .base import program_desc_tracing_guard, switch_to_static_graph from .base import program_desc_tracing_guard, switch_to_static_graph
from .dygraph_to_static import ProgramTranslator, convert_to_static
from .layers import Layer from .layers import Layer
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
from paddle.fluid.executor import Executor, scope_guard from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.compiler import CompiledProgram from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
def create_program_from_desc(program_desc): def create_program_from_desc(program_desc):
...@@ -55,27 +55,42 @@ def extract_vars(inputs): ...@@ -55,27 +55,42 @@ def extract_vars(inputs):
def _dygraph_to_static_code_(dygraph_func): def _dygraph_to_static_code_(dygraph_func):
def __impl__(*args, **kwargs):
program_translator = ProgramTranslator()
return program_translator.get_code(dygraph_func)
return __impl__
dygraph_to_static_code = wrap_decorator(_dygraph_to_static_code_)
def _dygraph_to_static_program_(dygraph_func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
if in_dygraph_mode(): if in_dygraph_mode():
warnings.warn( warnings.warn(
"The decorator 'dygraph_to_static_code' doesn't work in dygraph mode." "The decorator 'dygraph_to_static_program' doesn't work in "
" Please use it in static mode.") "dygraph mode. We will just return dygraph output. Use the "
"decorator in static mode if you would like to translate to "
"static graph.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
return program_translator.get_code(dygraph_func) return program_translator.get_program(dygraph_func, *args, **kwargs)
return __impl__ return __impl__
dygraph_to_static_code = wrap_decorator(_dygraph_to_static_code_) dygraph_to_static_program = wrap_decorator(_dygraph_to_static_program_)
def _dygraph_to_static_func_(dygraph_func): def _dygraph_to_static_func_(dygraph_func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
if in_dygraph_mode(): if in_dygraph_mode():
warnings.warn( warnings.warn(
"The decorator 'dygraph_to_static_func' doesn't work in dygraph mode." "The decorator 'dygraph_to_static_func' doesn't work in "
" Please use it in static mode.") "dygraph mode. We will just return dygraph output. Use the "
"decorator in static mode if you would like to translate to "
"static graph.")
return dygraph_func(*args, **kwargs) return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
static_func = program_translator.get_func(dygraph_func) static_func = program_translator.get_func(dygraph_func)
...@@ -89,6 +104,13 @@ dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_) ...@@ -89,6 +104,13 @@ dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_)
def _dygraph_to_static_output_(dygraph_func): def _dygraph_to_static_output_(dygraph_func):
def __impl__(*args, **kwargs): def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_output' doesn't work in "
"dygraph mode. We will just return dygraph output. Use the "
"decorator in static mode if you would like to translate to "
"static graph.")
return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
return program_translator.get_output(dygraph_func, *args, **kwargs) return program_translator.get_output(dygraph_func, *args, **kwargs)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import dygraph_to_static_output
np.random.seed(2020)
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
class SimpleFcLayer(fluid.dygraph.Layer):
def __init__(self, fc_size):
super(SimpleFcLayer, self).__init__()
self._linear = fluid.dygraph.Linear(fc_size, fc_size)
@dygraph_to_static_output
def forward(self, x):
x = fluid.dygraph.to_variable(x)
y = self._linear(x)
z = self._linear(y)
out = fluid.layers.mean(z, name='mean')
return out
class TestDyToStaticSaveInferenceModel(unittest.TestCase):
def test_save_inference_model(self):
fc_size = 20
x = np.random.random((fc_size, fc_size)).astype('float32')
layer = SimpleFcLayer(fc_size)
program_translator = ProgramTranslator.get_instance()
program_cache = ProgramTranslator().get_program_cache
adam = fluid.optimizer.SGD(learning_rate=0.001)
program_translator.set_optimizer(adam, 'mean')
for i in range(5):
out = layer(x)
main_program = ProgramTranslator.get_instance().main_program
expected_persistable_vars = set(
[layer._linear.weight.name, layer._linear.bias.name])
infer_model_dir = "./test_dy2stat_save_inference_model"
ProgramTranslator.get_instance().save_inference_model(infer_model_dir)
saved_var_names = set([
filename for filename in os.listdir(infer_model_dir)
if filename != '__model__'
])
self.assertEqual(saved_var_names, expected_persistable_vars)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from paddle.fluid.dygraph.jit import dygraph_to_static_program
from paddle.fluid.dygraph.nn import Linear
np.random.seed(2020)
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
def simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
return z
@dygraph_to_static_program
def decorated_simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
return z
class TestDyToStaticSaveLoad(unittest.TestCase):
def test_save_load_same_result(self):
x = np.random.randn(30, 10, 32).astype('float32')
weight = np.random.randn(32, 64).astype('float32')
with fluid.dygraph.guard(place):
dygraph_result = simple_func(x, weight)
main_program, startup_program, inputs, outputs = decorated_simple_func(
x, weight)
exe = fluid.Executor(place)
exe.run(startup_program)
fluid.save(main_program, "./test_dy2stat_save_load")
# set vars to zero so that we can test load in same file
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
tensor = fluid.global_scope().find_var(var.name).get_tensor()
tensor.set(np.zeros_like(np.array(tensor)), place)
# make sure all the paramerter or optimizer var have been set to zero
tensor_np = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertEqual(0, np.sum(np.abs(tensor_np)))
fluid.load(main_program, "./test_dy2stat_save_load")
static_result = exe.run(main_program,
feed={inputs[0].name: x},
fetch_list=outputs)
self.assertTrue(np.allclose(dygraph_result.numpy(), static_result))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册