未验证 提交 67e88424 编写于 作者: Z Zeng Jinle 提交者: GitHub

Polish jit trace codes (#21218)

* polish jit trace codes, test=develop

* polish codes again by removing var_id, test=develop
上级 cdb3d279
...@@ -21,49 +21,81 @@ namespace paddle { ...@@ -21,49 +21,81 @@ namespace paddle {
namespace imperative { namespace imperative {
namespace jit { namespace jit {
void ProgramDescTracer::SetNamePrefix(const std::string &name_prefix) { // A helper class to generate unique name for each non-persistable var
name_prefix_ = name_prefix; class UniqueBlockVarGenerator {
} public:
UniqueBlockVarGenerator(const VarDescMetaMap &all_vars,
framework::BlockDesc *block);
void ProgramDescTracer::SetFeedVars( std::string NameOf(const std::weak_ptr<VarBase> &var,
const std::vector<std::shared_ptr<VarBase>> &feed_vars, const std::string &prefix);
std::vector<std::string> feed_names) {
feed_vars_.clear();
if (feed_names.empty()) { private:
feed_names.reserve(feed_vars.size()); void InsertNewVarInBlock(const std::weak_ptr<VarBase> &var,
for (auto &var : feed_vars) { const framework::VarDesc &ref_desc,
feed_names.emplace_back(var->Name()); const std::string &name);
}
}
PADDLE_ENFORCE_EQ(feed_names.size(), feed_vars.size(), private:
"The feeded variable names number must be equal to the " const VarDescMetaMap &all_vars_;
"feeded variable number"); framework::BlockDesc *block_;
std::unordered_map<std::string, size_t> counter_;
for (size_t i = 0; i < feed_names.size(); ++i) { std::map<std::weak_ptr<VarBase>, std::string,
feed_vars_[feed_vars[i]] = feed_names[i]; std::owner_less<std::weak_ptr<VarBase>>>
var_to_name_;
std::unordered_set<std::string> existing_names_;
};
UniqueBlockVarGenerator::UniqueBlockVarGenerator(const VarDescMetaMap &all_vars,
framework::BlockDesc *block)
: all_vars_(all_vars), block_(block) {
for (auto &var_pair : all_vars_) {
auto *var_desc = var_pair.second.get();
if (var_desc->Persistable()) {
InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name());
}
} }
} }
void ProgramDescTracer::SetFetchVars( std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr<VarBase> &var,
const std::vector<std::shared_ptr<VarBase>> &fetch_vars, const std::string &prefix) {
std::vector<std::string> fetch_names) { auto all_vars_iter = all_vars_.find(var);
fetch_vars_.clear(); PADDLE_ENFORCE_EQ(all_vars_iter != all_vars_.end(), true,
platform::errors::NotFound(
if (fetch_names.empty()) { "Variable is not found in UniqueBlockVarGenerator"));
fetch_names.reserve(fetch_vars.size());
for (auto &var : fetch_vars) { auto iter = var_to_name_.find(var);
fetch_names.emplace_back(var->Name()); if (iter != var_to_name_.end()) {
} VLOG(5) << "Return existing var name " << iter->second;
return iter->second;
} else {
auto generate_unique_name = [this, &prefix] {
auto &cnt = counter_[prefix];
do {
auto name = prefix + std::to_string(cnt++);
if (existing_names_.count(name) == 0) {
return name;
}
} while (cnt > 0);
PADDLE_THROW(
platform::errors::OutOfRange("Too many vars in the program"));
};
auto unique_name = generate_unique_name();
VLOG(5) << "Generate new var name " << unique_name;
InsertNewVarInBlock(var, *(all_vars_iter->second), unique_name);
return unique_name;
} }
}
PADDLE_ENFORCE_EQ(fetch_names.size(), fetch_vars.size(), void UniqueBlockVarGenerator::InsertNewVarInBlock(
"The fetched variable names number must be equal to the " const std::weak_ptr<VarBase> &var, const framework::VarDesc &var_desc,
"fetched variable number"); const std::string &name) {
for (size_t i = 0; i < fetch_names.size(); ++i) { var_to_name_[var] = name;
fetch_vars_[fetch_vars[i]] = fetch_names[i]; existing_names_.insert(name);
} auto *new_var_desc = block_->Var(name);
*new_var_desc = var_desc;
new_var_desc->SetName(name);
} }
void ProgramDescTracer::InsertOp(const std::string &type, void ProgramDescTracer::InsertOp(const std::string &type,
...@@ -85,70 +117,24 @@ void ProgramDescTracer::InsertOp(const std::string &type, ...@@ -85,70 +117,24 @@ void ProgramDescTracer::InsertOp(const std::string &type,
} }
} }
std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc() TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
const { const std::vector<std::shared_ptr<VarBase>> &feed_vars,
const std::string &feed_prefix,
const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
const std::string &fetch_prefix, const std::string &tmp_prefix) const {
std::unique_ptr<framework::ProgramDesc> prog(new framework::ProgramDesc()); std::unique_ptr<framework::ProgramDesc> prog(new framework::ProgramDesc());
auto *block = prog->MutableBlock(0); auto *block = prog->MutableBlock(0);
size_t var_num = vars_.size(); UniqueBlockVarGenerator generator(vars_, block);
std::vector<framework::VarDesc *> var_descs(var_num, nullptr);
std::unordered_map<framework::VarDesc *, std::weak_ptr<VarBase>>
var_desc_to_var_base;
for (auto &pair : vars_) {
size_t var_id = pair.second.first;
PADDLE_ENFORCE_LT(var_id, var_num);
var_descs[var_id] = pair.second.second.get();
PADDLE_ENFORCE_NOT_NULL(var_descs[var_id]);
var_desc_to_var_base[var_descs[var_id]] = pair.first;
}
std::unordered_set<std::string> existing_var_names;
for (auto *var_desc : var_descs) {
if (var_desc->Persistable()) {
existing_var_names.insert(var_desc->Name());
}
}
for (auto &pair : feed_vars_) {
existing_var_names.insert(pair.second);
}
for (auto &pair : fetch_vars_) { std::vector<std::string> feed_var_names;
existing_var_names.insert(pair.second); for (auto &feed_var : feed_vars) {
feed_var_names.emplace_back(generator.NameOf(feed_var, feed_prefix));
} }
size_t counter = 0; std::vector<std::string> fetch_var_names;
auto generate_unique_name = [&]() -> std::string { for (auto &fetch_var : fetch_vars) {
do { fetch_var_names.emplace_back(generator.NameOf(fetch_var, fetch_prefix));
auto name = name_prefix_ + std::to_string(counter++);
if (existing_var_names.count(name) == 0) {
existing_var_names.insert(name);
return name;
}
} while (counter > 0);
PADDLE_THROW("Too many vars in the program");
};
std::map<std::weak_ptr<VarBase>, std::string,
std::owner_less<std::weak_ptr<VarBase>>>
var_to_name;
for (auto *var_desc : var_descs) {
auto var_name = var_desc->Name();
PADDLE_ENFORCE_EQ(var_desc_to_var_base.count(var_desc), 1);
std::weak_ptr<VarBase> var_base = var_desc_to_var_base.at(var_desc);
if (feed_vars_.count(var_base) > 0) {
var_name = feed_vars_.at(var_base);
} else if (fetch_vars_.count(var_base) > 0) {
var_name = fetch_vars_.at(var_base);
} else if (!var_desc->Persistable()) {
var_name = generate_unique_name();
}
auto *new_var_desc = block->Var(var_name);
*new_var_desc = *var_desc;
new_var_desc->SetName(std::move(var_name));
var_to_name[var_base] = new_var_desc->Name();
} }
for (auto &op : ops_) { for (auto &op : ops_) {
...@@ -160,10 +146,7 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc() ...@@ -160,10 +146,7 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
std::vector<std::string> names; std::vector<std::string> names;
names.reserve(pair.second.size()); names.reserve(pair.second.size());
for (auto &var : pair.second) { for (auto &var : pair.second) {
auto iter = var_to_name.find(var); names.emplace_back(generator.NameOf(var, tmp_prefix));
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
"Cannot find input variable");
names.emplace_back(iter->second);
} }
op_desc->SetInput(pair.first, std::move(names)); op_desc->SetInput(pair.first, std::move(names));
...@@ -173,10 +156,7 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc() ...@@ -173,10 +156,7 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
std::vector<std::string> names; std::vector<std::string> names;
names.reserve(pair.second.size()); names.reserve(pair.second.size());
for (auto &var : pair.second) { for (auto &var : pair.second) {
auto iter = var_to_name.find(var); names.emplace_back(generator.NameOf(var, tmp_prefix));
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
"Cannot find output variable");
names.emplace_back(iter->second);
} }
op_desc->SetOutput(pair.first, std::move(names)); op_desc->SetOutput(pair.first, std::move(names));
...@@ -184,7 +164,8 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc() ...@@ -184,7 +164,8 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
} }
prog->Flush(); prog->Flush();
return prog; return std::make_tuple(std::move(prog), std::move(feed_var_names),
std::move(fetch_var_names));
} }
void ProgramDescTracer::InsertVarIfNotExist( void ProgramDescTracer::InsertVarIfNotExist(
...@@ -192,10 +173,8 @@ void ProgramDescTracer::InsertVarIfNotExist( ...@@ -192,10 +173,8 @@ void ProgramDescTracer::InsertVarIfNotExist(
PADDLE_ENFORCE_NOT_NULL(new_var); PADDLE_ENFORCE_NOT_NULL(new_var);
if (vars_.count(new_var) != 0) return; if (vars_.count(new_var) != 0) return;
size_t var_id = vars_.size();
auto new_var_desc = new framework::VarDesc(""); auto new_var_desc = new framework::VarDesc("");
vars_[new_var] = vars_[new_var].reset(new_var_desc);
std::make_pair(var_id, std::unique_ptr<framework::VarDesc>(new_var_desc));
if (new_var->Persistable()) { if (new_var->Persistable()) {
new_var_desc->SetName(new_var->Name()); new_var_desc->SetName(new_var->Name());
...@@ -225,9 +204,6 @@ void ProgramDescTracer::InsertVarIfNotExist( ...@@ -225,9 +204,6 @@ void ProgramDescTracer::InsertVarIfNotExist(
void ProgramDescTracer::Reset() { void ProgramDescTracer::Reset() {
ops_.clear(); ops_.clear();
vars_.clear(); vars_.clear();
feed_vars_.clear();
fetch_vars_.clear();
name_prefix_.clear();
} }
} // namespace jit } // namespace jit
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <string> #include <string>
#include <tuple>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -29,48 +30,39 @@ namespace paddle { ...@@ -29,48 +30,39 @@ namespace paddle {
namespace imperative { namespace imperative {
namespace jit { namespace jit {
using VarDescMetaMap =
std::map<std::weak_ptr<VarBase>, std::unique_ptr<framework::VarDesc>,
std::owner_less<std::weak_ptr<VarBase>>>;
using TracedProgramTuple =
std::tuple<std::unique_ptr<framework::ProgramDesc> /*program*/,
std::vector<std::string> /*feed_var_names*/,
std::vector<std::string> /*fetch_var_names*/>;
class ProgramDescTracer { class ProgramDescTracer {
DISABLE_COPY_AND_ASSIGN(ProgramDescTracer); DISABLE_COPY_AND_ASSIGN(ProgramDescTracer);
public: public:
ProgramDescTracer() = default; ProgramDescTracer() = default;
void SetNamePrefix(const std::string &name_prefix);
void SetFeedVars(const std::vector<std::shared_ptr<VarBase>> &feed_vars,
std::vector<std::string> feed_names);
void SetFetchVars(const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
std::vector<std::string> fetch_names);
void InsertOp(const std::string &type, const NameVarBaseMap &inputs, void InsertOp(const std::string &type, const NameVarBaseMap &inputs,
const NameVarBaseMap &outputs, const NameVarBaseMap &outputs,
const framework::AttributeMap &attrs); const framework::AttributeMap &attrs);
std::unique_ptr<framework::ProgramDesc> CreateProgramDesc() const; TracedProgramTuple CreateProgramDesc(
const std::vector<std::shared_ptr<VarBase>> &feed_vars,
const std::string &feed_prefix,
const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
const std::string &fetch_prefix, const std::string &tmp_prefix) const;
void Reset(); void Reset();
private: private:
void InsertVarIfNotExist(const std::shared_ptr<VarBase> &new_var); void InsertVarIfNotExist(const std::shared_ptr<VarBase> &new_var);
private:
std::vector<std::unique_ptr<OpDescMeta>> ops_; std::vector<std::unique_ptr<OpDescMeta>> ops_;
VarDescMetaMap vars_;
std::map<std::weak_ptr<VarBase>,
std::pair<size_t, std::unique_ptr<framework::VarDesc>>,
std::owner_less<std::weak_ptr<VarBase>>>
vars_;
// The following fields are used to polish the converted ProgramDesc
std::map<std::weak_ptr<VarBase>, std::string,
std::owner_less<std::weak_ptr<VarBase>>>
feed_vars_;
std::map<std::weak_ptr<VarBase>, std::string,
std::owner_less<std::weak_ptr<VarBase>>>
fetch_vars_;
std::string name_prefix_;
}; };
} // namespace jit } // namespace jit
......
...@@ -328,10 +328,6 @@ void BindImperative(py::module *m_ptr) { ...@@ -328,10 +328,6 @@ void BindImperative(py::module *m_ptr) {
}); });
py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "") py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "")
.def("set_name_prefix",
&imperative::jit::ProgramDescTracer::SetNamePrefix)
.def("set_feed_vars", &imperative::jit::ProgramDescTracer::SetFeedVars)
.def("set_fetch_vars", &imperative::jit::ProgramDescTracer::SetFetchVars)
.def("create_program_desc", .def("create_program_desc",
&imperative::jit::ProgramDescTracer::CreateProgramDesc) &imperative::jit::ProgramDescTracer::CreateProgramDesc)
.def("reset", &imperative::jit::ProgramDescTracer::Reset); .def("reset", &imperative::jit::ProgramDescTracer::Reset);
......
...@@ -47,82 +47,19 @@ def extract_vars(inputs): ...@@ -47,82 +47,19 @@ def extract_vars(inputs):
@dygraph_only @dygraph_only
def _trace(layer, inputs, feed_names=None, fetch_names=None): def _trace(layer,
""" inputs,
Trace dygraph network into a :code:`Program`. The returned :code:`Program` feed_prefix='feed_',
can be run in static graph mode. This method would simply record all fetch_prefix='fetch_',
operators in the network with :code:`inputs` . Users should guarantee that tmp_prefix='t_'):
the traced dygraph network is independent with input data, input shapes,
and would not be changed between different batches. Otherwise, the traced
result may be different.
Args:
layer(Layer): the layer to be traced.
inputs(list): the input arguments of :code:`layer.forward()` method.
feed_names(list(str), optional): the input variable names in the
traced :code:`Program` corresponding to :code:`inputs` . If it
is None, the variable name of :code:`inputs` would be used.
It is suggested that users should set :code:`feed_names`
manually. Otherwise, the input variable names would be
different between different batches. Default None.
fetch_names(list(str), optional): the output variable names in the
traced :code:`Program` corresponding to the output variables
of :code:`layer.forward()` method. If it is None, the variable
name of the outputs of :code:`layer.forward()` would be used.
It is suggested that users should set :code:`fetch_names`
manually. Otherwise, the output variable names would be
different between different batches. Default None.
Returns:
A tuple of 4 items, whose first item is the outputs of
:code:`layer.forward()` method, and second item is the traced
:code:`Program`, and the third item is names of feed variables,
and the fourth item is names of fetch variables.
Examples:
.. code-block:: python:
import paddle.fluid as fluid
from paddle.fluid.dygraph import FC, to_variable
import paddle.fluid.dygraph.jit as jit
import numpy as np
class ExampleLayer(fluid.dygraph.Layer):
def __init__(self, name_scope):
super(ExampleLayer, self).__init__(name_scope)
self._fc = FC(self.full_name(), 10)
def forward(self, input):
return self._fc(input)
with fluid.dygraph.guard():
layer = ExampleLayer("example_layer")
in_np = np.random.random([2, 3]).astype('float32')
in_var = to_variable(in_np)
out, program, _, _ = jit._trace(layer, inputs=[in_var],
feed_names=['input'],
fetch_names=['fc_out'])
"""
assert isinstance(layer, Layer) assert isinstance(layer, Layer)
if not isinstance(inputs, (list, tuple)): if not isinstance(inputs, (list, tuple)):
inputs = [inputs] inputs = [inputs]
if feed_names is None:
feed_names = []
if fetch_names is None:
fetch_names = []
tracer = _dygraph_tracer()._get_program_desc_tracer() tracer = _dygraph_tracer()._get_program_desc_tracer()
var_list = extract_vars(inputs) var_list = extract_vars(inputs)
if callable(feed_names):
feed_names = feed_names(len(var_list))
tracer.set_feed_vars(var_list, feed_names)
with program_desc_tracing_guard(True): with program_desc_tracing_guard(True):
original_outputs = layer(*inputs) original_outputs = layer(*inputs)
...@@ -132,13 +69,8 @@ def _trace(layer, inputs, feed_names=None, fetch_names=None): ...@@ -132,13 +69,8 @@ def _trace(layer, inputs, feed_names=None, fetch_names=None):
outputs = original_outputs outputs = original_outputs
out_vars = [var._ivar for var in outputs] out_vars = [var._ivar for var in outputs]
if callable(fetch_names): program_desc, feed_names, fetch_names = tracer.create_program_desc(
fetch_names = fetch_names(len(out_vars)) var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix)
tracer.set_fetch_vars(out_vars, fetch_names)
tracer.set_name_prefix('t_')
program_desc = tracer.create_program_desc()
tracer.reset() tracer.reset()
with _dygraph_guard(None): with _dygraph_guard(None):
...@@ -233,9 +165,7 @@ class TracedLayer(object): ...@@ -233,9 +165,7 @@ class TracedLayer(object):
out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var]) out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])
out_static_graph = static_layer([in_var]) out_static_graph = static_layer([in_var])
""" """
feed_func = lambda n: ['feed_{}'.format(i) for i in range(n)] outs, prog, feed, fetch = _trace(layer, inputs)
fetch_func = lambda n: ['fetch_{}'.format(i) for i in range(n)]
outs, prog, feed, fetch = _trace(layer, inputs, feed_func, fetch_func)
traced = TracedLayer(prog, layer.parameters(), feed, fetch) traced = TracedLayer(prog, layer.parameters(), feed, fetch)
return outs, traced return outs, traced
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册