未验证 提交 db67d60e 编写于 作者: W Wu Yi 提交者: GitHub

Remove block api (#12107)

* remove block api

* remove clone_variable

* hide block inner apis

* update

* fix tests
上级 866fcb0c
...@@ -210,7 +210,7 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader, ...@@ -210,7 +210,7 @@ def train_parallel(avg_loss, infer_prog, optimizer, train_reader, test_reader,
# generate fake: # generate fake:
if args.use_fake_data: if args.use_fake_data:
for var in feed_var_list: for var in feed_var_list:
v = startup_prog.global_block().clone_variable(var) v = startup_prog.global_block()._clone_variable(var)
var.persistable = True var.persistable = True
v.persistable = True v.persistable = True
......
...@@ -98,13 +98,13 @@ class Block(objects): ...@@ -98,13 +98,13 @@ class Block(objects):
def append_operator(self, ...): def append_operator(self, ...):
self.ops.append(Operator(self, ...)) self.ops.append(Operator(self, ...))
def prepend_operator(self, ...): # Parameter's ctor prepands initialize operators. def _prepend_operator(self, ...): # Parameter's ctor prepands initialize operators.
self.ops.prepend(Operator(self, ...)) self.ops.prepend(Operator(self, ...))
``` ```
`create_parameter` is necessary because parameters are global variables, defined in the global block, but can be created in some sub-blocks. For example, an FC layer in the step block of an RNN operator. `create_parameter` is necessary because parameters are global variables, defined in the global block, but can be created in some sub-blocks. For example, an FC layer in the step block of an RNN operator.
`prepend_operator` is necessary because the constructor of `Parameter` needs to create the initialize (or load) operator of the parameter, and would like to put it in the *preamble* of the global block. `_prepend_operator` is necessary because the constructor of `Parameter` needs to create the initialize (or load) operator of the parameter, and would like to put it in the *preamble* of the global block.
### Operator ### Operator
......
...@@ -78,7 +78,7 @@ def error_clip_callback(block, context): ...@@ -78,7 +78,7 @@ def error_clip_callback(block, context):
op_desc = block.desc.op(block.desc.op_size() - 1) op_desc = block.desc.op(block.desc.op_size() - 1)
for grad_n in filter(lambda n: grad_to_var.has_key(n), for grad_n in filter(lambda n: grad_to_var.has_key(n),
op_desc.output_arg_names()): op_desc.output_arg_names()):
fwd_var = block.var_recursive(grad_to_var[grad_n]) fwd_var = block.__var_recursive(grad_to_var[grad_n])
error_clip = getattr(fwd_var, "error_clip", None) error_clip = getattr(fwd_var, "error_clip", None)
if not (error_clip is None or isinstance(error_clip, if not (error_clip is None or isinstance(error_clip,
BaseErrorClipAttr)): BaseErrorClipAttr)):
......
...@@ -118,7 +118,7 @@ class Float16Transpiler: ...@@ -118,7 +118,7 @@ class Float16Transpiler:
for var in self.block.vars.keys(): for var in self.block.vars.keys():
if var not in args: if var not in args:
self.block.remove_var(var) self.block._remove_var(var)
def _modify_feed_fetch(self): def _modify_feed_fetch(self):
''' '''
...@@ -165,7 +165,7 @@ class Float16Transpiler: ...@@ -165,7 +165,7 @@ class Float16Transpiler:
dtype=core.VarDesc.VarType.FP16, dtype=core.VarDesc.VarType.FP16,
shape=var.shape, shape=var.shape,
persistable=var.persistable) persistable=var.persistable)
self.block.insert_op( self.block._insert_op(
i + 1, i + 1,
type="cast", type="cast",
inputs={"X": var}, inputs={"X": var},
...@@ -188,7 +188,7 @@ class Float16Transpiler: ...@@ -188,7 +188,7 @@ class Float16Transpiler:
persistable=var.persistable) persistable=var.persistable)
find_op(var) find_op(var)
var.op.rename_output(var_name, tmp_var_name) var.op.rename_output(var_name, tmp_var_name)
self.block.insert_op( self.block._insert_op(
i, i,
type="cast", type="cast",
inputs={"X": tmp_var}, inputs={"X": tmp_var},
...@@ -253,4 +253,4 @@ class Float16Transpiler: ...@@ -253,4 +253,4 @@ class Float16Transpiler:
# old var will be replaced by the fp16 var in program desc # old var will be replaced by the fp16 var in program desc
self.input_map[var.name] = fp16_var_name self.input_map[var.name] = fp16_var_name
self.block.remove_var(var.name) self.block._remove_var(var.name)
...@@ -145,14 +145,14 @@ void BindBlockDesc(pybind11::module *m) { ...@@ -145,14 +145,14 @@ void BindBlockDesc(pybind11::module *m) {
.def_property_readonly("id", &pd::BlockDesc::ID) .def_property_readonly("id", &pd::BlockDesc::ID)
.def_property_readonly("parent", &pd::BlockDesc::Parent) .def_property_readonly("parent", &pd::BlockDesc::Parent)
.def("get_forward_block_idx", &pd::BlockDesc::ForwardBlockID) .def("get_forward_block_idx", &pd::BlockDesc::ForwardBlockID)
.def("set_forward_block_idx", &pd::BlockDesc::SetForwardBlockID) .def("_set_forward_block_idx", &pd::BlockDesc::SetForwardBlockID)
.def("append_op", &pd::BlockDesc::AppendOp, .def("append_op", &pd::BlockDesc::AppendOp,
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("prepend_op", &pd::BlockDesc::PrependOp, .def("_prepend_op", &pd::BlockDesc::PrependOp,
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("insert_op", &pd::BlockDesc::InsertOp, .def("_insert_op", &pd::BlockDesc::InsertOp,
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("remove_op", &pd::BlockDesc::RemoveOp) .def("_remove_op", &pd::BlockDesc::RemoveOp)
.def("var", .def("var",
[](pd::BlockDesc &self, pybind11::bytes byte_name) { [](pd::BlockDesc &self, pybind11::bytes byte_name) {
std::string name = byte_name; std::string name = byte_name;
...@@ -165,7 +165,7 @@ void BindBlockDesc(pybind11::module *m) { ...@@ -165,7 +165,7 @@ void BindBlockDesc(pybind11::module *m) {
return self.HasVar(name); return self.HasVar(name);
}, },
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("rename_var", .def("_rename_var",
[](pd::BlockDesc &self, const pybind11::bytes &byte_name, [](pd::BlockDesc &self, const pybind11::bytes &byte_name,
const pybind11::bytes &byte_name_new) { const pybind11::bytes &byte_name_new) {
std::string name = byte_name; std::string name = byte_name;
...@@ -189,7 +189,7 @@ void BindBlockDesc(pybind11::module *m) { ...@@ -189,7 +189,7 @@ void BindBlockDesc(pybind11::module *m) {
return self.FindVarRecursive(name); return self.FindVarRecursive(name);
}, },
pybind11::return_value_policy::reference) pybind11::return_value_policy::reference)
.def("remove_var", .def("_remove_var",
[](pd::BlockDesc &self, pybind11::bytes byte_name) { [](pd::BlockDesc &self, pybind11::bytes byte_name) {
std::string name = byte_name; std::string name = byte_name;
return self.RemoveVar(name); return self.RemoveVar(name);
......
...@@ -328,7 +328,7 @@ def _append_backward_ops_(block, ...@@ -328,7 +328,7 @@ def _append_backward_ops_(block,
if op.has_attr("sub_block"): if op.has_attr("sub_block"):
sub_block = program.block(op.block_attr("sub_block")) sub_block = program.block(op.block_attr("sub_block"))
grad_sub_block = program.create_block() grad_sub_block = program.create_block()
grad_sub_block.set_forward_block_idx(sub_block.idx) grad_sub_block._set_forward_block_idx(sub_block.idx)
cb = _callback_lookup_(op) cb = _callback_lookup_(op)
if cb is not None: if cb is not None:
if callbacks is None: if callbacks is None:
...@@ -571,7 +571,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, ...@@ -571,7 +571,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
_append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map) _append_backward_vars_(root_block, fwd_op_num, grad_to_var, grad_info_map)
program.current_block_idx = current_block_idx program.current_block_idx = current_block_idx
program.sync_with_cpp() program._sync_with_cpp()
# FIXME(zcd): prevent loss.grad optimized by mem_opt. # FIXME(zcd): prevent loss.grad optimized by mem_opt.
loss.block.var(_append_grad_suffix_(loss.name)).persistable = True loss.block.var(_append_grad_suffix_(loss.name)).persistable = True
...@@ -744,7 +744,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None): ...@@ -744,7 +744,7 @@ def calc_gradient(targets, inputs, target_gradients=None, no_grad_set=None):
_rename_grad_(block, fwd_op_num, grad_to_var, target_grad_map) _rename_grad_(block, fwd_op_num, grad_to_var, target_grad_map)
_append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map) _append_backward_vars_(block, fwd_op_num, grad_to_var, grad_info_map)
prog.sync_with_cpp() prog._sync_with_cpp()
grad_vars = [] grad_vars = []
for input_var in inputs: for input_var in inputs:
......
...@@ -82,7 +82,7 @@ def error_clip_callback(block, context): ...@@ -82,7 +82,7 @@ def error_clip_callback(block, context):
op_desc = block.desc.op(block.desc.op_size() - 1) op_desc = block.desc.op(block.desc.op_size() - 1)
for grad_n in filter(lambda n: grad_to_var.has_key(n), for grad_n in filter(lambda n: grad_to_var.has_key(n),
op_desc.output_arg_names()): op_desc.output_arg_names()):
fwd_var = block.var_recursive(grad_to_var[grad_n]) fwd_var = block._var_recursive(grad_to_var[grad_n])
error_clip = getattr(fwd_var, "error_clip", None) error_clip = getattr(fwd_var, "error_clip", None)
if not (error_clip is None or isinstance(error_clip, if not (error_clip is None or isinstance(error_clip,
BaseErrorClipAttr)): BaseErrorClipAttr)):
......
...@@ -69,8 +69,10 @@ class Go(BlockGuard): ...@@ -69,8 +69,10 @@ class Go(BlockGuard):
parent_block.append_op( parent_block.append_op(
type='go', type='go',
inputs={ inputs={
'X': 'X': [
[parent_block.var_recursive(x_name) for x_name in x_name_list] parent_block._var_recursive(x_name)
for x_name in x_name_list
]
}, },
outputs={}, outputs={},
attrs={'sub_block': go_block}) attrs={'sub_block': go_block})
...@@ -259,7 +261,7 @@ class Select(BlockGuard): ...@@ -259,7 +261,7 @@ class Select(BlockGuard):
if var_name in intermediate if var_name in intermediate
] ]
X = [select_block.var_recursive(x_name) for x_name in params] X = [select_block._var_recursive(x_name) for x_name in params]
# Needs to be used by `equal` inside the cases block. # Needs to be used by `equal` inside the cases block.
X.append(self.case_to_execute) X.append(self.case_to_execute)
......
...@@ -309,7 +309,7 @@ class Executor(object): ...@@ -309,7 +309,7 @@ class Executor(object):
if not has_feed_operators(global_block, feed, feed_var_name): if not has_feed_operators(global_block, feed, feed_var_name):
for i, name in enumerate(feed): for i, name in enumerate(feed):
out = global_block.var(name) out = global_block.var(name)
global_block.prepend_op( global_block._prepend_op(
type='feed', type='feed',
inputs={'X': [feed_var]}, inputs={'X': [feed_var]},
outputs={'Out': [out]}, outputs={'Out': [out]},
......
...@@ -32,7 +32,6 @@ except Exception, e: ...@@ -32,7 +32,6 @@ except Exception, e:
import unique_name import unique_name
__all__ = [ __all__ = [
'Block',
'Variable', 'Variable',
'Program', 'Program',
'Operator', 'Operator',
...@@ -447,7 +446,7 @@ class Operator(object): ...@@ -447,7 +446,7 @@ class Operator(object):
Notes: Notes:
The constructor of operator should not be invoked directly. Use The constructor of operator should not be invoked directly. Use
Block.append_op or Block.prepend_op instead. Block.append_op or Block._prepend_op instead.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -870,7 +869,7 @@ class Block(object): ...@@ -870,7 +869,7 @@ class Block(object):
def forward_block_idx(self): def forward_block_idx(self):
return self.desc.get_forward_block_idx() return self.desc.get_forward_block_idx()
def set_forward_block_idx(self, idx): def _set_forward_block_idx(self, idx):
""" """
Set the forward block Idx. Set the forward block Idx.
...@@ -880,7 +879,7 @@ class Block(object): ...@@ -880,7 +879,7 @@ class Block(object):
Returns: Returns:
None None
""" """
self.desc.set_forward_block_idx(idx) self.desc._set_forward_block_idx(idx)
@property @property
def idx(self): def idx(self):
...@@ -909,7 +908,7 @@ class Block(object): ...@@ -909,7 +908,7 @@ class Block(object):
raise ValueError("var %s not in this block" % name) raise ValueError("var %s not in this block" % name)
return v return v
def var_recursive(self, name): def _var_recursive(self, name):
""" """
Get a Variable by name from this block recursively. Get a Variable by name from this block recursively.
...@@ -951,9 +950,9 @@ class Block(object): ...@@ -951,9 +950,9 @@ class Block(object):
raise ValueError("Var {0} is not found recursively".format(name)) raise ValueError("Var {0} is not found recursively".format(name))
def all_parameters(self): def all_parameters(self):
return list(self.iter_parameters()) return list(self._iter_parameters())
def iter_parameters(self): def _iter_parameters(self):
return (item[1] for item in self.vars.iteritems() return (item[1] for item in self.vars.iteritems()
if isinstance(item[1], Parameter)) if isinstance(item[1], Parameter))
...@@ -966,7 +965,7 @@ class Block(object): ...@@ -966,7 +965,7 @@ class Block(object):
def has_var(self, name): def has_var(self, name):
return name in self.vars return name in self.vars
def rename_var(self, name, new_name): def _rename_var(self, name, new_name):
""" """
Rename variable in vars and ops' inputs and outputs Rename variable in vars and ops' inputs and outputs
...@@ -1000,8 +999,8 @@ class Block(object): ...@@ -1000,8 +999,8 @@ class Block(object):
else: else:
raise ValueError("unsupported var type: %s", type(v)) raise ValueError("unsupported var type: %s", type(v))
orig_var_type = v.type orig_var_type = v.type
self.desc.rename_var(name, new_name) self.desc._rename_var(name, new_name)
# NOTE: v is destroyed by C++ after calling rename_var. # NOTE: v is destroyed by C++ after calling _rename_var.
d = self.desc.find_var(new_name) d = self.desc.find_var(new_name)
if var_type == "Parameter": if var_type == "Parameter":
var = Parameter( var = Parameter(
...@@ -1024,16 +1023,16 @@ class Block(object): ...@@ -1024,16 +1023,16 @@ class Block(object):
error_clip=error_clip, error_clip=error_clip,
stop_gradient=stop_gradient) stop_gradient=stop_gradient)
# rename the python side, sync_with_cpp will only add # rename the python side, _sync_with_cpp will only add
# new vars/ops to python side. # new vars/ops to python side.
self.vars[new_name] = var self.vars[new_name] = var
del self.vars[name] del self.vars[name]
self.sync_with_cpp() self._sync_with_cpp()
return var return var
def remove_var(self, name): def _remove_var(self, name):
self.sync_with_cpp() self._sync_with_cpp()
self.desc.remove_var(name) self.desc._remove_var(name)
del self.vars[name] del self.vars[name]
def create_parameter(self, *args, **kwargs): def create_parameter(self, *args, **kwargs):
...@@ -1055,7 +1054,7 @@ class Block(object): ...@@ -1055,7 +1054,7 @@ class Block(object):
self.ops.append(op) self.ops.append(op)
return op return op
def insert_op(self, index, *args, **kwargs): def _insert_op(self, index, *args, **kwargs):
""" """
Insert a Operator according to the giving arguments. Insert a Operator according to the giving arguments.
...@@ -1065,13 +1064,13 @@ class Block(object): ...@@ -1065,13 +1064,13 @@ class Block(object):
Returns: Returns:
Operator: the insert Operator. Operator: the insert Operator.
""" """
self.sync_with_cpp() self._sync_with_cpp()
op_desc = self.desc.insert_op(index) op_desc = self.desc._insert_op(index)
op = Operator(block=self, desc=op_desc, *args, **kwargs) op = Operator(block=self, desc=op_desc, *args, **kwargs)
self.ops.insert(index, op) self.ops.insert(index, op)
return op return op
def remove_op(self, index): def _remove_op(self, index):
""" """
Remove the specific position operator. Remove the specific position operator.
...@@ -1081,11 +1080,11 @@ class Block(object): ...@@ -1081,11 +1080,11 @@ class Block(object):
Returns: Returns:
None None
""" """
self.sync_with_cpp() self._sync_with_cpp()
self.desc.remove_op(index, index + 1) self.desc._remove_op(index, index + 1)
del self.ops[index] del self.ops[index]
def slice_ops(self, start, end): def _slice_ops(self, start, end):
""" """
Return the Operator between start and end. Return the Operator between start and end.
...@@ -1098,13 +1097,13 @@ class Block(object): ...@@ -1098,13 +1097,13 @@ class Block(object):
""" """
return self.ops[start:end] return self.ops[start:end]
def prepend_op(self, *args, **kwargs): def _prepend_op(self, *args, **kwargs):
op_desc = self.desc.prepend_op() op_desc = self.desc._prepend_op()
op = Operator(self, op_desc, *args, **kwargs) op = Operator(self, op_desc, *args, **kwargs)
self.ops.insert(0, op) self.ops.insert(0, op)
return op return op
def sync_with_cpp(self): def _sync_with_cpp(self):
""" """
Sync from the desc on the c++ end. This method is used to synchronize Sync from the desc on the c++ end. This method is used to synchronize
the c++ desc instance generated by backward. the c++ desc instance generated by backward.
...@@ -1170,7 +1169,7 @@ class Block(object): ...@@ -1170,7 +1169,7 @@ class Block(object):
for index in range(len(self.ops)): for index in range(len(self.ops)):
assert self.ops[index].desc == ops_in_cpp[index] assert self.ops[index].desc == ops_in_cpp[index]
def copy_param_info_from(self, other): def _copy_param_info_from(self, other):
""" """
Copy the information of parameters from the other block. Copy the information of parameters from the other block.
...@@ -1185,12 +1184,13 @@ class Block(object): ...@@ -1185,12 +1184,13 @@ class Block(object):
None None
""" """
if not isinstance(other, Block): if not isinstance(other, Block):
raise TypeError("copy_param_info_from should be invoked with Block") raise TypeError(
for p in other.iter_parameters(): "_copy_param_info_from should be invoked with Block")
for p in other._iter_parameters():
assert isinstance(p, Parameter) assert isinstance(p, Parameter)
v = self.vars.get(p.name, None) v = self.vars.get(p.name, None)
if v is None: if v is None:
raise ValueError("copy_param_info_from should be invoked with " raise ValueError("_copy_param_info_from should be invoked with "
"same topology") "same topology")
assert isinstance(v, Variable) assert isinstance(v, Variable)
new_p = Parameter( new_p = Parameter(
...@@ -1208,7 +1208,7 @@ class Block(object): ...@@ -1208,7 +1208,7 @@ class Block(object):
name=v.name) name=v.name)
self.vars[new_p.name] = new_p self.vars[new_p.name] = new_p
def clone_variable(self, var): def _clone_variable(self, var):
""" """
Clone a variable into current block. Clone a variable into current block.
...@@ -1484,9 +1484,9 @@ class Program(object): ...@@ -1484,9 +1484,9 @@ class Program(object):
p = Program() p = Program()
p.desc = core.ProgramDesc(self.desc) p.desc = core.ProgramDesc(self.desc)
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())] p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
p.sync_with_cpp() p._sync_with_cpp()
p.copy_param_info_from(self) p._copy_param_info_from(self)
p.copy_data_info_from(self) p.copy_data_info_from(self)
return p return p
...@@ -1536,7 +1536,7 @@ class Program(object): ...@@ -1536,7 +1536,7 @@ class Program(object):
res = Program() res = Program()
res.desc = core.prune(self.desc, targets_idx) res.desc = core.prune(self.desc, targets_idx)
res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())] res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())]
res.sync_with_cpp() res._sync_with_cpp()
return res return res
def inference_optimize(self): def inference_optimize(self):
...@@ -1562,7 +1562,7 @@ class Program(object): ...@@ -1562,7 +1562,7 @@ class Program(object):
if op.has_attr('is_test'): if op.has_attr('is_test'):
op.set_attr('is_test', True) op.set_attr('is_test', True)
res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())] res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())]
res.sync_with_cpp() res._sync_with_cpp()
return res return res
@staticmethod @staticmethod
...@@ -1582,7 +1582,7 @@ class Program(object): ...@@ -1582,7 +1582,7 @@ class Program(object):
p = Program() p = Program()
p.desc = core.ProgramDesc(binary_str) p.desc = core.ProgramDesc(binary_str)
p.blocks = [Block(p, i) for i in xrange(p.desc.num_blocks())] p.blocks = [Block(p, i) for i in xrange(p.desc.num_blocks())]
p.sync_with_cpp() p._sync_with_cpp()
return p return p
@property @property
...@@ -1662,7 +1662,7 @@ class Program(object): ...@@ -1662,7 +1662,7 @@ class Program(object):
""" """
self.current_block_idx = self.current_block().parent_idx self.current_block_idx = self.current_block().parent_idx
def sync_with_cpp(self): def _sync_with_cpp(self):
""" """
Synchronize Python instance to its binding C++ object instance. Synchronize Python instance to its binding C++ object instance.
If the program is modified in C++ space, this method should be invoked. If the program is modified in C++ space, this method should be invoked.
...@@ -1676,9 +1676,9 @@ class Program(object): ...@@ -1676,9 +1676,9 @@ class Program(object):
for block_idx in range(len(self.blocks), self.desc.num_blocks()): for block_idx in range(len(self.blocks), self.desc.num_blocks()):
self.blocks.append(Block(self, block_idx)) self.blocks.append(Block(self, block_idx))
for block in self.blocks: for block in self.blocks:
block.sync_with_cpp() block._sync_with_cpp()
def copy_param_info_from(self, other): def _copy_param_info_from(self, other):
""" """
Copy the information of parameters from other program. Copy the information of parameters from other program.
...@@ -1692,13 +1692,13 @@ class Program(object): ...@@ -1692,13 +1692,13 @@ class Program(object):
None None
""" """
if not isinstance(other, Program): if not isinstance(other, Program):
raise TypeError("copy_param_info_from should be invoked with " raise TypeError("_copy_param_info_from should be invoked with "
"Program") "Program")
if len(self.blocks) != len(other.blocks): if len(self.blocks) != len(other.blocks):
raise ValueError("copy_param_info_from should be invoked with two " raise ValueError("_copy_param_info_from should be invoked with two "
"program, with represent the same topology") "program, with represent the same topology")
self.global_block().copy_param_info_from(other.global_block()) self.global_block()._copy_param_info_from(other.global_block())
def copy_data_info_from(self, other): def copy_data_info_from(self, other):
""" """
...@@ -1714,11 +1714,11 @@ class Program(object): ...@@ -1714,11 +1714,11 @@ class Program(object):
None None
""" """
if not isinstance(other, Program): if not isinstance(other, Program):
raise TypeError("copy_param_info_from should be invoked with " raise TypeError("_copy_param_info_from should be invoked with "
"Program") "Program")
if len(self.blocks) != len(other.blocks): if len(self.blocks) != len(other.blocks):
raise ValueError("copy_param_info_from should be invoked with two " raise ValueError("_copy_param_info_from should be invoked with two "
"program, with represent the same topology") "program, with represent the same topology")
for var in other.global_block().vars.itervalues(): for var in other.global_block().vars.itervalues():
if var.is_data: if var.is_data:
......
...@@ -148,7 +148,7 @@ class ConstantInitializer(Initializer): ...@@ -148,7 +148,7 @@ class ConstantInitializer(Initializer):
assert isinstance(var, framework.Variable) assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block) assert isinstance(block, framework.Block)
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
op = block.prepend_op( op = block._prepend_op(
type="fill_constant", type="fill_constant",
outputs={"Out": var}, outputs={"Out": var},
attrs={ attrs={
...@@ -202,7 +202,7 @@ class UniformInitializer(Initializer): ...@@ -202,7 +202,7 @@ class UniformInitializer(Initializer):
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
op = block.prepend_op( op = block._prepend_op(
type="uniform_random", type="uniform_random",
outputs={"Out": var}, outputs={"Out": var},
attrs={ attrs={
...@@ -256,7 +256,7 @@ class NormalInitializer(Initializer): ...@@ -256,7 +256,7 @@ class NormalInitializer(Initializer):
# Initialization Ops should be prepended and not appended # Initialization Ops should be prepended and not appended
if self._seed == 0: if self._seed == 0:
self._seed = block.program.random_seed self._seed = block.program.random_seed
op = block.prepend_op( op = block._prepend_op(
type="gaussian_random", type="gaussian_random",
outputs={"Out": var}, outputs={"Out": var},
attrs={ attrs={
...@@ -346,7 +346,7 @@ class XavierInitializer(Initializer): ...@@ -346,7 +346,7 @@ class XavierInitializer(Initializer):
if self._uniform: if self._uniform:
limit = np.sqrt(6.0 / float(fan_in + fan_out)) limit = np.sqrt(6.0 / float(fan_in + fan_out))
op = block.prepend_op( op = block._prepend_op(
type="uniform_random", type="uniform_random",
outputs={"Out": var}, outputs={"Out": var},
attrs={ attrs={
...@@ -359,7 +359,7 @@ class XavierInitializer(Initializer): ...@@ -359,7 +359,7 @@ class XavierInitializer(Initializer):
else: else:
std = np.sqrt(2.0 / float(fan_in + fan_out)) std = np.sqrt(2.0 / float(fan_in + fan_out))
op = block.prepend_op( op = block._prepend_op(
type="gaussian_random", type="gaussian_random",
outputs={"Out": var}, outputs={"Out": var},
attrs={ attrs={
...@@ -444,7 +444,7 @@ class MSRAInitializer(Initializer): ...@@ -444,7 +444,7 @@ class MSRAInitializer(Initializer):
if self._uniform: if self._uniform:
limit = np.sqrt(6.0 / float(fan_in)) limit = np.sqrt(6.0 / float(fan_in))
op = block.prepend_op( op = block._prepend_op(
type="uniform_random", type="uniform_random",
outputs={"Out": var}, outputs={"Out": var},
attrs={ attrs={
...@@ -457,7 +457,7 @@ class MSRAInitializer(Initializer): ...@@ -457,7 +457,7 @@ class MSRAInitializer(Initializer):
else: else:
std = np.sqrt(2.0 / float(fan_in)) std = np.sqrt(2.0 / float(fan_in))
op = block.prepend_op( op = block._prepend_op(
type="gaussian_random", type="gaussian_random",
outputs={"Out": var}, outputs={"Out": var},
attrs={ attrs={
......
...@@ -523,7 +523,7 @@ def prepend_feed_ops(inference_program, ...@@ -523,7 +523,7 @@ def prepend_feed_ops(inference_program,
for i, name in enumerate(feed_target_names): for i, name in enumerate(feed_target_names):
out = global_block.var(name) out = global_block.var(name)
global_block.prepend_op( global_block._prepend_op(
type='feed', type='feed',
inputs={'X': [feed_var]}, inputs={'X': [feed_var]},
outputs={'Out': [out]}, outputs={'Out': [out]},
...@@ -625,7 +625,7 @@ def save_inference_model(dirname, ...@@ -625,7 +625,7 @@ def save_inference_model(dirname,
for i, op in enumerate(global_block.ops): for i, op in enumerate(global_block.ops):
op.desc.set_is_target(False) op.desc.set_is_target(False)
if op.type == "feed" or op.type == "fetch": if op.type == "feed" or op.type == "fetch":
global_block.remove_op(i) global_block._remove_op(i)
copy_program.desc.flush() copy_program.desc.flush()
pruned_program = copy_program.prune(targets=target_vars) pruned_program = copy_program.prune(targets=target_vars)
...@@ -874,7 +874,7 @@ def get_test_program(filelist, program=None, startup_program=None): ...@@ -874,7 +874,7 @@ def get_test_program(filelist, program=None, startup_program=None):
main_block = program.global_block() main_block = program.global_block()
for var in main_block.vars.values(): for var in main_block.vars.values():
if var.type == core.VarDesc.VarType.READER: if var.type == core.VarDesc.VarType.READER:
main_block.rename_var( main_block._rename_var(
str(var.name), str(_get_test_reader_name(var.name))) str(var.name), str(_get_test_reader_name(var.name)))
for op in main_block.ops: for op in main_block.ops:
...@@ -883,7 +883,7 @@ def get_test_program(filelist, program=None, startup_program=None): ...@@ -883,7 +883,7 @@ def get_test_program(filelist, program=None, startup_program=None):
if op.type == "create_multi_pass_reader": if op.type == "create_multi_pass_reader":
test_op.set_attr("pass_num", 1) test_op.set_attr("pass_num", 1)
startup_program.sync_with_cpp() startup_program._sync_with_cpp()
program.sync_with_cpp() program._sync_with_cpp()
return program return program
...@@ -730,8 +730,10 @@ class While(object): ...@@ -730,8 +730,10 @@ class While(object):
parent_block.append_op( parent_block.append_op(
type='while', type='while',
inputs={ inputs={
'X': 'X': [
[parent_block.var_recursive(x_name) for x_name in x_name_list], parent_block._var_recursive(x_name)
for x_name in x_name_list
],
'Condition': [self.cond_var] 'Condition': [self.cond_var]
}, },
outputs={'Out': out_vars, outputs={'Out': out_vars,
...@@ -1259,7 +1261,7 @@ class ConditionalBlock(object): ...@@ -1259,7 +1261,7 @@ class ConditionalBlock(object):
input_set = set([ipt.name for ipt in self.inputs]) input_set = set([ipt.name for ipt in self.inputs])
param_list = [ param_list = [
parent_block.var_recursive(each_name) for each_name in params parent_block._var_recursive(each_name) for each_name in params
if each_name not in input_set if each_name not in input_set
] ]
......
...@@ -4367,7 +4367,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1): ...@@ -4367,7 +4367,7 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
helper.set_variable_initializer( helper.set_variable_initializer(
counter, initializer=Constant( counter, initializer=Constant(
value=begin - 1, force_cpu=True)) value=begin - 1, force_cpu=True))
helper.main_program.global_block().prepend_op( helper.main_program.global_block()._prepend_op(
type='increment', type='increment',
inputs={'X': [counter]}, inputs={'X': [counter]},
outputs={'Out': [counter]}, outputs={'Out': [counter]},
......
...@@ -240,7 +240,7 @@ class Optimizer(object): ...@@ -240,7 +240,7 @@ class Optimizer(object):
self._finish_update(loss.block, parameters_and_grads) self._finish_update(loss.block, parameters_and_grads)
end = len(global_block.ops) end = len(global_block.ops)
return global_block.slice_ops(start, end) return global_block._slice_ops(start, end)
def minimize(self, def minimize(self,
loss, loss,
......
...@@ -152,7 +152,7 @@ class ParallelExecutor(object): ...@@ -152,7 +152,7 @@ class ParallelExecutor(object):
self.executor = core.ParallelExecutor( self.executor = core.ParallelExecutor(
self._places, self._places,
set([ set([
p.name for p in main.global_block().iter_parameters() p.name for p in main.global_block()._iter_parameters()
if not p.stop_gradient if not p.stop_gradient
]), ]),
set(self.persistable_vars), main.desc, loss_name set(self.persistable_vars), main.desc, loss_name
......
...@@ -181,13 +181,13 @@ class TestBlockDesc(unittest.TestCase): ...@@ -181,13 +181,13 @@ class TestBlockDesc(unittest.TestCase):
self.assertIsNotNone(block) self.assertIsNotNone(block)
op1 = block.append_op() op1 = block.append_op()
op2 = block.append_op() op2 = block.append_op()
op0 = block.prepend_op() op0 = block._prepend_op()
all_ops = [] all_ops = []
for idx in xrange(0, block.op_size()): for idx in xrange(0, block.op_size()):
all_ops.append(block.op(idx)) all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op0, op1, op2]) self.assertEqual(all_ops, [op0, op1, op2])
def test_remove_op(self): def test__remove_op(self):
program = Program() program = Program()
program_desc = program.desc program_desc = program.desc
self.assertIsNotNone(program_desc) self.assertIsNotNone(program_desc)
...@@ -201,8 +201,8 @@ class TestBlockDesc(unittest.TestCase): ...@@ -201,8 +201,8 @@ class TestBlockDesc(unittest.TestCase):
op1.set_type("test") op1.set_type("test")
op2.set_type("test") op2.set_type("test")
block.remove_op(1, 2) block._remove_op(1, 2)
program.sync_with_cpp() program._sync_with_cpp()
all_ops = [] all_ops = []
for idx in xrange(0, block.op_size()): for idx in xrange(0, block.op_size()):
......
...@@ -17,10 +17,10 @@ def delete_ops(block, ops): ...@@ -17,10 +17,10 @@ def delete_ops(block, ops):
try: try:
start = list(block.ops).index(ops[0]) start = list(block.ops).index(ops[0])
end = list(block.ops).index(ops[-1]) end = list(block.ops).index(ops[-1])
[block.remove_op(start) for _ in xrange(end - start + 1)] [block._remove_op(start) for _ in xrange(end - start + 1)]
except Exception, e: except Exception, e:
raise e raise e
block.program.sync_with_cpp() block.program._sync_with_cpp()
def find_op_by_input_arg(block, arg_name): def find_op_by_input_arg(block, arg_name):
......
...@@ -243,7 +243,7 @@ class DistributeTranspiler(object): ...@@ -243,7 +243,7 @@ class DistributeTranspiler(object):
AssertionError("Can not insert the send op by original " AssertionError("Can not insert the send op by original "
"variable name :", orig_varname) "variable name :", orig_varname)
program.global_block().insert_op( program.global_block()._insert_op(
index=index + 1, index=index + 1,
type="send", type="send",
inputs={"X": splited_vars}, inputs={"X": splited_vars},
...@@ -429,7 +429,7 @@ class DistributeTranspiler(object): ...@@ -429,7 +429,7 @@ class DistributeTranspiler(object):
# clone vars # clone vars
for var in origin_block.vars: for var in origin_block.vars:
new_sub_block.clone_variable(var) new_sub_block._clone_variable(var)
# clone ops # clone ops
for origin_op in origin_block.ops: for origin_op in origin_block.ops:
...@@ -525,7 +525,7 @@ class DistributeTranspiler(object): ...@@ -525,7 +525,7 @@ class DistributeTranspiler(object):
outputs={}, outputs={},
attrs=attrs) attrs=attrs)
pserver_program.sync_with_cpp() pserver_program._sync_with_cpp()
return pserver_program return pserver_program
def get_startup_program(self, endpoint, pserver_program): def get_startup_program(self, endpoint, pserver_program):
...@@ -557,7 +557,7 @@ class DistributeTranspiler(object): ...@@ -557,7 +557,7 @@ class DistributeTranspiler(object):
pserver_vars = pserver_program.global_block().vars pserver_vars = pserver_program.global_block().vars
created_var_map = dict() created_var_map = dict()
for _, var in pserver_vars.iteritems(): for _, var in pserver_vars.iteritems():
tmpvar = s_prog.global_block().clone_variable(var) tmpvar = s_prog.global_block()._clone_variable(var)
created_var_map[var.name] = tmpvar created_var_map[var.name] = tmpvar
# 2. rename op outputs # 2. rename op outputs
...@@ -760,7 +760,7 @@ class DistributeTranspiler(object): ...@@ -760,7 +760,7 @@ class DistributeTranspiler(object):
self.all_prefetch_output_vars.append(prefetch_output_vars) self.all_prefetch_output_vars.append(prefetch_output_vars)
# insert split_ids_op # insert split_ids_op
program.global_block().insert_op( program.global_block()._insert_op(
index=lookup_table_op_index, index=lookup_table_op_index,
type="split_ids", type="split_ids",
inputs={ inputs={
...@@ -772,7 +772,7 @@ class DistributeTranspiler(object): ...@@ -772,7 +772,7 @@ class DistributeTranspiler(object):
outputs={"Out": prefetch_input_vars}) outputs={"Out": prefetch_input_vars})
# insert prefetch_op # insert prefetch_op
program.global_block().insert_op( program.global_block()._insert_op(
index=lookup_table_op_index + 1, index=lookup_table_op_index + 1,
type="prefetch", type="prefetch",
inputs={'X': prefetch_input_vars}, inputs={'X': prefetch_input_vars},
...@@ -783,7 +783,7 @@ class DistributeTranspiler(object): ...@@ -783,7 +783,7 @@ class DistributeTranspiler(object):
}) })
# insert concat_op # insert concat_op
program.global_block().insert_op( program.global_block()._insert_op(
index=lookup_table_op_index + 2, index=lookup_table_op_index + 2,
type="merge_ids", type="merge_ids",
inputs={ inputs={
...@@ -814,14 +814,14 @@ class DistributeTranspiler(object): ...@@ -814,14 +814,14 @@ class DistributeTranspiler(object):
if table_grad_name in op.output_arg_names: if table_grad_name in op.output_arg_names:
op_index = list(all_ops).index(op) op_index = list(all_ops).index(op)
# insert split_ids_op # insert split_ids_op
program.global_block().insert_op( program.global_block()._insert_op(
index=op_index + 1, index=op_index + 1,
type="split_ids", type="split_ids",
inputs={ inputs={
'Ids': [program.global_block().vars[table_grad_name]] 'Ids': [program.global_block().vars[table_grad_name]]
}, },
outputs={"Out": self.trainer_side_table_grad_list}) outputs={"Out": self.trainer_side_table_grad_list})
program.global_block().insert_op( program.global_block()._insert_op(
index=op_index + 2, index=op_index + 2,
type="send", type="send",
inputs={'X': self.trainer_side_table_grad_list}, inputs={'X': self.trainer_side_table_grad_list},
...@@ -880,7 +880,7 @@ class DistributeTranspiler(object): ...@@ -880,7 +880,7 @@ class DistributeTranspiler(object):
persistable=True) persistable=True)
# parameter must be selected rows # parameter must be selected rows
param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS) param_var.desc.set_type(core.VarDesc.VarType.SELECTED_ROWS)
grad_var = pserver_program.global_block().clone_variable( grad_var = pserver_program.global_block()._clone_variable(
self.origin_program.global_block().vars[grad_var_name( self.origin_program.global_block().vars[grad_var_name(
self.table_name)]) self.table_name)])
...@@ -920,7 +920,7 @@ class DistributeTranspiler(object): ...@@ -920,7 +920,7 @@ class DistributeTranspiler(object):
if not splited_grad_name.startswith(origin_grad_name): if not splited_grad_name.startswith(origin_grad_name):
raise ValueError("origin_grad_var: " + splited_grad_name + raise ValueError("origin_grad_var: " + splited_grad_name +
" grad_var:" + grad_var.name) " grad_var:" + grad_var.name)
grad_var = pserver_program.global_block().rename_var( grad_var = pserver_program.global_block()._rename_var(
origin_grad_name, splited_grad_name) origin_grad_name, splited_grad_name)
lr_var = pserver_program.global_block().vars[table_opt_op.input( lr_var = pserver_program.global_block().vars[table_opt_op.input(
...@@ -996,7 +996,7 @@ class DistributeTranspiler(object): ...@@ -996,7 +996,7 @@ class DistributeTranspiler(object):
if self.sync_mode and add_trainer_suffix: if self.sync_mode and add_trainer_suffix:
new_var_name = "%s.trainer_%d" % \ new_var_name = "%s.trainer_%d" % \
(orig_var.name, self.trainer_id) (orig_var.name, self.trainer_id)
program.global_block().rename_var(varname, new_var_name) program.global_block()._rename_var(varname, new_var_name)
var_mapping[varname] = \ var_mapping[varname] = \
[program.global_block().var(new_var_name)] [program.global_block().var(new_var_name)]
else: else:
...@@ -1030,8 +1030,7 @@ class DistributeTranspiler(object): ...@@ -1030,8 +1030,7 @@ class DistributeTranspiler(object):
type=orig_var.type, type=orig_var.type,
shape=splited_shape) # flattend splited var shape=splited_shape) # flattend splited var
var_mapping[varname].append(var) var_mapping[varname].append(var)
program.global_block().sync_with_cpp() program.global_block()._sync_with_cpp()
return var_mapping return var_mapping
def create_splited_vars(self, source_var, block, tag): def create_splited_vars(self, source_var, block, tag):
...@@ -1059,7 +1058,7 @@ class DistributeTranspiler(object): ...@@ -1059,7 +1058,7 @@ class DistributeTranspiler(object):
height_sections = [] height_sections = []
for v in splited_vars: for v in splited_vars:
height_sections.append(v.shape[0]) height_sections.append(v.shape[0])
program.global_block().insert_op( program.global_block()._insert_op(
index=index + 1, index=index + 1,
type="split_selected_rows", type="split_selected_rows",
inputs={"X": orig_var}, inputs={"X": orig_var},
...@@ -1069,7 +1068,7 @@ class DistributeTranspiler(object): ...@@ -1069,7 +1068,7 @@ class DistributeTranspiler(object):
sections = [] sections = []
for v in splited_vars: for v in splited_vars:
sections.append(v.shape[0]) sections.append(v.shape[0])
program.global_block().insert_op( program.global_block()._insert_op(
index=index + 1, index=index + 1,
type="split_byref", type="split_byref",
inputs={"X": orig_var}, inputs={"X": orig_var},
...@@ -1258,7 +1257,7 @@ class DistributeTranspiler(object): ...@@ -1258,7 +1257,7 @@ class DistributeTranspiler(object):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
if var not in program.global_block().vars: if var not in program.global_block().vars:
block.clone_variable(var) block._clone_variable(var)
outputs = self._get_output_map_from_op( outputs = self._get_output_map_from_op(
self.origin_program.global_block().vars, op) self.origin_program.global_block().vars, op)
...@@ -1267,7 +1266,7 @@ class DistributeTranspiler(object): ...@@ -1267,7 +1266,7 @@ class DistributeTranspiler(object):
varlist = [varlist] varlist = [varlist]
for var in varlist: for var in varlist:
if var not in program.global_block().vars: if var not in program.global_block().vars:
block.clone_variable(var) block._clone_variable(var)
return block.append_op( return block.append_op(
type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs) type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs)
...@@ -1305,7 +1304,7 @@ class DistributeTranspiler(object): ...@@ -1305,7 +1304,7 @@ class DistributeTranspiler(object):
if grad_block: if grad_block:
outputs[key] = grad_block outputs[key] = grad_block
elif not program.global_block().vars.has_key(var.name): elif not program.global_block().vars.has_key(var.name):
program.global_block().clone_variable(var) program.global_block()._clone_variable(var)
return optimize_block.append_op( return optimize_block.append_op(
type=opt_op.type, type=opt_op.type,
......
...@@ -95,7 +95,7 @@ class InferenceTranspiler(object): ...@@ -95,7 +95,7 @@ class InferenceTranspiler(object):
# modify bnorm OP to include relu # modify bnorm OP to include relu
current_op.set_attr("fuse_with_relu", True) current_op.set_attr("fuse_with_relu", True)
# remove relu OP # remove relu OP
self.block.remove_op(i + 1) self.block._remove_op(i + 1)
i = i + 1 i = i + 1
self._remove_unused_var() self._remove_unused_var()
...@@ -171,7 +171,7 @@ class InferenceTranspiler(object): ...@@ -171,7 +171,7 @@ class InferenceTranspiler(object):
# fuse batch_norm # fuse batch_norm
self._fuse_param(current_op, next_op, bias_op, 0) self._fuse_param(current_op, next_op, bias_op, 0)
# remove batch_norm_op # remove batch_norm_op
self.block.remove_op(i + 2) self.block._remove_op(i + 2)
i = i + 1 i = i + 1
# conv2d with bias, the next_op.type is elementwise_add # conv2d with bias, the next_op.type is elementwise_add
elif (next_op.type == 'elementwise_add'): elif (next_op.type == 'elementwise_add'):
...@@ -180,7 +180,7 @@ class InferenceTranspiler(object): ...@@ -180,7 +180,7 @@ class InferenceTranspiler(object):
# fuse batch_norm # fuse batch_norm
self._fuse_param(current_op, next_next_op, next_op, 1) self._fuse_param(current_op, next_next_op, next_op, 1)
# remove batch_norm_op # remove batch_norm_op
self.block.remove_op(i + 2) self.block._remove_op(i + 2)
i = i + 1 i = i + 1
i = i + 1 i = i + 1
...@@ -212,7 +212,7 @@ class InferenceTranspiler(object): ...@@ -212,7 +212,7 @@ class InferenceTranspiler(object):
y_var = self.block.var(bn_op.input("Bias")[0]) y_var = self.block.var(bn_op.input("Bias")[0])
out_var = self.block.var(bn_op.output("Y")[0]) out_var = self.block.var(bn_op.output("Y")[0])
bias_op = self.block.insert_op( bias_op = self.block._insert_op(
index, index,
type="elementwise_add", type="elementwise_add",
inputs={"X": x_var, inputs={"X": x_var,
...@@ -307,4 +307,4 @@ class InferenceTranspiler(object): ...@@ -307,4 +307,4 @@ class InferenceTranspiler(object):
for var in self.block.vars.keys(): for var in self.block.vars.keys():
if var not in args: if var not in args:
self.block.remove_var(var) self.block._remove_var(var)
...@@ -177,7 +177,7 @@ class ControlFlowGraph(object): ...@@ -177,7 +177,7 @@ class ControlFlowGraph(object):
in_diff) in_diff)
if can_optimize: if can_optimize:
index = i + fwd_id + 1 if is_forward else i - self._forward_num + bwd_id + 1 index = i + fwd_id + 1 if is_forward else i - self._forward_num + bwd_id + 1
delete_op = block_desc.insert_op(index) delete_op = block_desc._insert_op(index)
delete_op.set_type("delete_var") delete_op.set_type("delete_var")
delete_op.set_input("X", can_optimize) delete_op.set_input("X", can_optimize)
if is_forward: if is_forward:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册