From 67e88424e54eac3e9be2771af5a2d76b637d8f25 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Tue, 19 Nov 2019 10:07:42 +0800 Subject: [PATCH] Polish jit trace codes (#21218) * polish jit trace codes, test=develop * polish codes again by removing var_id, test=develop --- .../imperative/jit/program_desc_tracer.cc | 190 ++++++++---------- .../imperative/jit/program_desc_tracer.h | 42 ++-- paddle/fluid/pybind/imperative.cc | 4 - python/paddle/fluid/dygraph/jit.py | 86 +------- 4 files changed, 108 insertions(+), 214 deletions(-) diff --git a/paddle/fluid/imperative/jit/program_desc_tracer.cc b/paddle/fluid/imperative/jit/program_desc_tracer.cc index 8c08ab47e1..2e92facb06 100644 --- a/paddle/fluid/imperative/jit/program_desc_tracer.cc +++ b/paddle/fluid/imperative/jit/program_desc_tracer.cc @@ -21,49 +21,81 @@ namespace paddle { namespace imperative { namespace jit { -void ProgramDescTracer::SetNamePrefix(const std::string &name_prefix) { - name_prefix_ = name_prefix; -} +// A helper class to generate unique name for each non-persistable var +class UniqueBlockVarGenerator { + public: + UniqueBlockVarGenerator(const VarDescMetaMap &all_vars, + framework::BlockDesc *block); -void ProgramDescTracer::SetFeedVars( - const std::vector> &feed_vars, - std::vector feed_names) { - feed_vars_.clear(); + std::string NameOf(const std::weak_ptr &var, + const std::string &prefix); - if (feed_names.empty()) { - feed_names.reserve(feed_vars.size()); - for (auto &var : feed_vars) { - feed_names.emplace_back(var->Name()); - } - } + private: + void InsertNewVarInBlock(const std::weak_ptr &var, + const framework::VarDesc &ref_desc, + const std::string &name); - PADDLE_ENFORCE_EQ(feed_names.size(), feed_vars.size(), - "The feeded variable names number must be equal to the " - "feeded variable number"); + private: + const VarDescMetaMap &all_vars_; + framework::BlockDesc *block_; + std::unordered_map counter_; - for (size_t i = 0; i < feed_names.size(); ++i) { - feed_vars_[feed_vars[i]] = feed_names[i]; + std::map, std::string, + std::owner_less>> + var_to_name_; + std::unordered_set existing_names_; +}; + +UniqueBlockVarGenerator::UniqueBlockVarGenerator(const VarDescMetaMap &all_vars, + framework::BlockDesc *block) + : all_vars_(all_vars), block_(block) { + for (auto &var_pair : all_vars_) { + auto *var_desc = var_pair.second.get(); + if (var_desc->Persistable()) { + InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name()); + } } } -void ProgramDescTracer::SetFetchVars( - const std::vector> &fetch_vars, - std::vector fetch_names) { - fetch_vars_.clear(); - - if (fetch_names.empty()) { - fetch_names.reserve(fetch_vars.size()); - for (auto &var : fetch_vars) { - fetch_names.emplace_back(var->Name()); - } +std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr &var, + const std::string &prefix) { + auto all_vars_iter = all_vars_.find(var); + PADDLE_ENFORCE_EQ(all_vars_iter != all_vars_.end(), true, + platform::errors::NotFound( + "Variable is not found in UniqueBlockVarGenerator")); + + auto iter = var_to_name_.find(var); + if (iter != var_to_name_.end()) { + VLOG(5) << "Return existing var name " << iter->second; + return iter->second; + } else { + auto generate_unique_name = [this, &prefix] { + auto &cnt = counter_[prefix]; + do { + auto name = prefix + std::to_string(cnt++); + if (existing_names_.count(name) == 0) { + return name; + } + } while (cnt > 0); + PADDLE_THROW( + platform::errors::OutOfRange("Too many vars in the program")); + }; + + auto unique_name = generate_unique_name(); + VLOG(5) << "Generate new var name " << unique_name; + InsertNewVarInBlock(var, *(all_vars_iter->second), unique_name); + return unique_name; } +} - PADDLE_ENFORCE_EQ(fetch_names.size(), fetch_vars.size(), - "The fetched variable names number must be equal to the " - "fetched variable number"); - for (size_t i = 0; i < fetch_names.size(); ++i) { - fetch_vars_[fetch_vars[i]] = fetch_names[i]; - } +void UniqueBlockVarGenerator::InsertNewVarInBlock( + const std::weak_ptr &var, const framework::VarDesc &var_desc, + const std::string &name) { + var_to_name_[var] = name; + existing_names_.insert(name); + auto *new_var_desc = block_->Var(name); + *new_var_desc = var_desc; + new_var_desc->SetName(name); } void ProgramDescTracer::InsertOp(const std::string &type, @@ -85,70 +117,24 @@ void ProgramDescTracer::InsertOp(const std::string &type, } } -std::unique_ptr ProgramDescTracer::CreateProgramDesc() - const { +TracedProgramTuple ProgramDescTracer::CreateProgramDesc( + const std::vector> &feed_vars, + const std::string &feed_prefix, + const std::vector> &fetch_vars, + const std::string &fetch_prefix, const std::string &tmp_prefix) const { std::unique_ptr prog(new framework::ProgramDesc()); auto *block = prog->MutableBlock(0); - size_t var_num = vars_.size(); - std::vector var_descs(var_num, nullptr); - std::unordered_map> - var_desc_to_var_base; - - for (auto &pair : vars_) { - size_t var_id = pair.second.first; - PADDLE_ENFORCE_LT(var_id, var_num); - var_descs[var_id] = pair.second.second.get(); - PADDLE_ENFORCE_NOT_NULL(var_descs[var_id]); - var_desc_to_var_base[var_descs[var_id]] = pair.first; - } - - std::unordered_set existing_var_names; - for (auto *var_desc : var_descs) { - if (var_desc->Persistable()) { - existing_var_names.insert(var_desc->Name()); - } - } - - for (auto &pair : feed_vars_) { - existing_var_names.insert(pair.second); - } + UniqueBlockVarGenerator generator(vars_, block); - for (auto &pair : fetch_vars_) { - existing_var_names.insert(pair.second); + std::vector feed_var_names; + for (auto &feed_var : feed_vars) { + feed_var_names.emplace_back(generator.NameOf(feed_var, feed_prefix)); } - size_t counter = 0; - auto generate_unique_name = [&]() -> std::string { - do { - auto name = name_prefix_ + std::to_string(counter++); - if (existing_var_names.count(name) == 0) { - existing_var_names.insert(name); - return name; - } - } while (counter > 0); - PADDLE_THROW("Too many vars in the program"); - }; - - std::map, std::string, - std::owner_less>> - var_to_name; - for (auto *var_desc : var_descs) { - auto var_name = var_desc->Name(); - PADDLE_ENFORCE_EQ(var_desc_to_var_base.count(var_desc), 1); - std::weak_ptr var_base = var_desc_to_var_base.at(var_desc); - if (feed_vars_.count(var_base) > 0) { - var_name = feed_vars_.at(var_base); - } else if (fetch_vars_.count(var_base) > 0) { - var_name = fetch_vars_.at(var_base); - } else if (!var_desc->Persistable()) { - var_name = generate_unique_name(); - } - - auto *new_var_desc = block->Var(var_name); - *new_var_desc = *var_desc; - new_var_desc->SetName(std::move(var_name)); - var_to_name[var_base] = new_var_desc->Name(); + std::vector fetch_var_names; + for (auto &fetch_var : fetch_vars) { + fetch_var_names.emplace_back(generator.NameOf(fetch_var, fetch_prefix)); } for (auto &op : ops_) { @@ -160,10 +146,7 @@ std::unique_ptr ProgramDescTracer::CreateProgramDesc() std::vector names; names.reserve(pair.second.size()); for (auto &var : pair.second) { - auto iter = var_to_name.find(var); - PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true, - "Cannot find input variable"); - names.emplace_back(iter->second); + names.emplace_back(generator.NameOf(var, tmp_prefix)); } op_desc->SetInput(pair.first, std::move(names)); @@ -173,10 +156,7 @@ std::unique_ptr ProgramDescTracer::CreateProgramDesc() std::vector names; names.reserve(pair.second.size()); for (auto &var : pair.second) { - auto iter = var_to_name.find(var); - PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true, - "Cannot find output variable"); - names.emplace_back(iter->second); + names.emplace_back(generator.NameOf(var, tmp_prefix)); } op_desc->SetOutput(pair.first, std::move(names)); @@ -184,7 +164,8 @@ std::unique_ptr ProgramDescTracer::CreateProgramDesc() } prog->Flush(); - return prog; + return std::make_tuple(std::move(prog), std::move(feed_var_names), + std::move(fetch_var_names)); } void ProgramDescTracer::InsertVarIfNotExist( @@ -192,10 +173,8 @@ void ProgramDescTracer::InsertVarIfNotExist( PADDLE_ENFORCE_NOT_NULL(new_var); if (vars_.count(new_var) != 0) return; - size_t var_id = vars_.size(); auto new_var_desc = new framework::VarDesc(""); - vars_[new_var] = - std::make_pair(var_id, std::unique_ptr(new_var_desc)); + vars_[new_var].reset(new_var_desc); if (new_var->Persistable()) { new_var_desc->SetName(new_var->Name()); @@ -225,9 +204,6 @@ void ProgramDescTracer::InsertVarIfNotExist( void ProgramDescTracer::Reset() { ops_.clear(); vars_.clear(); - feed_vars_.clear(); - fetch_vars_.clear(); - name_prefix_.clear(); } } // namespace jit diff --git a/paddle/fluid/imperative/jit/program_desc_tracer.h b/paddle/fluid/imperative/jit/program_desc_tracer.h index 08e5957bdb..4ef29d0f44 100644 --- a/paddle/fluid/imperative/jit/program_desc_tracer.h +++ b/paddle/fluid/imperative/jit/program_desc_tracer.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include "paddle/fluid/framework/program_desc.h" @@ -29,48 +30,39 @@ namespace paddle { namespace imperative { namespace jit { +using VarDescMetaMap = + std::map, std::unique_ptr, + std::owner_less>>; + +using TracedProgramTuple = + std::tuple /*program*/, + std::vector /*feed_var_names*/, + std::vector /*fetch_var_names*/>; + class ProgramDescTracer { DISABLE_COPY_AND_ASSIGN(ProgramDescTracer); public: ProgramDescTracer() = default; - void SetNamePrefix(const std::string &name_prefix); - - void SetFeedVars(const std::vector> &feed_vars, - std::vector feed_names); - - void SetFetchVars(const std::vector> &fetch_vars, - std::vector fetch_names); - void InsertOp(const std::string &type, const NameVarBaseMap &inputs, const NameVarBaseMap &outputs, const framework::AttributeMap &attrs); - std::unique_ptr CreateProgramDesc() const; + TracedProgramTuple CreateProgramDesc( + const std::vector> &feed_vars, + const std::string &feed_prefix, + const std::vector> &fetch_vars, + const std::string &fetch_prefix, const std::string &tmp_prefix) const; void Reset(); private: void InsertVarIfNotExist(const std::shared_ptr &new_var); + private: std::vector> ops_; - - std::map, - std::pair>, - std::owner_less>> - vars_; - - // The following fields are used to polish the converted ProgramDesc - std::map, std::string, - std::owner_less>> - feed_vars_; - - std::map, std::string, - std::owner_less>> - fetch_vars_; - - std::string name_prefix_; + VarDescMetaMap vars_; }; } // namespace jit diff --git a/paddle/fluid/pybind/imperative.cc b/paddle/fluid/pybind/imperative.cc index 36739ae462..92b4bf1522 100644 --- a/paddle/fluid/pybind/imperative.cc +++ b/paddle/fluid/pybind/imperative.cc @@ -328,10 +328,6 @@ void BindImperative(py::module *m_ptr) { }); py::class_(m, "ProgramDescTracer", "") - .def("set_name_prefix", - &imperative::jit::ProgramDescTracer::SetNamePrefix) - .def("set_feed_vars", &imperative::jit::ProgramDescTracer::SetFeedVars) - .def("set_fetch_vars", &imperative::jit::ProgramDescTracer::SetFetchVars) .def("create_program_desc", &imperative::jit::ProgramDescTracer::CreateProgramDesc) .def("reset", &imperative::jit::ProgramDescTracer::Reset); diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index e3d56ad009..4ea82e2ca0 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -47,82 +47,19 @@ def extract_vars(inputs): @dygraph_only -def _trace(layer, inputs, feed_names=None, fetch_names=None): - """ - Trace dygraph network into a :code:`Program`. The returned :code:`Program` - can be run in static graph mode. This method would simply record all - operators in the network with :code:`inputs` . Users should guarantee that - the traced dygraph network is independent with input data, input shapes, - and would not be changed between different batches. Otherwise, the traced - result may be different. - - Args: - layer(Layer): the layer to be traced. - inputs(list): the input arguments of :code:`layer.forward()` method. - feed_names(list(str), optional): the input variable names in the - traced :code:`Program` corresponding to :code:`inputs` . If it - is None, the variable name of :code:`inputs` would be used. - It is suggested that users should set :code:`feed_names` - manually. Otherwise, the input variable names would be - different between different batches. Default None. - fetch_names(list(str), optional): the output variable names in the - traced :code:`Program` corresponding to the output variables - of :code:`layer.forward()` method. If it is None, the variable - name of the outputs of :code:`layer.forward()` would be used. - It is suggested that users should set :code:`fetch_names` - manually. Otherwise, the output variable names would be - different between different batches. Default None. - - Returns: - A tuple of 4 items, whose first item is the outputs of - :code:`layer.forward()` method, and second item is the traced - :code:`Program`, and the third item is names of feed variables, - and the fourth item is names of fetch variables. - - Examples: - - .. code-block:: python: - - import paddle.fluid as fluid - from paddle.fluid.dygraph import FC, to_variable - import paddle.fluid.dygraph.jit as jit - import numpy as np - - class ExampleLayer(fluid.dygraph.Layer): - def __init__(self, name_scope): - super(ExampleLayer, self).__init__(name_scope) - self._fc = FC(self.full_name(), 10) - - def forward(self, input): - return self._fc(input) - - with fluid.dygraph.guard(): - layer = ExampleLayer("example_layer") - in_np = np.random.random([2, 3]).astype('float32') - in_var = to_variable(in_np) - out, program, _, _ = jit._trace(layer, inputs=[in_var], - feed_names=['input'], - fetch_names=['fc_out']) - - """ +def _trace(layer, + inputs, + feed_prefix='feed_', + fetch_prefix='fetch_', + tmp_prefix='t_'): assert isinstance(layer, Layer) if not isinstance(inputs, (list, tuple)): inputs = [inputs] - if feed_names is None: - feed_names = [] - - if fetch_names is None: - fetch_names = [] - tracer = _dygraph_tracer()._get_program_desc_tracer() var_list = extract_vars(inputs) - if callable(feed_names): - feed_names = feed_names(len(var_list)) - - tracer.set_feed_vars(var_list, feed_names) with program_desc_tracing_guard(True): original_outputs = layer(*inputs) @@ -132,13 +69,8 @@ def _trace(layer, inputs, feed_names=None, fetch_names=None): outputs = original_outputs out_vars = [var._ivar for var in outputs] - if callable(fetch_names): - fetch_names = fetch_names(len(out_vars)) - - tracer.set_fetch_vars(out_vars, fetch_names) - tracer.set_name_prefix('t_') - - program_desc = tracer.create_program_desc() + program_desc, feed_names, fetch_names = tracer.create_program_desc( + var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix) tracer.reset() with _dygraph_guard(None): @@ -233,9 +165,7 @@ class TracedLayer(object): out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var]) out_static_graph = static_layer([in_var]) """ - feed_func = lambda n: ['feed_{}'.format(i) for i in range(n)] - fetch_func = lambda n: ['fetch_{}'.format(i) for i in range(n)] - outs, prog, feed, fetch = _trace(layer, inputs, feed_func, fetch_func) + outs, prog, feed, fetch = _trace(layer, inputs) traced = TracedLayer(prog, layer.parameters(), feed, fetch) return outs, traced -- GitLab