未验证 提交 a62599a8 编写于 作者: L Leo Chen 提交者: GitHub

[feature] prune program by feed and fetch_list automatically (#22474)

* prune train program by fetch_list, test=develop

* add unittest for prune, test=develop

* fix pruned feed, test=develop

* support ParallelExecutor and feed prune, test=develop

* add comments, test=develop

* update unittest, test=develop

* update unittests, test=develop

* remove debug code, test=develop

* support cond in clone, test=develop

* support cond in prune, test=develop

* support multiple minimize, test=develop

* support cache, test=develop

* fix _copy_param_info_from, test=develop

* support python2 str, test=develop

* remove debug code, test=develop

* fix bug of caching CompiledProgram, test=develop

* fix multi_device issue, test=develop

* tmp

* support tuple in fetch_list and overriding use_prune, test=develop

* dont use nonlocal in python2, test=develop

* remove nonlocal, test=develop

* code clean, test=develop

* code clean, test=develop

* feed list, test=develop

* test adam, test=develop

* follow comments, test=develop

* reduce duplicate code, test=develop

* update comments, test=develop
上级 7c55a94d
...@@ -113,7 +113,6 @@ bool HasSubBlock(const proto::OpDesc& op_desc) { ...@@ -113,7 +113,6 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
} }
int GetOpRole(const proto::OpDesc& op_desc) { int GetOpRole(const proto::OpDesc& op_desc) {
// The op role >= 0, so -1 is used to indicate "NotFound".
for (auto& attr : op_desc.attrs()) { for (auto& attr : op_desc.attrs()) {
if (attr.name() == OpProtoAndCheckerMaker::OpRoleAttrName()) { if (attr.name() == OpProtoAndCheckerMaker::OpRoleAttrName()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -124,7 +123,10 @@ int GetOpRole(const proto::OpDesc& op_desc) { ...@@ -124,7 +123,10 @@ int GetOpRole(const proto::OpDesc& op_desc) {
return attr.i(); return attr.i();
} }
} }
return -1; // If attr op_role is not found, it may be operator created in c++ test, like
// prune_test.cc. In that case, the op_role should be defaut value, which is
// kNotSpecified.
return static_cast<int>(OpRole::kNotSpecified);
} }
void AppendOpInputVarNames(const proto::OpDesc& op_desc, void AppendOpInputVarNames(const proto::OpDesc& op_desc,
...@@ -145,6 +147,16 @@ void AppendOpOutputVarNames(const proto::OpDesc& op_desc, ...@@ -145,6 +147,16 @@ void AppendOpOutputVarNames(const proto::OpDesc& op_desc,
} }
} }
int FindMapByValue(const std::map<int, int>& m, int val) {
// The content in map should be >= 0, so -1 is used to indicate "NotFound".
for (auto& pair : m) {
if (pair.second == val) {
return pair.first;
}
}
return -1;
}
// block_id is the idx of the current block in the input desc // block_id is the idx of the current block in the input desc
// parent_block_id is the idx of the parent of the current block // parent_block_id is the idx of the parent of the current block
// in the output desc, -1 means the current block is global block // in the output desc, -1 means the current block is global block
...@@ -153,30 +165,41 @@ void AppendOpOutputVarNames(const proto::OpDesc& op_desc, ...@@ -153,30 +165,41 @@ void AppendOpOutputVarNames(const proto::OpDesc& op_desc,
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
int block_id, int parent_block_id, int block_id, int parent_block_id,
std::unordered_set<std::string>* dependent_vars, std::unordered_set<std::string>* dependent_vars,
const std::set<std::string> feed_var_names) { const std::set<std::string> feed_var_names,
std::map<int, int>* pruned_origin_block_id_map) {
auto& block = input.blocks(block_id); auto& block = input.blocks(block_id);
auto& ops = block.ops(); auto& ops = block.ops();
bool expect_feed = true; bool expect_feed = true;
for (auto& op_desc : ops) { for (auto& op_desc : ops) {
PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed, PADDLE_ENFORCE_EQ(
"All FeedOps are at the beginning of the ProgramDesc"); op_desc.type() != kFeedOpType || expect_feed, true,
platform::errors::PreconditionNotMet(
"All FeedOps are at the beginning of the ProgramDesc"));
expect_feed = (op_desc.type() == kFeedOpType); expect_feed = (op_desc.type() == kFeedOpType);
} }
bool expect_fetch = true; bool expect_fetch = true;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter; auto& op_desc = *op_iter;
PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch, PADDLE_ENFORCE_EQ(op_desc.type() != kFetchOpType || expect_fetch, true,
"All FetchOps must at the end of the ProgramDesc"); platform::errors::PreconditionNotMet(
"All FetchOps must at the end of the ProgramDesc"));
expect_fetch = (op_desc.type() == kFetchOpType); expect_fetch = (op_desc.type() == kFetchOpType);
} }
std::vector<bool> should_run; std::vector<bool> should_run;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) { for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter; auto& op_desc = *op_iter;
if (IsTarget(op_desc) || HasDependentOutputVar(op_desc, *dependent_vars)) {
// insert its input to the dependency graph if (IsTarget(op_desc) ||
(HasDependentOutputVar(op_desc, *dependent_vars) &&
(GetOpRole(op_desc) & static_cast<int>(OpRole::kOptimize)) == 0)) {
// NOTE(zhiqiu): since optimize op takes the trainable parameters as
// inputs and output, it may introduce wrong dependency graph.
// For train mode, the optimize op should be in targets, so is not need
// and not right to mark optimize op by its outputs.
// For eval / infer mode, there is no optimize op in program.
for (auto& var : op_desc.inputs()) { for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) { for (auto& argu : var.arguments()) {
if (feed_var_names.count(argu) == 0) { if (feed_var_names.count(argu) == 0) {
...@@ -203,6 +226,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -203,6 +226,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
output_block->set_idx(output_block_id); output_block->set_idx(output_block_id);
output_block->set_parent_idx(parent_block_id); output_block->set_parent_idx(parent_block_id);
(*pruned_origin_block_id_map)[output_block_id] = block_id;
auto* op_field = output_block->mutable_ops(); auto* op_field = output_block->mutable_ops();
op_field->Clear(); op_field->Clear();
for (size_t i = 0; i < should_run.size(); ++i) { for (size_t i = 0; i < should_run.size(); ++i) {
...@@ -244,7 +269,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -244,7 +269,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc // GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc // output_block_id is the idx of the current block in the output desc
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id, prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
&sub_block_dependent_vars, feed_var_names); &sub_block_dependent_vars, feed_var_names,
pruned_origin_block_id_map);
} }
} }
} }
...@@ -284,22 +310,33 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output, ...@@ -284,22 +310,33 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
} }
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies // TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const proto::ProgramDesc& input, std::map<int, int> Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names, const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output) { proto::ProgramDesc* output) {
std::unordered_set<std::string> dependent_vars; std::unordered_set<std::string> dependent_vars;
output->clear_blocks(); output->clear_blocks();
prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names); std::map<int, int> pruned_origin_block_id_map;
} prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names,
&pruned_origin_block_id_map);
int FindMapByValue(const std::map<int, int>& m, int val) { // update subblock idx
// The content in map should be >= 0, so -1 is used to indicate "NotFound". for (int i = 0; i < output->blocks_size(); i++) {
for (auto& pair : m) { auto* pruned = output->mutable_blocks(i);
if (pair.second == val) { auto* ops = pruned->mutable_ops();
return pair.first; for (auto op_iter = ops->rbegin(); op_iter != ops->rend(); ++op_iter) {
auto& op_desc = *op_iter;
if (HasSubBlock(op_desc)) {
int origin_sub_idx = GetSubBlockIndex(op_desc);
auto sub_idx =
FindMapByValue(pruned_origin_block_id_map, origin_sub_idx);
PADDLE_ENFORCE_NE(sub_idx, -1,
platform::errors::NotFound(
"The origin sub block id should be found in "
"pruned_progin_block_id_map"));
SetSubBlockIndex(&op_desc, sub_idx);
} }
} }
return -1; }
return pruned_origin_block_id_map;
} }
void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) { void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) {
...@@ -348,8 +385,8 @@ void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) { ...@@ -348,8 +385,8 @@ void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) {
var_names.insert(op_output_vars.begin(), op_output_vars.end()); var_names.insert(op_output_vars.begin(), op_output_vars.end());
for (const auto& name : var_names) { for (const auto& name : var_names) {
if (var_map.count(name)) { if (var_map.count(name)) {
// NOTE(zhiqiu): For operator in a conditional block, the related vars may // NOTE(zhiqiu): For operator in a conditional block, the related vars
// not exist in current block, but in its futher block. // may not exist in current block, but in its futher block.
*pruned_vars->Add() = var_map[name]; *pruned_vars->Add() = var_map[name];
} }
} }
...@@ -389,6 +426,7 @@ std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward( ...@@ -389,6 +426,7 @@ std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
proto::ProgramDesc pruned_desc; proto::ProgramDesc pruned_desc;
pruned_desc.clear_blocks(); pruned_desc.clear_blocks();
// Step 2. Prune backward for each block. // Step 2. Prune backward for each block.
for (size_t i = 0; i < origin_clone.Size(); i++) { for (size_t i = 0; i < origin_clone.Size(); i++) {
auto pruned = proto::BlockDesc(); auto pruned = proto::BlockDesc();
......
...@@ -26,7 +26,7 @@ limitations under the License. */ ...@@ -26,7 +26,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void Prune(const proto::ProgramDesc& input, std::map<int, int> Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names, const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output); proto::ProgramDesc* output);
......
...@@ -1154,8 +1154,10 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -1154,8 +1154,10 @@ All parameter, weight, gradient are variables in Paddle.
prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true); prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true);
} }
proto::ProgramDesc pruned_desc; proto::ProgramDesc pruned_desc;
auto pruned_origin_block_id_map =
Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc); Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc);
return new ProgramDesc(pruned_desc); return std::make_tuple(ProgramDesc(pruned_desc),
pruned_origin_block_id_map);
}); });
m.def("prune_backward", m.def("prune_backward",
[](const framework::ProgramDesc &program) { [](const framework::ProgramDesc &program) {
......
...@@ -23,12 +23,13 @@ import numpy as np ...@@ -23,12 +23,13 @@ import numpy as np
from .wrapped_decorator import signature_safe_contextmanager from .wrapped_decorator import signature_safe_contextmanager
import six import six
from .data_feeder import convert_dtype from .data_feeder import convert_dtype
from .framework import Program, default_main_program, Variable, convert_np_dtype_to_dtype_ from .framework import Program, default_main_program, Variable, Operator, convert_np_dtype_to_dtype_
from . import core from . import core
from . import compiler from . import compiler
from .. import compat as cpt from .. import compat as cpt
from .trainer_factory import TrainerFactory from .trainer_factory import TrainerFactory
from .trainer_factory import FetchHandlerMonitor from .trainer_factory import FetchHandlerMonitor
import copy
__all__ = ['Executor', 'global_scope', 'scope_guard'] __all__ = ['Executor', 'global_scope', 'scope_guard']
...@@ -345,14 +346,27 @@ def _fetch_var(name, scope=None, return_numpy=True): ...@@ -345,14 +346,27 @@ def _fetch_var(name, scope=None, return_numpy=True):
def _to_name_str(var): def _to_name_str(var):
def _to_str(var):
if isinstance(var, Variable): if isinstance(var, Variable):
return var.desc.name() return var.desc.name()
elif isinstance(var, str): elif isinstance(var, str):
return var return var
elif isinstance(var, six.string_types): elif isinstance(var, six.string_types):
return str(var) return str(var)
elif isinstance(var, Operator):
return var.desc.type()
else: else:
raise TypeError(str(var) + " should be Variable or str") raise TypeError(str(var) + " should be Variable, Operator or str")
# NOTEz(zhiqiu): The item in fetch_list may be tuple returned by Optimizer.minimize(),
# see comments in _split_optimize_ops_in_fetch_list for more details.
if isinstance(var, tuple):
var = var[0]
if isinstance(var, list):
s = [_to_str(item) for item in var]
return ','.join(s)
else:
return _to_str(var)
def _get_strong_program_cache_key(program, feed, fetch_list): def _get_strong_program_cache_key(program, feed, fetch_list):
...@@ -360,9 +374,13 @@ def _get_strong_program_cache_key(program, feed, fetch_list): ...@@ -360,9 +374,13 @@ def _get_strong_program_cache_key(program, feed, fetch_list):
def _get_program_cache_key(feed, fetch_list): def _get_program_cache_key(feed, fetch_list):
feed_var_names = []
if isinstance(feed, dict):
feed_var_names = list(feed.keys()) feed_var_names = list(feed.keys())
elif isinstance(feed, list) or isinstance(feed, tuple):
for i, each in enumerate(feed):
feed_var_names += list(each.keys())
fetch_var_names = list(map(_to_name_str, fetch_list)) fetch_var_names = list(map(_to_name_str, fetch_list))
return str(feed_var_names + fetch_var_names) return str(feed_var_names + fetch_var_names)
...@@ -503,10 +521,12 @@ class Executor(object): ...@@ -503,10 +521,12 @@ class Executor(object):
self.ctx_caches = dict() self.ctx_caches = dict()
self.scope_caches = dict() self.scope_caches = dict()
self.var_caches = dict() self.var_caches = dict()
self.pruned_program_caches = dict()
p = core.Place() p = core.Place()
p.set_place(self.place) p.set_place(self.place)
self._default_executor = core.Executor(p) self._default_executor = core.Executor(p)
self._closed = False self._closed = False
self.pruned_program_scope_caches = dict()
def _get_scope_cache(self, program_cache_key): def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None) return self.scope_caches.get(program_cache_key, None)
...@@ -520,6 +540,18 @@ class Executor(object): ...@@ -520,6 +540,18 @@ class Executor(object):
def _add_program_cache(self, program_cache_key, program): def _add_program_cache(self, program_cache_key, program):
self.program_caches[program_cache_key] = program self.program_caches[program_cache_key] = program
def _get_pruned_program_cache(self, program_cache_key):
return self.pruned_program_caches.get(program_cache_key, None)
def _add_pruned_program_cache(self, program_cache_key, program):
self.pruned_program_caches[program_cache_key] = program
def _get_pruned_program_scope_cache(self, program_cache_key):
return self.pruned_program_scope_caches.get(program_cache_key, None)
def _add_pruned_program_scope_cache(self, program_cache_key, program):
self.pruned_program_scope_caches[program_cache_key] = program
def _add_ctx_cache(self, ctx_cache_key, ctx): def _add_ctx_cache(self, ctx_cache_key, ctx):
self.ctx_caches[ctx_cache_key] = ctx self.ctx_caches[ctx_cache_key] = ctx
...@@ -551,13 +583,17 @@ class Executor(object): ...@@ -551,13 +583,17 @@ class Executor(object):
# prepend feed operators # prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name): if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed): for i, name in enumerate(feed):
if global_block.has_var(name):
out = global_block.var(name) out = global_block.var(name)
global_block._prepend_op( global_block._prepend_op(
type='feed', type='feed',
inputs={'X': [feed_var]}, inputs={'X': [feed_var]},
outputs={'Out': [out]}, outputs={'Out': [out]},
attrs={'col': i}) attrs={'col': i})
else:
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% name)
# append fetch_operators # append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name): if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list): for i, var in enumerate(fetch_list):
...@@ -595,6 +631,159 @@ class Executor(object): ...@@ -595,6 +631,159 @@ class Executor(object):
] ]
return outs return outs
def _split_optimize_ops_in_fetch_list(self, fetch_list):
"""
Split optimize_ops from fetch_list, which provided to specify program prunning.
Args:
fetch_list(list): The original fetch_list.
Possible types of fetch_list are:
fetch_list = ['loss']
fetch_list = [[sgd, sgd], 'loss']
fetch_list = [([sgd, sgd], [(param, grad)]), 'loss']
Returns:
optimize_ops(list): The optimize operators splited from fetch_list.
fetch_list(list): The updated fetch_list which does not contain optimize operators.
"""
_optimize_ops = []
_fetch_list = []
def _get_targets(_optimize_ops, _fetch_list, item):
if isinstance(item, Operator):
if item._is_optimize_op():
_optimize_ops.append(item)
else:
raise TypeError(
"The operator in fetch_list is not an optimize_op")
elif isinstance(item, Variable) or isinstance(
item, str) or isinstance(item, six.string_types):
_fetch_list.append(item)
else:
raise TypeError(
"The item in fetch_list should be str, variable or optimize_op, but recieved %s.",
type(item))
for item in fetch_list:
# NOTE(zhiqiu): to support (optimizer_ops, param_and_grads) and optimizer_ops in fetch_list
# we should handle tuple and list in fetch_list.
# TODO(zhiqiu): find a better way to handle that.
if isinstance(item, list):
for i in item:
_get_targets(_optimize_ops, _fetch_list, i)
elif isinstance(item, tuple):
for i in item[0]:
_get_targets(_optimize_ops, _fetch_list, i)
else:
_get_targets(_optimize_ops, _fetch_list, item)
return _fetch_list, _optimize_ops
def _prune_program(self,
program,
feed=None,
fetch_list=None,
optimize_ops=None):
"""
Prune operators and variables which are not needed to generate
:code:`fetch_list` and optimize operators.
Prune operators and variables which are needed
to generate variables to be feeded.
Notes: This is a very low level API. Users should not use this API
directly.
Args:
program(Program): the origin program
feed(list|dict): feed dict or list.
fetch_list(list|Variable): A list of variables need to be fetched
optimize_ops(list[Operator]): A list of optimizer operators
Returns:
Program: A new, pruned program.
"""
compiled = isinstance(program, compiler.CompiledProgram)
if compiled:
if program._program:
origin_program = program._program
else:
warnings.warn(
"The program holds no _program, maybe it is constructed by graph, which can't be pruned yet."
)
return
else:
origin_program = program
feed_names = []
if isinstance(feed, dict):
feed_names = list(feed.keys())
elif isinstance(feed, list) or isinstance(feed, tuple):
for i, each in enumerate(feed):
feed_names += list(each.keys())
# if optimize_ops is [], all optimize ops in the program is used.
if not optimize_ops:
for block in origin_program.blocks:
for op in block.ops:
if op._is_optimize_op():
optimize_ops.append(op)
targets = fetch_list + optimize_ops
pruned_program = origin_program._prune_with_input(feed_names, targets)
if compiled:
# for compiled program, update the underlying program, re-generate graph,
# and reset the flag so it can be compiled again.
program._program = pruned_program
program._graph = core.Graph(pruned_program.desc)
program._compiled = False
else:
program = pruned_program
return program
def _update_feed(self, program, feed):
"""
Update the feed dict, remove the feed item which is pruned in program.
Notes: This is a very low level API. Users should not use this API
directly.
Args:
program(Program): the pruned program.
feed(list|dict): feed dict or list.
Returns:
feed:(list|dict) updated feed.
"""
compiled = isinstance(program, compiler.CompiledProgram)
if compiled:
if program._program:
global_block = program._program.global_block()
else:
warnings.warn(
"The program holds no _program, maybe it is constructed by graph."
)
else:
global_block = program.global_block()
if isinstance(feed, dict):
for feed_name in list(feed.keys()):
if not global_block.has_var(feed_name):
feed.pop(feed_name)
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% feed_name)
elif isinstance(feed, list) or isinstance(feed, tuple):
for i, each in enumerate(feed):
for feed_name in list(each.keys()):
if not global_block.has_var(feed_name):
each.pop(feed_name)
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% feed_name)
return feed
''' '''
TODO(typhoonzero): Define "no longer use" meaning? Can user create TODO(typhoonzero): Define "no longer use" meaning? Can user create
a new Executor for the same program and run? a new Executor for the same program and run?
...@@ -682,7 +871,8 @@ class Executor(object): ...@@ -682,7 +871,8 @@ class Executor(object):
scope=None, scope=None,
return_numpy=True, return_numpy=True,
use_program_cache=False, use_program_cache=False,
return_merged=True): return_merged=True,
use_prune=False):
""" """
Run the specified :code:`Program` or :code:`CompiledProgram`. It should be noted that the executor Run the specified :code:`Program` or :code:`CompiledProgram`. It should be noted that the executor
will execute all the operators in :code:`Program` or :code:`CompiledProgram` without pruning some will execute all the operators in :code:`Program` or :code:`CompiledProgram` without pruning some
...@@ -732,6 +922,13 @@ class Executor(object): ...@@ -732,6 +922,13 @@ class Executor(object):
set :code:`return_merged` as False, which denotes that the fetched results will not be merged. set :code:`return_merged` as False, which denotes that the fetched results will not be merged.
The default is True, but it is just for the compatibility, and may use False as default value The default is True, but it is just for the compatibility, and may use False as default value
in the future version. in the future version.
use_prune(bool): This parameter indicates whether the input :code:`Program` will be pruned.
If the parameter is True, the program will be pruned accroding to the given feed and fetch_list,
which means the operators and variables in program that generate :code:`feed` and are not
needed to generate :code:`fetch_list` will be pruned. The default is False, which means the
program will not pruned and all the operators and variables will be executed during running.
Note that if the tuple returned from :code:`Optimizer.minimize()` is passed to :code:`fetch_list`,
:code:`use_prune` will be overrided to True, and the program will be pruned.
Returns: Returns:
...@@ -844,6 +1041,7 @@ class Executor(object): ...@@ -844,6 +1041,7 @@ class Executor(object):
scope=scope, scope=scope,
return_numpy=return_numpy, return_numpy=return_numpy,
use_program_cache=use_program_cache, use_program_cache=use_program_cache,
use_prune=use_prune,
return_merged=return_merged) return_merged=return_merged)
except Exception as e: except Exception as e:
if not isinstance(e, core.EOFException): if not isinstance(e, core.EOFException):
...@@ -853,7 +1051,7 @@ class Executor(object): ...@@ -853,7 +1051,7 @@ class Executor(object):
def _run_impl(self, program, feed, fetch_list, feed_var_name, def _run_impl(self, program, feed, fetch_list, feed_var_name,
fetch_var_name, scope, return_numpy, use_program_cache, fetch_var_name, scope, return_numpy, use_program_cache,
return_merged): return_merged, use_prune):
if self._closed: if self._closed:
raise RuntimeError("Attempted to use a closed Executor") raise RuntimeError("Attempted to use a closed Executor")
...@@ -877,7 +1075,9 @@ class Executor(object): ...@@ -877,7 +1075,9 @@ class Executor(object):
scope = global_scope() scope = global_scope()
if fetch_list is not None: if fetch_list is not None:
if isinstance(fetch_list, Variable) or isinstance(fetch_list, str): if isinstance(fetch_list, Variable) or isinstance(
fetch_list, str) or isinstance(fetch_list,
six.string_types):
fetch_list = [fetch_list] fetch_list = [fetch_list]
assert isinstance(fetch_list, tuple) or isinstance(fetch_list, list), \ assert isinstance(fetch_list, tuple) or isinstance(fetch_list, list), \
"Currently , The fetch_list type only should be list or tuple, \n"\ "Currently , The fetch_list type only should be list or tuple, \n"\
...@@ -886,6 +1086,38 @@ class Executor(object): ...@@ -886,6 +1086,38 @@ class Executor(object):
else: else:
fetch_list = [] fetch_list = []
# use_prune can be overrided by putting optimize_ops in fetch_list
_origin_fetch_list = fetch_list
_origin_program = program
fetch_list, optimize_ops = self._split_optimize_ops_in_fetch_list(
fetch_list)
if optimize_ops:
use_prune = True
if use_prune:
cache_key = _get_strong_program_cache_key(program, feed,
_origin_fetch_list)
cached_pruned_program = self._get_pruned_program_cache(cache_key)
if cached_pruned_program is None:
if isinstance(program, compiler.CompiledProgram):
program_scope_cache = self._get_pruned_program_scope_cache(
str(id(_origin_program)))
# copy the original program, so it can be cached.
program = copy.copy(program)
# share the local scopes for same original CompiledProgram.
program._share_vars_from = program_scope_cache
if self._get_pruned_program_scope_cache(
str(id(_origin_program))) is None:
self._add_pruned_program_scope_cache(
str(id(_origin_program)), program)
pruned_program = self._prune_program(program, feed, fetch_list,
optimize_ops)
self._add_pruned_program_cache(cache_key, pruned_program)
else:
pruned_program = cached_pruned_program
feed = self._update_feed(pruned_program, feed)
program = pruned_program
compiled = isinstance(program, compiler.CompiledProgram) compiled = isinstance(program, compiler.CompiledProgram)
# For backward compatibility, run directly. # For backward compatibility, run directly.
......
...@@ -2172,6 +2172,15 @@ class Operator(object): ...@@ -2172,6 +2172,15 @@ class Operator(object):
return attr_map return attr_map
def _is_optimize_op(self):
op_maker = core.op_proto_and_checker_maker
OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role = self.desc.attr(op_maker.kOpRoleAttrName())
if op_role & int(OPTIMIZE):
return True
else:
return False
class Block(object): class Block(object):
""" """
...@@ -2706,8 +2715,8 @@ class Block(object): ...@@ -2706,8 +2715,8 @@ class Block(object):
assert isinstance(p, Parameter) assert isinstance(p, Parameter)
v = self.vars.get(p.name, None) v = self.vars.get(p.name, None)
if v is None: if v is None:
raise ValueError("_copy_param_info_from should be invoked with " # if the Parameter is pruned, v may be None
"same topology") continue
assert isinstance(v, Variable) assert isinstance(v, Variable)
new_p = None new_p = None
if in_dygraph_mode(): if in_dygraph_mode():
...@@ -4056,52 +4065,13 @@ class Program(object): ...@@ -4056,52 +4065,13 @@ class Program(object):
directly. This API is in flux and not stable. directly. This API is in flux and not stable.
Args: Args:
targets(list|Variable|Operator): A list of variables or operators targets(list|Variable|Operator): A list of variables, operators, or variable names
need to be pruned need to be pruned
Returns: Returns:
Program: A new, pruned program. Program: A new, pruned program.
""" """
return self._prune_with_input([], targets)
#NOTE(zhiqiu): we sync the original program first, since its program may diff with
# its desc due to modifying desc in c++ space. E.g. save op will add kLookupTablePath in desc.
self._sync_with_cpp()
if not isinstance(targets, list):
targets = [targets]
targets_idx = []
for t in targets:
if not isinstance(t, Operator):
if isinstance(t, Variable):
# After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this
# variable here.
t.op = None
global_block = self.global_block()
for idx, op in enumerate(global_block.ops):
if t.name in op.output_arg_names:
t.op = op
break
t = t.op
if t is None:
raise ValueError(
"The target variable must have an "
"associated operator that generates it.")
else:
raise ValueError("All targets of prune() can only be "
"Variable or Operator.")
targets_idx.append([t.block.idx, t.idx])
res = Program()
res.desc = core.prune(self.desc, set(), targets_idx)
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
res._sync_with_cpp()
return res
def _prune_with_input(self, feeded_var_names, targets): def _prune_with_input(self, feeded_var_names, targets):
""" """
...@@ -4115,7 +4085,7 @@ class Program(object): ...@@ -4115,7 +4085,7 @@ class Program(object):
Args: Args:
feeded_var_names(list|str): A list of variable names from where feeded_var_names(list|str): A list of variable names from where
pruning start. If it is set as [], this API works just like _prune() pruning start. If it is set as [], this API works just like _prune()
targets(list|Variable|Operator): A list of variables or operators targets(list|Variable|Operator): A list of variables, operators, or variable names
need to be pruned need to be pruned
Returns: Returns:
...@@ -4140,33 +4110,47 @@ class Program(object): ...@@ -4140,33 +4110,47 @@ class Program(object):
for t in targets: for t in targets:
if not isinstance(t, Operator): if not isinstance(t, Operator):
if isinstance(t, Variable): if isinstance(t, Variable):
name = t.name
elif isinstance(t, six.string_types):
name = str(t)
else:
raise ValueError("All targets of prune() can only be "
"Variable or Operator.")
# After transpiler processing, the op that output this # After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable # variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this # and we need to find the current op that generate this
# variable here. # variable here.
t.op = None target_op = None
global_block = self.global_block() global_block = self.global_block()
for idx, op in enumerate(global_block.ops): for idx, op in enumerate(global_block.ops):
if t.name in op.output_arg_names: if name in op.output_arg_names:
t.op = op # NOTE(zhiqiu): Find op that generate target name.
# Skip optimize op except for optimize op in targets,
# since optimize op generates parameters.
if op._is_optimize_op() and op not in targets:
continue
else:
target_op = op
break break
t = target_op
t = t.op
if t is None: if t is None:
raise ValueError( raise ValueError("The target variable must have an "
"The target variable must have an "
"associated operator that generates it.") "associated operator that generates it.")
else:
raise ValueError("All targets of prune() can only be "
"Variable or Operator.")
targets_idx.append([t.block.idx, t.idx]) targets_idx.append([t.block.idx, t.idx])
res = Program() res = Program()
res.desc = core.prune(self.desc, set(feeded_var_names), targets_idx) res.desc, pruned_origin_block_id_map = core.prune(self.desc,
set(feeded_var_names),
targets_idx)
res.blocks = [ res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks()) Block(res, i) for i in six.moves.range(res.desc.num_blocks())
] ]
res._sync_with_cpp() res._sync_with_cpp()
res._copy_param_info_from(self)
res._copy_data_info_from(self, pruned_origin_block_id_map)
res._copy_dist_param_info_from(self)
return res return res
def _inference_optimize(self, prune_read_op=True): def _inference_optimize(self, prune_read_op=True):
......
...@@ -811,6 +811,9 @@ class Optimizer(object): ...@@ -811,6 +811,9 @@ class Optimizer(object):
tuple: tuple (optimize_ops, params_grads), A list of operators appended tuple: tuple (optimize_ops, params_grads), A list of operators appended
by minimize and a list of (param, grad) variable pairs, param is by minimize and a list of (param, grad) variable pairs, param is
``Parameter``, grad is the gradient value corresponding to the parameter. ``Parameter``, grad is the gradient value corresponding to the parameter.
The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
indicate program pruning. If so, the program will be pruned by ``feed`` and
``fetch_list`` before run, see details in ``Executor``.
Examples: Examples:
Please refer to the example of current Optimizer. Please refer to the example of current Optimizer.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册