未验证 提交 86c40e20 编写于 作者: Z Zeng Jinle 提交者: GitHub

Expose fluid.dygraph.TracedLayer API (#21518)

* expost fluid.dygraph.TracedLayer apis, test=develop

* polish doc, test=develop

* follow comments, test=develop, test=document_fix

* follow comments, test=develop

* remove save_inference_model return value, test=develop
上级 911eef43
...@@ -41,6 +41,9 @@ from .learning_rate_scheduler import * ...@@ -41,6 +41,9 @@ from .learning_rate_scheduler import *
from . import backward_strategy from . import backward_strategy
from .backward_strategy import * from .backward_strategy import *
from . import jit
from .jit import *
__all__ = [] __all__ = []
__all__ += layers.__all__ __all__ += layers.__all__
__all__ += base.__all__ __all__ += base.__all__
...@@ -51,3 +54,4 @@ __all__ += parallel.__all__ ...@@ -51,3 +54,4 @@ __all__ += parallel.__all__
__all__ += checkpoint.__all__ __all__ += checkpoint.__all__
__all__ += learning_rate_scheduler.__all__ __all__ += learning_rate_scheduler.__all__
__all__ += backward_strategy.__all__ __all__ += backward_strategy.__all__
__all__ += jit.__all__
...@@ -20,7 +20,6 @@ from paddle.fluid import core ...@@ -20,7 +20,6 @@ from paddle.fluid import core
from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard, _current_expected_place, in_dygraph_mode
from paddle.fluid.executor import Executor, scope_guard from paddle.fluid.executor import Executor, scope_guard
from paddle.fluid.compiler import CompiledProgram from paddle.fluid.compiler import CompiledProgram
import paddle.fluid.io as fluid_io
def create_program_from_desc(program_desc): def create_program_from_desc(program_desc):
...@@ -81,11 +80,15 @@ def _trace(layer, ...@@ -81,11 +80,15 @@ def _trace(layer,
class TracedLayer(object): class TracedLayer(object):
""" """
TracedLayer is a callable object which is converted from dygraph model. TracedLayer is used to convert a forward dygraph model to a static
Inside TracedLayer, the dygraph model is converted into a static graph graph model. This is mainly used to save the dygraph model for online
model, and it would run the static graph model using inference using C++. Besides, users can also do inference in Python
:code:`Executor` and :code:`CompiledProgram` . The static graph model using the converted static graph model, which usually has better
would share parameters with the dygraph model. performance than the original dygraph model.
TracedLayer would run the static graph model using :code:`Executor`
and :code:`CompiledProgram` . The static graph model would share
parameters with the dygraph model.
All TracedLayer objects should not be created by constructor and should All TracedLayer objects should not be created by constructor and should
be created by static method :code:`TracedLayer.trace(layer, inputs)` . be created by static method :code:`TracedLayer.trace(layer, inputs)` .
...@@ -133,37 +136,43 @@ class TracedLayer(object): ...@@ -133,37 +136,43 @@ class TracedLayer(object):
model and convert it into a static graph model. model and convert it into a static graph model.
Args: Args:
layer (paddle.fluid.dygraph.Layer): the layer object to be traced. layer (dygraph.Layer): the layer object to be traced.
inputs (list(Variable)): the input variables of the layer object. inputs (list(Variable)): the input variables of the layer object.
Returns: Returns:
A tuple of 2 items, whose the first item is the output of tuple: A tuple of 2 items, whose the first item is the output of
:code:`layer(*inputs)` , and the second item is the created :code:`layer(*inputs)` , and the second item is the created
TracedLayer object. TracedLayer object.
Examples:
Examples:
.. code-block:: python: .. code-block:: python:
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import FC, to_variable, TracedLayer from paddle.fluid.dygraph import Linear, to_variable, TracedLayer
import paddle.fluid.dygraph.jit as jit
import numpy as np import numpy as np
class ExampleLayer(fluid.dygraph.Layer): class ExampleLayer(fluid.dygraph.Layer):
def __init__(self, name_scope): def __init__(self):
super(ExampleLayer, self).__init__(name_scope) super(ExampleLayer, self).__init__()
self._fc = FC(self.full_name(), 10) self._fc = Linear(3, 10)
def forward(self, input): def forward(self, input):
return self._fc(input) return self._fc(input)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
layer = ExampleLayer("example_layer") layer = ExampleLayer()
in_np = np.random.random([2, 3]).astype('float32') in_np = np.random.random([2, 3]).astype('float32')
in_var = to_variable(in_np) in_var = to_variable(in_np)
out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var]) out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])
out_static_graph = static_layer([in_var])
# run the static graph model using Executor inside
out_static_graph = static_layer([in_var])
print(len(out_static_graph)) # 1
print(out_static_graph[0].shape) # (2, 10)
# save the static graph model for inference
static_layer.save_inference_model(dirname='./saved_infer_model')
""" """
outs, prog, feed, fetch = _trace(layer, inputs) outs, prog, feed, fetch = _trace(layer, inputs)
traced = TracedLayer(prog, layer.parameters(), feed, fetch) traced = TracedLayer(prog, layer.parameters(), feed, fetch)
...@@ -174,7 +183,7 @@ class TracedLayer(object): ...@@ -174,7 +183,7 @@ class TracedLayer(object):
Set the strategies when running static graph model. Set the strategies when running static graph model.
Args: Args:
build_strategy (BuildStrategy, optional): build strategy of build_strategy (BuildStrategy, optional): build strategy of
:code:`CompiledProgram` inside TracedLayer. Default None. :code:`CompiledProgram` inside TracedLayer. Default None.
exec_strategy (ExecutionStrategy, optional): execution strategy of exec_strategy (ExecutionStrategy, optional): execution strategy of
:code:`CompiledProgram` inside TracedLayer. Default None. :code:`CompiledProgram` inside TracedLayer. Default None.
...@@ -183,24 +192,22 @@ class TracedLayer(object): ...@@ -183,24 +192,22 @@ class TracedLayer(object):
None None
Examples: Examples:
.. code-block:: python: .. code-block:: python:
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import FC, to_variable, TracedLayer from paddle.fluid.dygraph import Linear, to_variable, TracedLayer
import paddle.fluid.dygraph.jit as jit
import numpy as np import numpy as np
class ExampleLayer(fluid.dygraph.Layer): class ExampleLayer(fluid.dygraph.Layer):
def __init__(self, name_scope): def __init__(self):
super(ExampleLayer, self).__init__(name_scope) super(ExampleLayer, self).__init__()
self._fc = FC(self.full_name(), 10) self._fc = Linear(3, 10)
def forward(self, input): def forward(self, input):
return self._fc(input) return self._fc(input)
with fluid.dygraph.guard(): with fluid.dygraph.guard():
layer = ExampleLayer("example_layer") layer = ExampleLayer()
in_np = np.random.random([2, 3]).astype('float32') in_np = np.random.random([2, 3]).astype('float32')
in_var = to_variable(in_np) in_var = to_variable(in_np)
...@@ -257,13 +264,13 @@ class TracedLayer(object): ...@@ -257,13 +264,13 @@ class TracedLayer(object):
@switch_to_static_graph @switch_to_static_graph
def save_inference_model(self, dirname, feed=None, fetch=None): def save_inference_model(self, dirname, feed=None, fetch=None):
""" """
Save the TracedLayer to an model for inference. The saved Save the TracedLayer to a model for inference. The saved
inference model can be loaded by C++ inference APIs. inference model can be loaded by C++ inference APIs.
Args: Args:
dirname (str): the directory to save the inference model. dirname (str): the directory to save the inference model.
feed (list[int], optional): the input variable indices of the saved feed (list[int], optional): the input variable indices of the saved
inference model. If None, all input variables of the inference model. If None, all input variables of the
TracedLayer object would be the inputs of the saved inference TracedLayer object would be the inputs of the saved inference
model. Default None. model. Default None.
fetch (list[int], optional): the output variable indices of the fetch (list[int], optional): the output variable indices of the
...@@ -272,35 +279,41 @@ class TracedLayer(object): ...@@ -272,35 +279,41 @@ class TracedLayer(object):
model. Default None. model. Default None.
Returns: Returns:
The fetch variables' name list None
Return Type:
list(str)
Examples: Examples:
.. code-block:: python: .. code-block:: python:
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import FC, to_variable, TracedLayer from paddle.fluid.dygraph import Linear, to_variable, TracedLayer
import paddle.fluid.dygraph.jit as jit
import numpy as np import numpy as np
class ExampleLayer(fluid.dygraph.Layer): class ExampleLayer(fluid.dygraph.Layer):
def __init__(self, name_scope): def __init__(self):
super(ExampleLayer, self).__init__(name_scope) super(ExampleLayer, self).__init__()
self._fc = FC(self.full_name(), 10) self._fc = Linear(3, 10)
def forward(self, input): def forward(self, input):
return self._fc(input) return self._fc(input)
save_dirname = './saved_infer_model'
in_np = np.random.random([2, 3]).astype('float32')
with fluid.dygraph.guard(): with fluid.dygraph.guard():
layer = ExampleLayer("example_layer") layer = ExampleLayer()
in_np = np.random.random([2, 3]).astype('float32')
in_var = to_variable(in_np) in_var = to_variable(in_np)
out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var]) out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])
static_layer.save_inference_model('./saved_infer_model') static_layer.save_inference_model(save_dirname, feed=[0], fetch=[0])
place = fluid.CPUPlace()
exe = fluid.Executor(place)
program, feed_vars, fetch_vars = fluid.io.load_inference_model(save_dirname,
exe)
fetch, = exe.run(program, feed={feed_vars[0]: in_np}, fetch_list=fetch_vars)
print(fetch.shape) # (2, 10)
""" """
from paddle.fluid.io import save_inference_model
def get_feed_fetch(all_vars, partial_vars): def get_feed_fetch(all_vars, partial_vars):
if partial_vars is None: if partial_vars is None:
...@@ -317,7 +330,7 @@ class TracedLayer(object): ...@@ -317,7 +330,7 @@ class TracedLayer(object):
assert target_var is not None, "{} cannot be found".format(name) assert target_var is not None, "{} cannot be found".format(name)
target_vars.append(target_var) target_vars.append(target_var)
return fluid_io.save_inference_model( save_inference_model(
dirname=dirname, dirname=dirname,
feeded_var_names=feeded_var_names, feeded_var_names=feeded_var_names,
target_vars=target_vars, target_vars=target_vars,
......
...@@ -27,7 +27,7 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear ...@@ -27,7 +27,7 @@ from paddle.fluid.dygraph.nn import Conv2D, Pool2D, Linear
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
from paddle.fluid.dygraph.jit import TracedLayer from paddle.fluid.dygraph import TracedLayer
class SimpleImgConvPool(fluid.dygraph.Layer): class SimpleImgConvPool(fluid.dygraph.Layer):
......
...@@ -21,7 +21,7 @@ from paddle.fluid.dygraph.nn import Embedding ...@@ -21,7 +21,7 @@ from paddle.fluid.dygraph.nn import Embedding
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.fluid.optimizer import SGDOptimizer from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.jit import TracedLayer from paddle.fluid.dygraph import TracedLayer
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
import numpy as np import numpy as np
import six import six
......
...@@ -25,7 +25,7 @@ from paddle.fluid import Conv2D, Pool2D, BatchNorm, Linear ...@@ -25,7 +25,7 @@ from paddle.fluid import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
from paddle.fluid.dygraph.jit import TracedLayer from paddle.fluid.dygraph import TracedLayer
batch_size = 8 batch_size = 8
train_parameters = { train_parameters = {
......
...@@ -24,8 +24,6 @@ from paddle.fluid.dygraph.base import to_variable ...@@ -24,8 +24,6 @@ from paddle.fluid.dygraph.base import to_variable
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
import numpy as np import numpy as np
import six import six
from utils import DyGraphProgramDescTracerTestHelper, is_equal_program
from paddle.fluid.dygraph.jit import TracedLayer
class SimpleNet(fluid.Layer): class SimpleNet(fluid.Layer):
...@@ -121,8 +119,6 @@ class TestDygraphSimpleNet(unittest.TestCase): ...@@ -121,8 +119,6 @@ class TestDygraphSimpleNet(unittest.TestCase):
dy_param_init = dict() dy_param_init = dict()
dy_loss = None dy_loss = None
helper = DyGraphProgramDescTracerTestHelper(self)
program = None
backward_strategy = fluid.dygraph.BackwardStrategy() backward_strategy = fluid.dygraph.BackwardStrategy()
backward_strategy.sort_sum_gradient = is_sort_sum_gradient backward_strategy.sort_sum_gradient = is_sort_sum_gradient
......
...@@ -18,7 +18,7 @@ import unittest ...@@ -18,7 +18,7 @@ import unittest
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid import Embedding, LayerNorm, Linear, Layer from paddle.fluid import Embedding, LayerNorm, Linear, Layer
from paddle.fluid.dygraph import to_variable, guard from paddle.fluid.dygraph import to_variable, guard
from paddle.fluid.dygraph.jit import TracedLayer from paddle.fluid.dygraph import TracedLayer
from test_imperative_base import new_program_scope from test_imperative_base import new_program_scope
from paddle.fluid import core from paddle.fluid import core
import numpy as np import numpy as np
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册