From 378fc4fb1c88de358442825c9625b9d4bb1d2a52 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Thu, 24 Oct 2019 21:10:29 +0800 Subject: [PATCH] add some docs to jit.trace, test=develop (#20811) --- .../imperative/jit/program_desc_tracer.h | 2 - python/paddle/fluid/dygraph/jit.py | 63 +++++++++++++++++-- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/imperative/jit/program_desc_tracer.h b/paddle/fluid/imperative/jit/program_desc_tracer.h index 599dbe49cea..08e5957bdb8 100644 --- a/paddle/fluid/imperative/jit/program_desc_tracer.h +++ b/paddle/fluid/imperative/jit/program_desc_tracer.h @@ -14,10 +14,8 @@ #pragma once -#include #include #include -#include #include #include #include diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 9f4d30b078a..02de2729612 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -14,7 +14,6 @@ __all__ = ['trace'] -from . import layers from .base import program_desc_tracing_guard from .layers import Layer from paddle.fluid.framework import Program, Block, Variable, _dygraph_tracer, dygraph_only, _dygraph_guard @@ -44,8 +43,64 @@ def extract_vars(inputs): @dygraph_only -def trace(module, inputs, feed_names=None, fetch_names=None): - assert isinstance(module, Layer) +def trace(layer, inputs, feed_names=None, fetch_names=None): + """ + Trace dygraph network into a :code:`Program`. The returned :code:`Program` + can be run in static graph mode. This method would simply record all + operators in the network with :code:`inputs` . Users should guarantee that + the traced dygraph network is independent with input data, input shapes, + and would not be changed between different batches. Otherwise, the traced + result may be different. + + Parameters: + layer(Layer): the layer to be traced. + inputs(list): the input arguments of :code:`layer.forward()` method. + feed_names(list(str), optional): the input variable names in the + traced :code:`Program` corresponding to :code:`inputs` . If it + is None, the variable name of :code:`inputs` would be used. + It is suggested that users should set :code:`feed_names` + manually. Otherwise, the input variable names would be + different between different batches. Default None. + fetch_names(list(str), optional): the output variable names in the + traced :code:`Program` corresponding to the output variables + of :code:`layer.forward()` method. If it is None, the variable + name of the outputs of :code:`layer.forward()` would be used. + It is suggested that users should set :code:`fetch_names` + manually. Otherwise, the output variable names would be + different between different batches. Default None. + + Returns: + A tuple of 2 items, whose first item is the outputs of + :code:`layer.forward()` method, and second item is the traced + :code:`Program` . + + Examples: + + .. code-blocks: python: + + import paddle.fluid as fluid + from paddle.fluid.dygraph import FC, to_variable + import paddle.fluid.dygraph.jit as jit + import numpy as np + + class ExampleLayer(fluid.dygraph.Layer): + def __init__(self, name_scope): + super(ExampleLayer, self).__init__(name_scope) + self._fc = FC(self.full_name(), 10) + + def forward(self, input): + return self._fc(input) + + with fluid.dygraph.guard(): + layer = ExampleLayer("example_layer") + in_np = np.random.random([2, 3]).astype('float32') + in_var = to_variable(in_np) + out, program = jit.trace(layer, inputs=[in_var], + feed_names=['input'], + fetch_names=['fc_out']) + + """ + assert isinstance(layer, Layer) if not isinstance(inputs, (list, tuple)): inputs = [inputs] @@ -62,7 +117,7 @@ def trace(module, inputs, feed_names=None, fetch_names=None): tracer.set_feed_vars(var_list, feed_names) with program_desc_tracing_guard(True): - original_outputs = module.__call__(*inputs) + original_outputs = layer(*inputs) if not isinstance(original_outputs, (list, tuple)): outputs = [original_outputs] else: -- GitLab