未验证 提交 a62599a8 编写于 作者: L Leo Chen 提交者: GitHub

[feature] prune program by feed and fetch_list automatically (#22474)

* prune train program by fetch_list, test=develop

* add unittest for prune, test=develop

* fix pruned feed, test=develop

* support ParallelExecutor and feed prune, test=develop

* add comments, test=develop

* update unittest, test=develop

* update unittests, test=develop

* remove debug code, test=develop

* support cond in clone, test=develop

* support cond in prune, test=develop

* support multiple minimize, test=develop

* support cache, test=develop

* fix _copy_param_info_from, test=develop

* support python2 str, test=develop

* remove debug code, test=develop

* fix bug of caching CompiledProgram, test=develop

* fix multi_device issue, test=develop

* tmp

* support tuple in fetch_list and overriding use_prune, test=develop

* dont use nonlocal in python2, test=develop

* remove nonlocal, test=develop

* code clean, test=develop

* code clean, test=develop

* feed list, test=develop

* test adam, test=develop

* follow comments, test=develop

* reduce duplicate code, test=develop

* update comments, test=develop
上级 7c55a94d
......@@ -113,7 +113,6 @@ bool HasSubBlock(const proto::OpDesc& op_desc) {
}
int GetOpRole(const proto::OpDesc& op_desc) {
// The op role >= 0, so -1 is used to indicate "NotFound".
for (auto& attr : op_desc.attrs()) {
if (attr.name() == OpProtoAndCheckerMaker::OpRoleAttrName()) {
PADDLE_ENFORCE_EQ(
......@@ -124,7 +123,10 @@ int GetOpRole(const proto::OpDesc& op_desc) {
return attr.i();
}
}
return -1;
// If attr op_role is not found, it may be operator created in c++ test, like
// prune_test.cc. In that case, the op_role should be defaut value, which is
// kNotSpecified.
return static_cast<int>(OpRole::kNotSpecified);
}
void AppendOpInputVarNames(const proto::OpDesc& op_desc,
......@@ -145,6 +147,16 @@ void AppendOpOutputVarNames(const proto::OpDesc& op_desc,
}
}
int FindMapByValue(const std::map<int, int>& m, int val) {
// The content in map should be >= 0, so -1 is used to indicate "NotFound".
for (auto& pair : m) {
if (pair.second == val) {
return pair.first;
}
}
return -1;
}
// block_id is the idx of the current block in the input desc
// parent_block_id is the idx of the parent of the current block
// in the output desc, -1 means the current block is global block
......@@ -153,30 +165,41 @@ void AppendOpOutputVarNames(const proto::OpDesc& op_desc,
void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
int block_id, int parent_block_id,
std::unordered_set<std::string>* dependent_vars,
const std::set<std::string> feed_var_names) {
const std::set<std::string> feed_var_names,
std::map<int, int>* pruned_origin_block_id_map) {
auto& block = input.blocks(block_id);
auto& ops = block.ops();
bool expect_feed = true;
for (auto& op_desc : ops) {
PADDLE_ENFORCE(op_desc.type() != kFeedOpType || expect_feed,
"All FeedOps are at the beginning of the ProgramDesc");
PADDLE_ENFORCE_EQ(
op_desc.type() != kFeedOpType || expect_feed, true,
platform::errors::PreconditionNotMet(
"All FeedOps are at the beginning of the ProgramDesc"));
expect_feed = (op_desc.type() == kFeedOpType);
}
bool expect_fetch = true;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
PADDLE_ENFORCE(op_desc.type() != kFetchOpType || expect_fetch,
"All FetchOps must at the end of the ProgramDesc");
PADDLE_ENFORCE_EQ(op_desc.type() != kFetchOpType || expect_fetch, true,
platform::errors::PreconditionNotMet(
"All FetchOps must at the end of the ProgramDesc"));
expect_fetch = (op_desc.type() == kFetchOpType);
}
std::vector<bool> should_run;
for (auto op_iter = ops.rbegin(); op_iter != ops.rend(); ++op_iter) {
auto& op_desc = *op_iter;
if (IsTarget(op_desc) || HasDependentOutputVar(op_desc, *dependent_vars)) {
// insert its input to the dependency graph
if (IsTarget(op_desc) ||
(HasDependentOutputVar(op_desc, *dependent_vars) &&
(GetOpRole(op_desc) & static_cast<int>(OpRole::kOptimize)) == 0)) {
// NOTE(zhiqiu): since optimize op takes the trainable parameters as
// inputs and output, it may introduce wrong dependency graph.
// For train mode, the optimize op should be in targets, so is not need
// and not right to mark optimize op by its outputs.
// For eval / infer mode, there is no optimize op in program.
for (auto& var : op_desc.inputs()) {
for (auto& argu : var.arguments()) {
if (feed_var_names.count(argu) == 0) {
......@@ -203,6 +226,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
output_block->set_idx(output_block_id);
output_block->set_parent_idx(parent_block_id);
(*pruned_origin_block_id_map)[output_block_id] = block_id;
auto* op_field = output_block->mutable_ops();
op_field->Clear();
for (size_t i = 0; i < should_run.size(); ++i) {
......@@ -244,7 +269,8 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
// GetSubBlockIndex(*op) is the idx of the sub_block in the input desc
// output_block_id is the idx of the current block in the output desc
prune_impl(input, output, GetSubBlockIndex(*op), output_block_id,
&sub_block_dependent_vars, feed_var_names);
&sub_block_dependent_vars, feed_var_names,
pruned_origin_block_id_map);
}
}
}
......@@ -284,22 +310,33 @@ void prune_impl(const proto::ProgramDesc& input, proto::ProgramDesc* output,
}
// TODO(fengjiayi): Prune() could be inplaced to avoid unnecessary copies
void Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output) {
std::map<int, int> Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output) {
std::unordered_set<std::string> dependent_vars;
output->clear_blocks();
prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names);
}
int FindMapByValue(const std::map<int, int>& m, int val) {
// The content in map should be >= 0, so -1 is used to indicate "NotFound".
for (auto& pair : m) {
if (pair.second == val) {
return pair.first;
std::map<int, int> pruned_origin_block_id_map;
prune_impl(input, output, 0, -1, &dependent_vars, feed_var_names,
&pruned_origin_block_id_map);
// update subblock idx
for (int i = 0; i < output->blocks_size(); i++) {
auto* pruned = output->mutable_blocks(i);
auto* ops = pruned->mutable_ops();
for (auto op_iter = ops->rbegin(); op_iter != ops->rend(); ++op_iter) {
auto& op_desc = *op_iter;
if (HasSubBlock(op_desc)) {
int origin_sub_idx = GetSubBlockIndex(op_desc);
auto sub_idx =
FindMapByValue(pruned_origin_block_id_map, origin_sub_idx);
PADDLE_ENFORCE_NE(sub_idx, -1,
platform::errors::NotFound(
"The origin sub block id should be found in "
"pruned_progin_block_id_map"));
SetSubBlockIndex(&op_desc, sub_idx);
}
}
}
return -1;
return pruned_origin_block_id_map;
}
void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) {
......@@ -348,8 +385,8 @@ void PruneBackwardImpl(proto::BlockDesc* origin, proto::BlockDesc* pruned) {
var_names.insert(op_output_vars.begin(), op_output_vars.end());
for (const auto& name : var_names) {
if (var_map.count(name)) {
// NOTE(zhiqiu): For operator in a conditional block, the related vars may
// not exist in current block, but in its futher block.
// NOTE(zhiqiu): For operator in a conditional block, the related vars
// may not exist in current block, but in its futher block.
*pruned_vars->Add() = var_map[name];
}
}
......@@ -389,6 +426,7 @@ std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
proto::ProgramDesc pruned_desc;
pruned_desc.clear_blocks();
// Step 2. Prune backward for each block.
for (size_t i = 0; i < origin_clone.Size(); i++) {
auto pruned = proto::BlockDesc();
......
......@@ -26,9 +26,9 @@ limitations under the License. */
namespace paddle {
namespace framework {
void Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output);
std::map<int, int> Prune(const proto::ProgramDesc& input,
const std::set<std::string>& feed_var_names,
proto::ProgramDesc* output);
std::tuple<framework::ProgramDesc, std::map<int, int>> PruneBackward(
const framework::ProgramDesc& origin);
......
......@@ -1154,8 +1154,10 @@ All parameter, weight, gradient are variables in Paddle.
prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true);
}
proto::ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc);
return new ProgramDesc(pruned_desc);
auto pruned_origin_block_id_map =
Prune(*prog_with_targets.Proto(), feeded_var_names, &pruned_desc);
return std::make_tuple(ProgramDesc(pruned_desc),
pruned_origin_block_id_map);
});
m.def("prune_backward",
[](const framework::ProgramDesc &program) {
......
......@@ -23,12 +23,13 @@ import numpy as np
from .wrapped_decorator import signature_safe_contextmanager
import six
from .data_feeder import convert_dtype
from .framework import Program, default_main_program, Variable, convert_np_dtype_to_dtype_
from .framework import Program, default_main_program, Variable, Operator, convert_np_dtype_to_dtype_
from . import core
from . import compiler
from .. import compat as cpt
from .trainer_factory import TrainerFactory
from .trainer_factory import FetchHandlerMonitor
import copy
__all__ = ['Executor', 'global_scope', 'scope_guard']
......@@ -345,14 +346,27 @@ def _fetch_var(name, scope=None, return_numpy=True):
def _to_name_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, six.string_types):
return str(var)
def _to_str(var):
if isinstance(var, Variable):
return var.desc.name()
elif isinstance(var, str):
return var
elif isinstance(var, six.string_types):
return str(var)
elif isinstance(var, Operator):
return var.desc.type()
else:
raise TypeError(str(var) + " should be Variable, Operator or str")
# NOTEz(zhiqiu): The item in fetch_list may be tuple returned by Optimizer.minimize(),
# see comments in _split_optimize_ops_in_fetch_list for more details.
if isinstance(var, tuple):
var = var[0]
if isinstance(var, list):
s = [_to_str(item) for item in var]
return ','.join(s)
else:
raise TypeError(str(var) + " should be Variable or str")
return _to_str(var)
def _get_strong_program_cache_key(program, feed, fetch_list):
......@@ -360,9 +374,13 @@ def _get_strong_program_cache_key(program, feed, fetch_list):
def _get_program_cache_key(feed, fetch_list):
feed_var_names = list(feed.keys())
feed_var_names = []
if isinstance(feed, dict):
feed_var_names = list(feed.keys())
elif isinstance(feed, list) or isinstance(feed, tuple):
for i, each in enumerate(feed):
feed_var_names += list(each.keys())
fetch_var_names = list(map(_to_name_str, fetch_list))
return str(feed_var_names + fetch_var_names)
......@@ -503,10 +521,12 @@ class Executor(object):
self.ctx_caches = dict()
self.scope_caches = dict()
self.var_caches = dict()
self.pruned_program_caches = dict()
p = core.Place()
p.set_place(self.place)
self._default_executor = core.Executor(p)
self._closed = False
self.pruned_program_scope_caches = dict()
def _get_scope_cache(self, program_cache_key):
return self.scope_caches.get(program_cache_key, None)
......@@ -520,6 +540,18 @@ class Executor(object):
def _add_program_cache(self, program_cache_key, program):
self.program_caches[program_cache_key] = program
def _get_pruned_program_cache(self, program_cache_key):
return self.pruned_program_caches.get(program_cache_key, None)
def _add_pruned_program_cache(self, program_cache_key, program):
self.pruned_program_caches[program_cache_key] = program
def _get_pruned_program_scope_cache(self, program_cache_key):
return self.pruned_program_scope_caches.get(program_cache_key, None)
def _add_pruned_program_scope_cache(self, program_cache_key, program):
self.pruned_program_scope_caches[program_cache_key] = program
def _add_ctx_cache(self, ctx_cache_key, ctx):
self.ctx_caches[ctx_cache_key] = ctx
......@@ -551,13 +583,17 @@ class Executor(object):
# prepend feed operators
if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed):
out = global_block.var(name)
global_block._prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
if global_block.has_var(name):
out = global_block.var(name)
global_block._prepend_op(
type='feed',
inputs={'X': [feed_var]},
outputs={'Out': [out]},
attrs={'col': i})
else:
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% name)
# append fetch_operators
if not has_fetch_operators(global_block, fetch_list, fetch_var_name):
for i, var in enumerate(fetch_list):
......@@ -595,6 +631,159 @@ class Executor(object):
]
return outs
def _split_optimize_ops_in_fetch_list(self, fetch_list):
"""
Split optimize_ops from fetch_list, which provided to specify program prunning.
Args:
fetch_list(list): The original fetch_list.
Possible types of fetch_list are:
fetch_list = ['loss']
fetch_list = [[sgd, sgd], 'loss']
fetch_list = [([sgd, sgd], [(param, grad)]), 'loss']
Returns:
optimize_ops(list): The optimize operators splited from fetch_list.
fetch_list(list): The updated fetch_list which does not contain optimize operators.
"""
_optimize_ops = []
_fetch_list = []
def _get_targets(_optimize_ops, _fetch_list, item):
if isinstance(item, Operator):
if item._is_optimize_op():
_optimize_ops.append(item)
else:
raise TypeError(
"The operator in fetch_list is not an optimize_op")
elif isinstance(item, Variable) or isinstance(
item, str) or isinstance(item, six.string_types):
_fetch_list.append(item)
else:
raise TypeError(
"The item in fetch_list should be str, variable or optimize_op, but recieved %s.",
type(item))
for item in fetch_list:
# NOTE(zhiqiu): to support (optimizer_ops, param_and_grads) and optimizer_ops in fetch_list
# we should handle tuple and list in fetch_list.
# TODO(zhiqiu): find a better way to handle that.
if isinstance(item, list):
for i in item:
_get_targets(_optimize_ops, _fetch_list, i)
elif isinstance(item, tuple):
for i in item[0]:
_get_targets(_optimize_ops, _fetch_list, i)
else:
_get_targets(_optimize_ops, _fetch_list, item)
return _fetch_list, _optimize_ops
def _prune_program(self,
program,
feed=None,
fetch_list=None,
optimize_ops=None):
"""
Prune operators and variables which are not needed to generate
:code:`fetch_list` and optimize operators.
Prune operators and variables which are needed
to generate variables to be feeded.
Notes: This is a very low level API. Users should not use this API
directly.
Args:
program(Program): the origin program
feed(list|dict): feed dict or list.
fetch_list(list|Variable): A list of variables need to be fetched
optimize_ops(list[Operator]): A list of optimizer operators
Returns:
Program: A new, pruned program.
"""
compiled = isinstance(program, compiler.CompiledProgram)
if compiled:
if program._program:
origin_program = program._program
else:
warnings.warn(
"The program holds no _program, maybe it is constructed by graph, which can't be pruned yet."
)
return
else:
origin_program = program
feed_names = []
if isinstance(feed, dict):
feed_names = list(feed.keys())
elif isinstance(feed, list) or isinstance(feed, tuple):
for i, each in enumerate(feed):
feed_names += list(each.keys())
# if optimize_ops is [], all optimize ops in the program is used.
if not optimize_ops:
for block in origin_program.blocks:
for op in block.ops:
if op._is_optimize_op():
optimize_ops.append(op)
targets = fetch_list + optimize_ops
pruned_program = origin_program._prune_with_input(feed_names, targets)
if compiled:
# for compiled program, update the underlying program, re-generate graph,
# and reset the flag so it can be compiled again.
program._program = pruned_program
program._graph = core.Graph(pruned_program.desc)
program._compiled = False
else:
program = pruned_program
return program
def _update_feed(self, program, feed):
"""
Update the feed dict, remove the feed item which is pruned in program.
Notes: This is a very low level API. Users should not use this API
directly.
Args:
program(Program): the pruned program.
feed(list|dict): feed dict or list.
Returns:
feed:(list|dict) updated feed.
"""
compiled = isinstance(program, compiler.CompiledProgram)
if compiled:
if program._program:
global_block = program._program.global_block()
else:
warnings.warn(
"The program holds no _program, maybe it is constructed by graph."
)
else:
global_block = program.global_block()
if isinstance(feed, dict):
for feed_name in list(feed.keys()):
if not global_block.has_var(feed_name):
feed.pop(feed_name)
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% feed_name)
elif isinstance(feed, list) or isinstance(feed, tuple):
for i, each in enumerate(feed):
for feed_name in list(each.keys()):
if not global_block.has_var(feed_name):
each.pop(feed_name)
warnings.warn(
"The variable %s is not found in program. It is not declared or is pruned."
% feed_name)
return feed
'''
TODO(typhoonzero): Define "no longer use" meaning? Can user create
a new Executor for the same program and run?
......@@ -682,7 +871,8 @@ class Executor(object):
scope=None,
return_numpy=True,
use_program_cache=False,
return_merged=True):
return_merged=True,
use_prune=False):
"""
Run the specified :code:`Program` or :code:`CompiledProgram`. It should be noted that the executor
will execute all the operators in :code:`Program` or :code:`CompiledProgram` without pruning some
......@@ -706,7 +896,7 @@ class Executor(object):
so the length of this list should be equal to the number of places.
The default is None.
fetch_list(list): This parameter represents the variables that need to be returned
after the model runs. The default is None.
after the model runs. The default is None.
feed_var_name(str): This parameter represents the name of the input variable of
the feed operator. The default is "feed".
fetch_var_name(str): This parameter represents the name of the output variable of
......@@ -732,6 +922,13 @@ class Executor(object):
set :code:`return_merged` as False, which denotes that the fetched results will not be merged.
The default is True, but it is just for the compatibility, and may use False as default value
in the future version.
use_prune(bool): This parameter indicates whether the input :code:`Program` will be pruned.
If the parameter is True, the program will be pruned accroding to the given feed and fetch_list,
which means the operators and variables in program that generate :code:`feed` and are not
needed to generate :code:`fetch_list` will be pruned. The default is False, which means the
program will not pruned and all the operators and variables will be executed during running.
Note that if the tuple returned from :code:`Optimizer.minimize()` is passed to :code:`fetch_list`,
:code:`use_prune` will be overrided to True, and the program will be pruned.
Returns:
......@@ -844,6 +1041,7 @@ class Executor(object):
scope=scope,
return_numpy=return_numpy,
use_program_cache=use_program_cache,
use_prune=use_prune,
return_merged=return_merged)
except Exception as e:
if not isinstance(e, core.EOFException):
......@@ -853,7 +1051,7 @@ class Executor(object):
def _run_impl(self, program, feed, fetch_list, feed_var_name,
fetch_var_name, scope, return_numpy, use_program_cache,
return_merged):
return_merged, use_prune):
if self._closed:
raise RuntimeError("Attempted to use a closed Executor")
......@@ -877,7 +1075,9 @@ class Executor(object):
scope = global_scope()
if fetch_list is not None:
if isinstance(fetch_list, Variable) or isinstance(fetch_list, str):
if isinstance(fetch_list, Variable) or isinstance(
fetch_list, str) or isinstance(fetch_list,
six.string_types):
fetch_list = [fetch_list]
assert isinstance(fetch_list, tuple) or isinstance(fetch_list, list), \
"Currently , The fetch_list type only should be list or tuple, \n"\
......@@ -886,6 +1086,38 @@ class Executor(object):
else:
fetch_list = []
# use_prune can be overrided by putting optimize_ops in fetch_list
_origin_fetch_list = fetch_list
_origin_program = program
fetch_list, optimize_ops = self._split_optimize_ops_in_fetch_list(
fetch_list)
if optimize_ops:
use_prune = True
if use_prune:
cache_key = _get_strong_program_cache_key(program, feed,
_origin_fetch_list)
cached_pruned_program = self._get_pruned_program_cache(cache_key)
if cached_pruned_program is None:
if isinstance(program, compiler.CompiledProgram):
program_scope_cache = self._get_pruned_program_scope_cache(
str(id(_origin_program)))
# copy the original program, so it can be cached.
program = copy.copy(program)
# share the local scopes for same original CompiledProgram.
program._share_vars_from = program_scope_cache
if self._get_pruned_program_scope_cache(
str(id(_origin_program))) is None:
self._add_pruned_program_scope_cache(
str(id(_origin_program)), program)
pruned_program = self._prune_program(program, feed, fetch_list,
optimize_ops)
self._add_pruned_program_cache(cache_key, pruned_program)
else:
pruned_program = cached_pruned_program
feed = self._update_feed(pruned_program, feed)
program = pruned_program
compiled = isinstance(program, compiler.CompiledProgram)
# For backward compatibility, run directly.
......
......@@ -2172,6 +2172,15 @@ class Operator(object):
return attr_map
def _is_optimize_op(self):
op_maker = core.op_proto_and_checker_maker
OPTIMIZE = core.op_proto_and_checker_maker.OpRole.Optimize
op_role = self.desc.attr(op_maker.kOpRoleAttrName())
if op_role & int(OPTIMIZE):
return True
else:
return False
class Block(object):
"""
......@@ -2706,8 +2715,8 @@ class Block(object):
assert isinstance(p, Parameter)
v = self.vars.get(p.name, None)
if v is None:
raise ValueError("_copy_param_info_from should be invoked with "
"same topology")
# if the Parameter is pruned, v may be None
continue
assert isinstance(v, Variable)
new_p = None
if in_dygraph_mode():
......@@ -4056,52 +4065,13 @@ class Program(object):
directly. This API is in flux and not stable.
Args:
targets(list|Variable|Operator): A list of variables or operators
targets(list|Variable|Operator): A list of variables, operators, or variable names
need to be pruned
Returns:
Program: A new, pruned program.
"""
#NOTE(zhiqiu): we sync the original program first, since its program may diff with
# its desc due to modifying desc in c++ space. E.g. save op will add kLookupTablePath in desc.
self._sync_with_cpp()
if not isinstance(targets, list):
targets = [targets]
targets_idx = []
for t in targets:
if not isinstance(t, Operator):
if isinstance(t, Variable):
# After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this
# variable here.
t.op = None
global_block = self.global_block()
for idx, op in enumerate(global_block.ops):
if t.name in op.output_arg_names:
t.op = op
break
t = t.op
if t is None:
raise ValueError(
"The target variable must have an "
"associated operator that generates it.")
else:
raise ValueError("All targets of prune() can only be "
"Variable or Operator.")
targets_idx.append([t.block.idx, t.idx])
res = Program()
res.desc = core.prune(self.desc, set(), targets_idx)
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
res._sync_with_cpp()
return res
return self._prune_with_input([], targets)
def _prune_with_input(self, feeded_var_names, targets):
"""
......@@ -4115,7 +4085,7 @@ class Program(object):
Args:
feeded_var_names(list|str): A list of variable names from where
pruning start. If it is set as [], this API works just like _prune()
targets(list|Variable|Operator): A list of variables or operators
targets(list|Variable|Operator): A list of variables, operators, or variable names
need to be pruned
Returns:
......@@ -4140,33 +4110,47 @@ class Program(object):
for t in targets:
if not isinstance(t, Operator):
if isinstance(t, Variable):
# After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this
# variable here.
t.op = None
global_block = self.global_block()
for idx, op in enumerate(global_block.ops):
if t.name in op.output_arg_names:
t.op = op
break
t = t.op
if t is None:
raise ValueError(
"The target variable must have an "
"associated operator that generates it.")
name = t.name
elif isinstance(t, six.string_types):
name = str(t)
else:
raise ValueError("All targets of prune() can only be "
"Variable or Operator.")
# After transpiler processing, the op that output this
# variable maybe has been changed, so t.op is not reliable
# and we need to find the current op that generate this
# variable here.
target_op = None
global_block = self.global_block()
for idx, op in enumerate(global_block.ops):
if name in op.output_arg_names:
# NOTE(zhiqiu): Find op that generate target name.
# Skip optimize op except for optimize op in targets,
# since optimize op generates parameters.
if op._is_optimize_op() and op not in targets:
continue
else:
target_op = op
break
t = target_op
if t is None:
raise ValueError("The target variable must have an "
"associated operator that generates it.")
targets_idx.append([t.block.idx, t.idx])
res = Program()
res.desc = core.prune(self.desc, set(feeded_var_names), targets_idx)
res.desc, pruned_origin_block_id_map = core.prune(self.desc,
set(feeded_var_names),
targets_idx)
res.blocks = [
Block(res, i) for i in six.moves.range(res.desc.num_blocks())
]
res._sync_with_cpp()
res._copy_param_info_from(self)
res._copy_data_info_from(self, pruned_origin_block_id_map)
res._copy_dist_param_info_from(self)
return res
def _inference_optimize(self, prune_read_op=True):
......
......@@ -811,6 +811,9 @@ class Optimizer(object):
tuple: tuple (optimize_ops, params_grads), A list of operators appended
by minimize and a list of (param, grad) variable pairs, param is
``Parameter``, grad is the gradient value corresponding to the parameter.
The returned tuple can be passed to ``fetch_list`` in ``Executor.run()`` to
indicate program pruning. If so, the program will be pruned by ``feed`` and
``fetch_list`` before run, see details in ``Executor``.
Examples:
Please refer to the example of current Optimizer.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册