提交 4658f950 编写于 作者: T typhoonzero

fix delete ops

上级 70729270
......@@ -91,6 +91,21 @@ OpDescBind *BlockDescBind::PrependOp() {
return ops_.front().get();
}
void BlockDescBind::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return;
}
need_update_ = true;
for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) {
auto names = (*it)->InputArgumentNames();
for (auto n : names) {
// TODO(typhoonzero): delete vars if no other op use it.
VLOG(3) << "deleting var " << n;
}
}
ops_.erase(ops_.begin() + s, ops_.begin() + e);
}
std::vector<OpDescBind *> BlockDescBind::AllOps() const {
std::vector<OpDescBind *> res;
for (const auto &op : ops_) {
......
......@@ -80,6 +80,8 @@ class BlockDescBind {
OpDescBind *PrependOp();
void RemoveOp(size_t s, size_t e);
std::vector<OpDescBind *> AllOps() const;
size_t OpSize() const { return ops_.size(); }
......
......@@ -159,6 +159,7 @@ void BindBlockDesc(py::module &m) {
py::return_value_policy::reference)
.def("prepend_op", &BlockDescBind::PrependOp,
py::return_value_policy::reference)
.def("remove_op", &BlockDescBind::RemoveOp)
.def("var",
[](BlockDescBind &self, py::bytes byte_name) {
std::string name = byte_name;
......
......@@ -131,11 +131,6 @@ class DistributeTranspiler:
def _optimize_distributed(self, optimize_ops, program, params_and_grads,
**kwargs):
# remove optimize ops and add a send op to main_program
# FIXME(typhoonzero): delete_op only remove the first accurance,
# need to consider about multiple same optimize op?
for op in optimize_ops:
program.global_block().delete_op(op)
if kwargs.has_key("split_method"):
split_method = kwargs["split_method"]
else:
......@@ -159,6 +154,10 @@ class DistributeTranspiler:
attrs={"endpoints": pserver_endpoints,
"epmap": epmap})
def get_trainer_program(optimize_ops, program):
# remove optimize ops and add a send op to main_program
program.global_block().delete_ops(optimize_ops)
def _create_var_for_trainers(self, block, var, trainers):
var_list = []
for i in xrange(trainers):
......@@ -209,6 +208,7 @@ class DistributeTranspiler:
if opt_op.inputs.has_key("Grad"):
if opt_op.inputs["Grad"].name in grad_var_names:
print "appending ", opt_op.type, opt_op.inputs
optimize_sub_program.global_block().append_op(
type=opt_op.type,
inputs=opt_op.inputs,
......
......@@ -579,6 +579,7 @@ class Block(object):
self.vars = dict() # var_name --> var
self.ops = collections.deque() # operator list
self.program = program
self.removed_vars = dict()
def __str__(self):
return self.to_string(True)
......@@ -635,8 +636,15 @@ class Block(object):
self.ops.append(op)
return op
def delete_op(self, op):
self.ops.remove(op)
def delete_ops(self, ops):
# remove from cpp
# FIXME(typhoonzero): remove only the first occuracy.
try:
start = list(self.ops).index(ops[0])
end = list(self.ops).index(ops[-1])
except Exception, e:
raise e
self.desc.remove_op(start, end)
def prepend_op(self, *args, **kwargs):
op_desc = self.desc.prepend_op()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册