未验证 提交 e5af90aa 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add Decorator 'dygraph_to_static_program' and ProgramTranslator.save_inference_model (#23227)

1. Add Decorator 'dygraph_to_static_program'
2. Add corresponding ProgramTranslator.get_program
3. Add ProgramTranslator.save_inference_model
4. Modified some warning information of dy2stat
5. Change program cache to contain startup_program because for users who gets program to run, they may like to initialize startup program
上级 a647bcd3
......@@ -42,7 +42,10 @@ from paddle.fluid.dygraph.dygraph_to_static.utils import dygraph_class_to_static
__all__ = ['DygraphToStaticAst', 'convert_to_static']
DECORATOR_NAMES = ['dygraph_to_static_output', 'dygraph_to_static_func']
DECORATOR_NAMES = [
'dygraph_to_static_code', 'dygraph_to_static_program',
'dygraph_to_static_func', 'dygraph_to_static_output'
]
class DygraphToStaticAst(gast.NodeTransformer):
......
......@@ -100,7 +100,8 @@ class ProgramCache(object):
# Always set program to default_main_program. Because once `__call__` is called,
# it means layers(or Ops) are added into default_main_program switched by outer
# `with` statement.
self._program = framework.default_main_program()
self._main_program = framework.default_main_program()
self._startup_program = framework.default_startup_program()
self._func_cache = FunctionCache()
# Stores the entry function of Net or Model.
self._forward_func = None
......@@ -142,7 +143,7 @@ class ProgramCache(object):
static_func = self._func_cache.get_or_cache_func(dyfunc)
# self._forward_func is entry function of Net or Model.
# It can be called for multiple times, but layers from these functions
# call stack will be added into self._program only once.
# call stack will be added into self._main_program only once.
# After that, cached program will be always returned by default.
if static_func == self._forward_func:
self._is_repeated = True
......@@ -157,7 +158,7 @@ class ProgramCache(object):
Returns program of the input function. If called at first time,
builds a new program and caches it.
"""
with framework.program_guard(self._program):
with framework.program_guard(self._main_program, self._startup_program):
if func == self._forward_func:
# Replaces input data with `layers.data`
args = list(args)
......@@ -178,7 +179,7 @@ class ProgramCache(object):
"""
if not self._feed_name_to_idx:
self._feed_name_to_idx = self._get_name_to_idx(self._forward_func)
with framework.program_guard(self._program):
with framework.program_guard(self._main_program, self._startup_program):
for feed_name, idx in self.feed_name_to_idx.items():
batch_data = args[idx]
assert isinstance(
......@@ -201,8 +202,12 @@ class ProgramCache(object):
return feed_name_to_idx
@property
def program(self):
return self._program
def main_program(self):
return self._main_program
@property
def startup_program(self):
return self._startup_program
@property
def inputs(self):
......@@ -222,7 +227,7 @@ class ProgramCache(object):
class ProgramTranslator(object):
_singleton_lock = threading.Lock()
_instance = None
@synchronized
......@@ -235,7 +240,8 @@ class ProgramTranslator(object):
@classmethod
def get_instance(cls):
if cls._instance is None:
raise ValueError("ProgramTranslator hasn\'t been created!")
with cls._singleton_lock:
cls._instance = cls()
return cls._instance
@classmethod
......@@ -266,8 +272,9 @@ class ProgramTranslator(object):
"""
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_output' doesn't work in dygraph mode."
" Please use it in static mode.")
"The ProgramTranslator.get_output doesn't work in dygraph "
"mode. We will just return dygraph output. Use the it in "
"static mode if you would like to translate to static graph.")
return dygraph_func(*args, **kwargs)
program_cache = self.get_program_cache()
......@@ -283,12 +290,29 @@ class ProgramTranslator(object):
"""
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_graph' doesn't work in dygraph mode."
" Please use it in static mode.")
"The ProgramTranslator.get_func doesn't work in dygraph "
"mode. We will just return dygraph function. Use the it in "
"static mode if you would like to translate to static graph.")
return dygraph_func
static_func = convert_function_with_cache(dygraph_func)
return static_func
def get_program(self, dygraph_func, *args, **kwargs):
"""
Returns the translated static program and input/output variables from
dygraph function.
"""
if in_dygraph_mode():
warnings.warn(
"The ProgramTranslator.get_program doesn't work in dygraph "
"mode. We will just return dygraph output. Use it in static "
"mode if you would like to translate to static graph.")
return dygraph_func(*args, **kwargs)
program_cache = self.get_program_cache()
outputs = program_cache.build_program_and_return_output(dygraph_func,
*args, **kwargs)
return self.main_program, self.startup_program, program_cache.inputs, outputs
def get_code(self, dygraph_func):
"""
Returns the translated static function code from dygraph code
......@@ -312,7 +336,7 @@ class ProgramTranslator(object):
"""
feed_dict, fetch_list = self._prepare(args)
main_program = self._program_cache.program
main_program = self._program_cache.main_program
outputs = self._exe.run(main_program,
feed=feed_dict,
fetch_list=fetch_list)
......@@ -332,6 +356,25 @@ class ProgramTranslator(object):
format(type(loss_name)))
self._loss_name = loss_name
def save_inference_model(self, dirname, feed=None, fetch=None):
"""
Save current model as the inference model.
"""
program_cache = self.get_program_cache()
if feed is None:
feeded_var_names = [i.name for i in program_cache.inputs]
else:
feeded_var_names = [program_cache.inputs[i].name for i in feed]
target_vars = program_cache.outputs
from paddle.fluid.io import save_inference_model
save_inference_model(
dirname=dirname,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
executor=self._exe,
main_program=self.main_program.clone())
def _prepare(self, args):
"""
Prepares with feed_dict, fetch_list, optimizer and initialize vars
......@@ -347,7 +390,7 @@ class ProgramTranslator(object):
self._add_optimizer()
if self._need_startup:
self._exe.run(framework.default_startup_program())
self._exe.run(self.startup_program)
self._need_startup = False
return feed_dict, fetch_list
......@@ -358,8 +401,9 @@ class ProgramTranslator(object):
In some models and unittest, program will be switched frequently by `program_guard`.
If does, the cached program and other properties are not available and should be reset.
"""
if self._program_cache.program:
if self._program_cache.program != framework.default_main_program():
if self._program_cache.main_program:
if self._program_cache.main_program != framework.default_main_program(
):
ProgramTranslator.reset()
def _update_batch_data(self, args):
......@@ -379,7 +423,8 @@ class ProgramTranslator(object):
"""
Supports to set or update the optimizer used to minimize loss.
"""
main_program = self._program_cache.program
main_program = self._program_cache.main_program
startup_program = self._program_cache.startup_program
all_vars = main_program.block(0).vars
loss_var = all_vars.get(self._loss_name, None)
......@@ -388,7 +433,7 @@ class ProgramTranslator(object):
"Can't find {} in main_program, please confirm whether the loss input is correct"
.format(self._loss_name))
# Adds optimizer to minimize loss
with framework.program_guard(main_program):
with framework.program_guard(main_program, startup_program):
self._optimizer.minimize(loss_var)
# Avoids to set optimizer repeatedly.
......@@ -402,5 +447,9 @@ class ProgramTranslator(object):
return self._program_cache
@property
def program(self):
return self._program_cache.program
def main_program(self):
return self._program_cache.main_program
@property
def startup_program(self):
return self._program_cache.startup_program
......@@ -16,19 +16,19 @@ from __future__ import print_function
__all__ = [
'TracedLayer', 'dygraph_to_static_code', 'dygraph_to_static_func',
'dygraph_to_static_output'
'dygraph_to_static_output', 'dygraph_to_static_program'
]
import warnings
from ..wrapped_decorator import wrap_decorator
from .base import program_desc_tracing_guard, switch_to_static_graph
from .dygraph_to_static import ProgramTranslator, convert_to_static
from .layers import Layer
from paddle.fluid import core
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.compiler import CompiledProgram
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
def create_program_from_desc(program_desc):
......@@ -55,27 +55,42 @@ def extract_vars(inputs):
def _dygraph_to_static_code_(dygraph_func):
def __impl__(*args, **kwargs):
program_translator = ProgramTranslator()
return program_translator.get_code(dygraph_func)
return __impl__
dygraph_to_static_code = wrap_decorator(_dygraph_to_static_code_)
def _dygraph_to_static_program_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_code' doesn't work in dygraph mode."
" Please use it in static mode.")
"The decorator 'dygraph_to_static_program' doesn't work in "
"dygraph mode. We will just return dygraph output. Use the "
"decorator in static mode if you would like to translate to "
"static graph.")
return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator()
return program_translator.get_code(dygraph_func)
return program_translator.get_program(dygraph_func, *args, **kwargs)
return __impl__
dygraph_to_static_code = wrap_decorator(_dygraph_to_static_code_)
dygraph_to_static_program = wrap_decorator(_dygraph_to_static_program_)
def _dygraph_to_static_func_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_func' doesn't work in dygraph mode."
" Please use it in static mode.")
"The decorator 'dygraph_to_static_func' doesn't work in "
"dygraph mode. We will just return dygraph output. Use the "
"decorator in static mode if you would like to translate to "
"static graph.")
return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator()
static_func = program_translator.get_func(dygraph_func)
......@@ -89,6 +104,13 @@ dygraph_to_static_func = wrap_decorator(_dygraph_to_static_func_)
def _dygraph_to_static_output_(dygraph_func):
def __impl__(*args, **kwargs):
if in_dygraph_mode():
warnings.warn(
"The decorator 'dygraph_to_static_output' doesn't work in "
"dygraph mode. We will just return dygraph output. Use the "
"decorator in static mode if you would like to translate to "
"static graph.")
return dygraph_func(*args, **kwargs)
program_translator = ProgramTranslator()
return program_translator.get_output(dygraph_func, *args, **kwargs)
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import unittest
import numpy as np
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import dygraph_to_static_output
np.random.seed(2020)
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
class SimpleFcLayer(fluid.dygraph.Layer):
def __init__(self, fc_size):
super(SimpleFcLayer, self).__init__()
self._linear = fluid.dygraph.Linear(fc_size, fc_size)
@dygraph_to_static_output
def forward(self, x):
x = fluid.dygraph.to_variable(x)
y = self._linear(x)
z = self._linear(y)
out = fluid.layers.mean(z, name='mean')
return out
class TestDyToStaticSaveInferenceModel(unittest.TestCase):
def test_save_inference_model(self):
fc_size = 20
x = np.random.random((fc_size, fc_size)).astype('float32')
layer = SimpleFcLayer(fc_size)
program_translator = ProgramTranslator.get_instance()
program_cache = ProgramTranslator().get_program_cache
adam = fluid.optimizer.SGD(learning_rate=0.001)
program_translator.set_optimizer(adam, 'mean')
for i in range(5):
out = layer(x)
main_program = ProgramTranslator.get_instance().main_program
expected_persistable_vars = set(
[layer._linear.weight.name, layer._linear.bias.name])
infer_model_dir = "./test_dy2stat_save_inference_model"
ProgramTranslator.get_instance().save_inference_model(infer_model_dir)
saved_var_names = set([
filename for filename in os.listdir(infer_model_dir)
if filename != '__model__'
])
self.assertEqual(saved_var_names, expected_persistable_vars)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from paddle.fluid.dygraph.jit import dygraph_to_static_program
from paddle.fluid.dygraph.nn import Linear
np.random.seed(2020)
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
)
def simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
return z
@dygraph_to_static_program
def decorated_simple_func(x, weight_numpy):
weight_initalizer = fluid.initializer.NumpyArrayInitializer(weight_numpy)
linear = Linear(32, 64, param_attr=weight_initalizer)
x = fluid.dygraph.to_variable(x)
y = linear(x)
z = linear(x)
return z
class TestDyToStaticSaveLoad(unittest.TestCase):
def test_save_load_same_result(self):
x = np.random.randn(30, 10, 32).astype('float32')
weight = np.random.randn(32, 64).astype('float32')
with fluid.dygraph.guard(place):
dygraph_result = simple_func(x, weight)
main_program, startup_program, inputs, outputs = decorated_simple_func(
x, weight)
exe = fluid.Executor(place)
exe.run(startup_program)
fluid.save(main_program, "./test_dy2stat_save_load")
# set vars to zero so that we can test load in same file
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
tensor = fluid.global_scope().find_var(var.name).get_tensor()
tensor.set(np.zeros_like(np.array(tensor)), place)
# make sure all the paramerter or optimizer var have been set to zero
tensor_np = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
self.assertEqual(0, np.sum(np.abs(tensor_np)))
fluid.load(main_program, "./test_dy2stat_save_load")
static_result = exe.run(main_program,
feed={inputs[0].name: x},
fetch_list=outputs)
self.assertTrue(np.allclose(dygraph_result.numpy(), static_result))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册