未验证 提交 67e88424 编写于 作者: Z Zeng Jinle 提交者: GitHub

Polish jit trace codes (#21218)

* polish jit trace codes, test=develop

* polish codes again by removing var_id, test=develop
上级 cdb3d279
......@@ -21,49 +21,81 @@ namespace paddle {
namespace imperative {
namespace jit {
void ProgramDescTracer::SetNamePrefix(const std::string &name_prefix) {
name_prefix_ = name_prefix;
}
// A helper class to generate unique name for each non-persistable var
class UniqueBlockVarGenerator {
public:
UniqueBlockVarGenerator(const VarDescMetaMap &all_vars,
framework::BlockDesc *block);
void ProgramDescTracer::SetFeedVars(
const std::vector<std::shared_ptr<VarBase>> &feed_vars,
std::vector<std::string> feed_names) {
feed_vars_.clear();
std::string NameOf(const std::weak_ptr<VarBase> &var,
const std::string &prefix);
if (feed_names.empty()) {
feed_names.reserve(feed_vars.size());
for (auto &var : feed_vars) {
feed_names.emplace_back(var->Name());
}
}
private:
void InsertNewVarInBlock(const std::weak_ptr<VarBase> &var,
const framework::VarDesc &ref_desc,
const std::string &name);
PADDLE_ENFORCE_EQ(feed_names.size(), feed_vars.size(),
"The feeded variable names number must be equal to the "
"feeded variable number");
private:
const VarDescMetaMap &all_vars_;
framework::BlockDesc *block_;
std::unordered_map<std::string, size_t> counter_;
for (size_t i = 0; i < feed_names.size(); ++i) {
feed_vars_[feed_vars[i]] = feed_names[i];
std::map<std::weak_ptr<VarBase>, std::string,
std::owner_less<std::weak_ptr<VarBase>>>
var_to_name_;
std::unordered_set<std::string> existing_names_;
};
UniqueBlockVarGenerator::UniqueBlockVarGenerator(const VarDescMetaMap &all_vars,
framework::BlockDesc *block)
: all_vars_(all_vars), block_(block) {
for (auto &var_pair : all_vars_) {
auto *var_desc = var_pair.second.get();
if (var_desc->Persistable()) {
InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name());
}
}
}
void ProgramDescTracer::SetFetchVars(
const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
std::vector<std::string> fetch_names) {
fetch_vars_.clear();
if (fetch_names.empty()) {
fetch_names.reserve(fetch_vars.size());
for (auto &var : fetch_vars) {
fetch_names.emplace_back(var->Name());
}
std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr<VarBase> &var,
const std::string &prefix) {
auto all_vars_iter = all_vars_.find(var);
PADDLE_ENFORCE_EQ(all_vars_iter != all_vars_.end(), true,
platform::errors::NotFound(
"Variable is not found in UniqueBlockVarGenerator"));
auto iter = var_to_name_.find(var);
if (iter != var_to_name_.end()) {
VLOG(5) << "Return existing var name " << iter->second;
return iter->second;
} else {
auto generate_unique_name = [this, &prefix] {
auto &cnt = counter_[prefix];
do {
auto name = prefix + std::to_string(cnt++);
if (existing_names_.count(name) == 0) {
return name;
}
} while (cnt > 0);
PADDLE_THROW(
platform::errors::OutOfRange("Too many vars in the program"));
};
auto unique_name = generate_unique_name();
VLOG(5) << "Generate new var name " << unique_name;
InsertNewVarInBlock(var, *(all_vars_iter->second), unique_name);
return unique_name;
}
}
PADDLE_ENFORCE_EQ(fetch_names.size(), fetch_vars.size(),
"The fetched variable names number must be equal to the "
"fetched variable number");
for (size_t i = 0; i < fetch_names.size(); ++i) {
fetch_vars_[fetch_vars[i]] = fetch_names[i];
}
void UniqueBlockVarGenerator::InsertNewVarInBlock(
const std::weak_ptr<VarBase> &var, const framework::VarDesc &var_desc,
const std::string &name) {
var_to_name_[var] = name;
existing_names_.insert(name);
auto *new_var_desc = block_->Var(name);
*new_var_desc = var_desc;
new_var_desc->SetName(name);
}
void ProgramDescTracer::InsertOp(const std::string &type,
......@@ -85,70 +117,24 @@ void ProgramDescTracer::InsertOp(const std::string &type,
}
}
std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
const {
TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
const std::vector<std::shared_ptr<VarBase>> &feed_vars,
const std::string &feed_prefix,
const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
const std::string &fetch_prefix, const std::string &tmp_prefix) const {
std::unique_ptr<framework::ProgramDesc> prog(new framework::ProgramDesc());
auto *block = prog->MutableBlock(0);
size_t var_num = vars_.size();
std::vector<framework::VarDesc *> var_descs(var_num, nullptr);
std::unordered_map<framework::VarDesc *, std::weak_ptr<VarBase>>
var_desc_to_var_base;
for (auto &pair : vars_) {
size_t var_id = pair.second.first;
PADDLE_ENFORCE_LT(var_id, var_num);
var_descs[var_id] = pair.second.second.get();
PADDLE_ENFORCE_NOT_NULL(var_descs[var_id]);
var_desc_to_var_base[var_descs[var_id]] = pair.first;
}
std::unordered_set<std::string> existing_var_names;
for (auto *var_desc : var_descs) {
if (var_desc->Persistable()) {
existing_var_names.insert(var_desc->Name());
}
}
for (auto &pair : feed_vars_) {
existing_var_names.insert(pair.second);
}
UniqueBlockVarGenerator generator(vars_, block);
for (auto &pair : fetch_vars_) {
existing_var_names.insert(pair.second);
std::vector<std::string> feed_var_names;
for (auto &feed_var : feed_vars) {
feed_var_names.emplace_back(generator.NameOf(feed_var, feed_prefix));
}
size_t counter = 0;
auto generate_unique_name = [&]() -> std::string {
do {
auto name = name_prefix_ + std::to_string(counter++);
if (existing_var_names.count(name) == 0) {
existing_var_names.insert(name);
return name;
}
} while (counter > 0);
PADDLE_THROW("Too many vars in the program");
};
std::map<std::weak_ptr<VarBase>, std::string,
std::owner_less<std::weak_ptr<VarBase>>>
var_to_name;
for (auto *var_desc : var_descs) {
auto var_name = var_desc->Name();
PADDLE_ENFORCE_EQ(var_desc_to_var_base.count(var_desc), 1);
std::weak_ptr<VarBase> var_base = var_desc_to_var_base.at(var_desc);
if (feed_vars_.count(var_base) > 0) {
var_name = feed_vars_.at(var_base);
} else if (fetch_vars_.count(var_base) > 0) {
var_name = fetch_vars_.at(var_base);
} else if (!var_desc->Persistable()) {
var_name = generate_unique_name();
}
auto *new_var_desc = block->Var(var_name);
*new_var_desc = *var_desc;
new_var_desc->SetName(std::move(var_name));
var_to_name[var_base] = new_var_desc->Name();
std::vector<std::string> fetch_var_names;
for (auto &fetch_var : fetch_vars) {
fetch_var_names.emplace_back(generator.NameOf(fetch_var, fetch_prefix));
}
for (auto &op : ops_) {
......@@ -160,10 +146,7 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
std::vector<std::string> names;
names.reserve(pair.second.size());
for (auto &var : pair.second) {
auto iter = var_to_name.find(var);
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
"Cannot find input variable");
names.emplace_back(iter->second);
names.emplace_back(generator.NameOf(var, tmp_prefix));
}
op_desc->SetInput(pair.first, std::move(names));
......@@ -173,10 +156,7 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
std::vector<std::string> names;
names.reserve(pair.second.size());
for (auto &var : pair.second) {
auto iter = var_to_name.find(var);
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
"Cannot find output variable");
names.emplace_back(iter->second);
names.emplace_back(generator.NameOf(var, tmp_prefix));
}
op_desc->SetOutput(pair.first, std::move(names));
......@@ -184,7 +164,8 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
}
prog->Flush();
return prog;
return std::make_tuple(std::move(prog), std::move(feed_var_names),
std::move(fetch_var_names));
}
void ProgramDescTracer::InsertVarIfNotExist(
......@@ -192,10 +173,8 @@ void ProgramDescTracer::InsertVarIfNotExist(
PADDLE_ENFORCE_NOT_NULL(new_var);
if (vars_.count(new_var) != 0) return;
size_t var_id = vars_.size();
auto new_var_desc = new framework::VarDesc("");
vars_[new_var] =
std::make_pair(var_id, std::unique_ptr<framework::VarDesc>(new_var_desc));
vars_[new_var].reset(new_var_desc);
if (new_var->Persistable()) {
new_var_desc->SetName(new_var->Name());
......@@ -225,9 +204,6 @@ void ProgramDescTracer::InsertVarIfNotExist(
void ProgramDescTracer::Reset() {
ops_.clear();
vars_.clear();
feed_vars_.clear();
fetch_vars_.clear();
name_prefix_.clear();
}
} // namespace jit
......
......@@ -17,6 +17,7 @@
#include <map>
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
......@@ -29,48 +30,39 @@ namespace paddle {
namespace imperative {
namespace jit {
using VarDescMetaMap =
std::map<std::weak_ptr<VarBase>, std::unique_ptr<framework::VarDesc>,
std::owner_less<std::weak_ptr<VarBase>>>;
using TracedProgramTuple =
std::tuple<std::unique_ptr<framework::ProgramDesc> /*program*/,
std::vector<std::string> /*feed_var_names*/,
std::vector<std::string> /*fetch_var_names*/>;
class ProgramDescTracer {
DISABLE_COPY_AND_ASSIGN(ProgramDescTracer);
public:
ProgramDescTracer() = default;
void SetNamePrefix(const std::string &name_prefix);
void SetFeedVars(const std::vector<std::shared_ptr<VarBase>> &feed_vars,
std::vector<std::string> feed_names);
void SetFetchVars(const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
std::vector<std::string> fetch_names);
void InsertOp(const std::string &type, const NameVarBaseMap &inputs,
const NameVarBaseMap &outputs,
const framework::AttributeMap &attrs);
std::unique_ptr<framework::ProgramDesc> CreateProgramDesc() const;
TracedProgramTuple CreateProgramDesc(
const std::vector<std::shared_ptr<VarBase>> &feed_vars,
const std::string &feed_prefix,
const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
const std::string &fetch_prefix, const std::string &tmp_prefix) const;
void Reset();
private:
void InsertVarIfNotExist(const std::shared_ptr<VarBase> &new_var);
private:
std::vector<std::unique_ptr<OpDescMeta>> ops_;
std::map<std::weak_ptr<VarBase>,
std::pair<size_t, std::unique_ptr<framework::VarDesc>>,
std::owner_less<std::weak_ptr<VarBase>>>
vars_;
// The following fields are used to polish the converted ProgramDesc
std::map<std::weak_ptr<VarBase>, std::string,
std::owner_less<std::weak_ptr<VarBase>>>
feed_vars_;
std::map<std::weak_ptr<VarBase>, std::string,
std::owner_less<std::weak_ptr<VarBase>>>
fetch_vars_;
std::string name_prefix_;
VarDescMetaMap vars_;
};
} // namespace jit
......
......@@ -328,10 +328,6 @@ void BindImperative(py::module *m_ptr) {
});
py::class_<imperative::jit::ProgramDescTracer>(m, "ProgramDescTracer", "")
.def("set_name_prefix",
&imperative::jit::ProgramDescTracer::SetNamePrefix)
.def("set_feed_vars", &imperative::jit::ProgramDescTracer::SetFeedVars)
.def("set_fetch_vars", &imperative::jit::ProgramDescTracer::SetFetchVars)
.def("create_program_desc",
&imperative::jit::ProgramDescTracer::CreateProgramDesc)
.def("reset", &imperative::jit::ProgramDescTracer::Reset);
......
......@@ -47,82 +47,19 @@ def extract_vars(inputs):
@dygraph_only
def _trace(layer, inputs, feed_names=None, fetch_names=None):
"""
Trace dygraph network into a :code:`Program`. The returned :code:`Program`
can be run in static graph mode. This method would simply record all
operators in the network with :code:`inputs` . Users should guarantee that
the traced dygraph network is independent with input data, input shapes,
and would not be changed between different batches. Otherwise, the traced
result may be different.
Args:
layer(Layer): the layer to be traced.
inputs(list): the input arguments of :code:`layer.forward()` method.
feed_names(list(str), optional): the input variable names in the
traced :code:`Program` corresponding to :code:`inputs` . If it
is None, the variable name of :code:`inputs` would be used.
It is suggested that users should set :code:`feed_names`
manually. Otherwise, the input variable names would be
different between different batches. Default None.
fetch_names(list(str), optional): the output variable names in the
traced :code:`Program` corresponding to the output variables
of :code:`layer.forward()` method. If it is None, the variable
name of the outputs of :code:`layer.forward()` would be used.
It is suggested that users should set :code:`fetch_names`
manually. Otherwise, the output variable names would be
different between different batches. Default None.
Returns:
A tuple of 4 items, whose first item is the outputs of
:code:`layer.forward()` method, and second item is the traced
:code:`Program`, and the third item is names of feed variables,
and the fourth item is names of fetch variables.
Examples:
.. code-block:: python:
import paddle.fluid as fluid
from paddle.fluid.dygraph import FC, to_variable
import paddle.fluid.dygraph.jit as jit
import numpy as np
class ExampleLayer(fluid.dygraph.Layer):
def __init__(self, name_scope):
super(ExampleLayer, self).__init__(name_scope)
self._fc = FC(self.full_name(), 10)
def forward(self, input):
return self._fc(input)
with fluid.dygraph.guard():
layer = ExampleLayer("example_layer")
in_np = np.random.random([2, 3]).astype('float32')
in_var = to_variable(in_np)
out, program, _, _ = jit._trace(layer, inputs=[in_var],
feed_names=['input'],
fetch_names=['fc_out'])
"""
def _trace(layer,
inputs,
feed_prefix='feed_',
fetch_prefix='fetch_',
tmp_prefix='t_'):
assert isinstance(layer, Layer)
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
if feed_names is None:
feed_names = []
if fetch_names is None:
fetch_names = []
tracer = _dygraph_tracer()._get_program_desc_tracer()
var_list = extract_vars(inputs)
if callable(feed_names):
feed_names = feed_names(len(var_list))
tracer.set_feed_vars(var_list, feed_names)
with program_desc_tracing_guard(True):
original_outputs = layer(*inputs)
......@@ -132,13 +69,8 @@ def _trace(layer, inputs, feed_names=None, fetch_names=None):
outputs = original_outputs
out_vars = [var._ivar for var in outputs]
if callable(fetch_names):
fetch_names = fetch_names(len(out_vars))
tracer.set_fetch_vars(out_vars, fetch_names)
tracer.set_name_prefix('t_')
program_desc = tracer.create_program_desc()
program_desc, feed_names, fetch_names = tracer.create_program_desc(
var_list, feed_prefix, out_vars, fetch_prefix, tmp_prefix)
tracer.reset()
with _dygraph_guard(None):
......@@ -233,9 +165,7 @@ class TracedLayer(object):
out_dygraph, static_layer = TracedLayer.trace(layer, inputs=[in_var])
out_static_graph = static_layer([in_var])
"""
feed_func = lambda n: ['feed_{}'.format(i) for i in range(n)]
fetch_func = lambda n: ['fetch_{}'.format(i) for i in range(n)]
outs, prog, feed, fetch = _trace(layer, inputs, feed_func, fetch_func)
outs, prog, feed, fetch = _trace(layer, inputs)
traced = TracedLayer(prog, layer.parameters(), feed, fetch)
return outs, traced
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册