From 4658f9501efd05396b796297f81bf17de37bda9f Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 21 Dec 2017 20:07:54 +0800 Subject: [PATCH] fix delete ops --- paddle/framework/block_desc.cc | 15 +++++++++++++++ paddle/framework/block_desc.h | 2 ++ paddle/pybind/protobuf.cc | 1 + python/paddle/v2/fluid/distribute_transpiler.py | 10 +++++----- python/paddle/v2/fluid/framework.py | 12 ++++++++++-- 5 files changed, 33 insertions(+), 7 deletions(-) diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 6a7a07d5cf..4707e48353 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -91,6 +91,21 @@ OpDescBind *BlockDescBind::PrependOp() { return ops_.front().get(); } +void BlockDescBind::RemoveOp(size_t s, size_t e) { + if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { + return; + } + need_update_ = true; + for (auto it = ops_.begin() + s; it != ops_.begin() + e; it++) { + auto names = (*it)->InputArgumentNames(); + for (auto n : names) { + // TODO(typhoonzero): delete vars if no other op use it. + VLOG(3) << "deleting var " << n; + } + } + ops_.erase(ops_.begin() + s, ops_.begin() + e); +} + std::vector BlockDescBind::AllOps() const { std::vector res; for (const auto &op : ops_) { diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 8e967e5378..51b0e75c55 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -80,6 +80,8 @@ class BlockDescBind { OpDescBind *PrependOp(); + void RemoveOp(size_t s, size_t e); + std::vector AllOps() const; size_t OpSize() const { return ops_.size(); } diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 6e6cafafb9..119cae94fb 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -159,6 +159,7 @@ void BindBlockDesc(py::module &m) { py::return_value_policy::reference) .def("prepend_op", &BlockDescBind::PrependOp, py::return_value_policy::reference) + .def("remove_op", &BlockDescBind::RemoveOp) .def("var", [](BlockDescBind &self, py::bytes byte_name) { std::string name = byte_name; diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index 7dfbab4677..50364c64be 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -131,11 +131,6 @@ class DistributeTranspiler: def _optimize_distributed(self, optimize_ops, program, params_and_grads, **kwargs): - # remove optimize ops and add a send op to main_program - # FIXME(typhoonzero): delete_op only remove the first accurance, - # need to consider about multiple same optimize op? - for op in optimize_ops: - program.global_block().delete_op(op) if kwargs.has_key("split_method"): split_method = kwargs["split_method"] else: @@ -159,6 +154,10 @@ class DistributeTranspiler: attrs={"endpoints": pserver_endpoints, "epmap": epmap}) + def get_trainer_program(optimize_ops, program): + # remove optimize ops and add a send op to main_program + program.global_block().delete_ops(optimize_ops) + def _create_var_for_trainers(self, block, var, trainers): var_list = [] for i in xrange(trainers): @@ -209,6 +208,7 @@ class DistributeTranspiler: if opt_op.inputs.has_key("Grad"): if opt_op.inputs["Grad"].name in grad_var_names: + print "appending ", opt_op.type, opt_op.inputs optimize_sub_program.global_block().append_op( type=opt_op.type, inputs=opt_op.inputs, diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index 7990886417..a409b2aa94 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -579,6 +579,7 @@ class Block(object): self.vars = dict() # var_name --> var self.ops = collections.deque() # operator list self.program = program + self.removed_vars = dict() def __str__(self): return self.to_string(True) @@ -635,8 +636,15 @@ class Block(object): self.ops.append(op) return op - def delete_op(self, op): - self.ops.remove(op) + def delete_ops(self, ops): + # remove from cpp + # FIXME(typhoonzero): remove only the first occuracy. + try: + start = list(self.ops).index(ops[0]) + end = list(self.ops).index(ops[-1]) + except Exception, e: + raise e + self.desc.remove_op(start, end) def prepend_op(self, *args, **kwargs): op_desc = self.desc.prepend_op() -- GitLab