提交 ea0cf6f3 编写于 作者: L Luo Tao

rewrite inference_transpiler in Python end

上级 16e31343
...@@ -13,11 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,11 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include <queue>
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include <queue>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -147,52 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) { ...@@ -147,52 +146,7 @@ void BlockDesc::RemoveOp(size_t s, size_t e) {
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) { if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
return; return;
} }
auto get_vars = [](std::deque<std::unique_ptr<OpDesc>>::iterator &op, ops_.erase(ops_.begin() + s, ops_.begin() + e);
std::vector<std::string> &v) {
auto in_names = (*op)->InputArgumentNames();
v.insert(v.end(), in_names.begin(), in_names.end());
auto out_names = (*op)->OutputArgumentNames();
v.insert(v.end(), out_names.begin(), out_names.end());
std::sort(v.begin(), v.end());
auto last = std::unique(v.begin(), v.end());
v.erase(last, v.end());
};
need_update_ = true;
for (size_t i = s; i < e; i++) {
// since remove op one by one, every time remove the first op.
auto op = ops_.begin() + s;
// collect input and output variables from current delete op
std::vector<std::string> cur_vars;
get_vars(op, cur_vars);
// remove current op
ops_.erase(ops_.begin() + s);
// collect input and output variables from other ops
std::vector<std::string> other_vars;
for (auto it = ops_.begin(); it != ops_.end(); it++) {
get_vars(it, other_vars);
}
// variables should be deleted
std::vector<std::string> delete_vars;
// delete_vars = cur_vars - cur_vars ^ other_input_vars
std::set_difference(cur_vars.begin(), cur_vars.end(), other_vars.begin(),
other_vars.end(),
std::inserter(delete_vars, delete_vars.end()));
// remove variables
for (size_t i = 0; i < delete_vars.size(); i++) {
auto name = delete_vars[i];
auto it = vars_.find(name);
PADDLE_ENFORCE(it != vars_.end(),
"%s is not in variable list, it should not be deleted",
name);
vars_.erase(it);
VLOG(3) << "deleting variable " << name;
}
}
} }
std::vector<OpDesc *> BlockDesc::AllOps() const { std::vector<OpDesc *> BlockDesc::AllOps() const {
......
...@@ -818,6 +818,11 @@ class Block(object): ...@@ -818,6 +818,11 @@ class Block(object):
del self.vars[name] del self.vars[name]
self.sync_with_cpp() self.sync_with_cpp()
def remove_var(self, name):
self.sync_with_cpp()
self.desc.remove_var(name)
del self.vars[name]
def create_parameter(self, *args, **kwargs): def create_parameter(self, *args, **kwargs):
global_block = self.program.global_block() global_block = self.program.global_block()
param = Parameter(global_block, *args, **kwargs) param = Parameter(global_block, *args, **kwargs)
...@@ -838,6 +843,11 @@ class Block(object): ...@@ -838,6 +843,11 @@ class Block(object):
self.ops.insert(index, op) self.ops.insert(index, op)
return op return op
def remove_op(self, index):
self.sync_with_cpp()
self.desc.remove_op(index, index + 1)
del self.ops[index]
def delete_ops(self, ops): def delete_ops(self, ops):
# remove from cpp # remove from cpp
# FIXME(typhoonzero): remove only the first occurrence. # FIXME(typhoonzero): remove only the first occurrence.
...@@ -846,6 +856,7 @@ class Block(object): ...@@ -846,6 +856,7 @@ class Block(object):
end = list(self.ops).index(ops[-1]) end = list(self.ops).index(ops[-1])
except Exception, e: except Exception, e:
raise e raise e
self.desc.remove_op(start, end + 1) self.desc.remove_op(start, end + 1)
def slice_ops(self, start, end): def slice_ops(self, start, end):
...@@ -920,15 +931,6 @@ class Block(object): ...@@ -920,15 +931,6 @@ class Block(object):
ops_in_cpp_index += 1 ops_in_cpp_index += 1
ops_in_python_index += 1 ops_in_python_index += 1
# sync ops inserted from c++ end
if len(self.ops) != len(ops_in_cpp) and start_index == 0 and len(
self.ops) == end_index:
del self.ops[:]
for index in range(len(ops_in_cpp)):
op_desc = ops_in_cpp[index]
op = Operator(self, op_desc)
self.ops.append(op)
assert len(self.ops) == len(ops_in_cpp) assert len(self.ops) == len(ops_in_cpp)
for index in range(len(self.ops)): for index in range(len(self.ops)):
assert self.ops[index].desc == ops_in_cpp[index] assert self.ops[index].desc == ops_in_cpp[index]
......
...@@ -61,30 +61,26 @@ class InferenceTranspiler: ...@@ -61,30 +61,26 @@ class InferenceTranspiler:
''' '''
self.scope = scope self.scope = scope
self.place = place self.place = place
self.block_desc = program.get_desc().block(0) self.block = program.block(0)
i = 0 i = 0
while i < self.block_desc.op_size(): while i < len(self.block.ops):
current_op = self.block_desc.op(i) current_op = self.block.ops[i]
# TODO(luotao1): consider only conv2d now. fc would be delt later. # TODO(luotao1): consider only conv2d now. fc would be delt later.
if current_op.type() in ['conv2d']: if current_op.type in ['conv2d']:
next_op = self.block_desc.op(i + 1) next_op = self.block.ops[i + 1]
# TODO(luotao1): consider only conv2d without bias now. # TODO(luotao1): consider only conv2d without bias now.
# If conv2d with bias, the next_op.type is elementwise_add. # If conv2d with bias, the next_op.type is elementwise_add.
if (next_op.type() == 'batch_norm'): if (next_op.type == 'batch_norm'):
# insert bias op # insert bias op
bias_op = self._insert_bias_op(i + 1, current_op, next_op) bias_op = self._insert_bias_op(i + 1, current_op, next_op)
program.sync_with_cpp()
# fuse batch_norm # fuse batch_norm
self._fuse_param(current_op, next_op, bias_op) self._fuse_param(current_op, next_op, bias_op)
# remove batch_norm_op # remove batch_norm_op
self.block_desc.remove_op(i + 2, i + 3) self.block.remove_op(i + 2)
program.sync_with_cpp()
i = i + 1 i = i + 1
i = i + 1 i = i + 1
self._remove_unused_var() self._remove_unused_var()
program.sync_with_cpp()
return program return program
# ====================== private transpiler functions ===================== # ====================== private transpiler functions =====================
...@@ -102,14 +98,19 @@ class InferenceTranspiler: ...@@ -102,14 +98,19 @@ class InferenceTranspiler:
:return: bias_op :return: bias_op
:rtype: Operator :rtype: Operator
''' '''
bias_op = self.block_desc.insert_op(index)
bias_op.set_type("elementwise_add")
# The input of bias_op is current_op's output and Bias of bn_op # The input of bias_op is current_op's output and Bias of bn_op
# The output of bias_op is bn_op's output # The output of bias_op is bn_op's output
bias_op.set_input("X", current_op.output("Output")) x_var = self.block.var(current_op.output("Output")[0])
bias_op.set_input("Y", bn_op.input("Bias")) y_var = self.block.var(bn_op.input("Bias")[0])
bias_op.set_output("Out", bn_op.output("Y")) out_var = self.block.var(bn_op.output("Y")[0])
bias_op.set_attr('axis', 1) # dim_start=1
bias_op = self.block.insert_op(
index,
type="elementwise_add",
inputs={"X": x_var,
"Y": y_var},
outputs={"Out": out_var},
attrs={"axis": 1}) # dim_start=1
return bias_op return bias_op
def _fuse_param(self, current_op, bn_op, bias_op): def _fuse_param(self, current_op, bn_op, bias_op):
...@@ -160,15 +161,15 @@ class InferenceTranspiler: ...@@ -160,15 +161,15 @@ class InferenceTranspiler:
def _remove_unused_var(self): def _remove_unused_var(self):
''' '''
remove unused varibles in program desc remove unused varibles in program
''' '''
args = [] args = []
for i in xrange(0, self.block_desc.op_size()): for i in range(len(self.block.ops)):
current_op = self.block_desc.op(i) current_op = self.block.ops[i]
args += current_op.input_arg_names() args += current_op.input_arg_names
args += current_op.output_arg_names() args += current_op.output_arg_names
args = list(set(args)) # unique the input and output arguments args = list(set(args)) # unique the input and output arguments
for var in self.block_desc.all_vars(): for var in self.block.vars.keys():
if var.name() not in args: if var not in args:
self.block_desc.remove_var(var.name()) self.block.remove_var(var)
...@@ -236,8 +236,8 @@ def infer(use_cuda, save_dirname=None): ...@@ -236,8 +236,8 @@ def infer(use_cuda, save_dirname=None):
assert len(results[0]) == len(transpiler_results[0]) assert len(results[0]) == len(transpiler_results[0])
for i in range(len(results[0])): for i in range(len(results[0])):
np.testing.assert_almost_equal(results[0][i], np.testing.assert_almost_equal(
transpiler_results[0][i]) results[0][i], transpiler_results[0][i], decimal=6)
print("infer results: ", results[0]) print("infer results: ", results[0])
......
...@@ -201,24 +201,6 @@ class TestBlockDesc(unittest.TestCase): ...@@ -201,24 +201,6 @@ class TestBlockDesc(unittest.TestCase):
op1.set_type("test") op1.set_type("test")
op2.set_type("test") op2.set_type("test")
var0 = block.var("var0")
var1 = block.var("var1")
var2 = block.var("var2")
var3 = block.var("var3")
var4 = block.var("var4")
var5 = block.var("var5")
op0.set_input("X", ["var0"])
op0.set_output("Y", ["var0"])
op1.set_input("X", ["var1", "var2"])
op1.set_output("Y", ["var3", "var4"])
op2.set_input("X", ["var1"])
op2.set_output("Y", ["var4", "var5"])
program.sync_with_cpp()
# remove op1, its input var2 and output var3 will be removed at the same time,
# but its input var1 and output var4 will not be removed since they are used for op2.
block.remove_op(1, 2) block.remove_op(1, 2)
program.sync_with_cpp() program.sync_with_cpp()
...@@ -226,8 +208,6 @@ class TestBlockDesc(unittest.TestCase): ...@@ -226,8 +208,6 @@ class TestBlockDesc(unittest.TestCase):
for idx in xrange(0, block.op_size()): for idx in xrange(0, block.op_size()):
all_ops.append(block.op(idx)) all_ops.append(block.op(idx))
self.assertEqual(all_ops, [op0, op2]) self.assertEqual(all_ops, [op0, op2])
all_vars = block.all_vars()
self.assertEqual(set(all_vars), {var0, var1, var4, var5})
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册