提交 278ac7be 编写于 作者: F fengjiayi

Compelete basic framework

上级 61a7df2e
...@@ -240,13 +240,7 @@ void BindOpDesc(py::module &m) { ...@@ -240,13 +240,7 @@ void BindOpDesc(py::module &m) {
.value("BLOCK", AttrType::BLOCK); .value("BLOCK", AttrType::BLOCK);
py::class_<OpDescBind> op_desc(m, "OpDesc", ""); py::class_<OpDescBind> op_desc(m, "OpDesc", "");
op_desc op_desc.def("__init__", [](OpDescBind &self) { new (&self) OpDescBind(); })
.def("__init__",
[](OpDescBind &self, const std::string &type,
const VariableNameMap &inputs, const VariableNameMap &outputs,
const AttributeMap &attrs) {
new (&self) OpDescBind(type, inputs, outputs, attrs);
})
.def("type", &OpDescBind::Type) .def("type", &OpDescBind::Type)
.def("set_type", &OpDescBind::SetType) .def("set_type", &OpDescBind::SetType)
.def("input", &OpDescBind::Input) .def("input", &OpDescBind::Input)
......
...@@ -285,8 +285,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -285,8 +285,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("get_grad_op_desc", m.def("get_grad_op_desc",
[](const OpDescBind &op_desc, [](const OpDescBind &op_desc,
const std::unordered_set<std::string> &no_grad_set, const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> &grad_to_var,
const std::vector<BlockDescBind *> &grad_sub_block) { const std::vector<BlockDescBind *> &grad_sub_block) {
std::unordered_map<std::string, std::string> grad_to_var;
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs = std::vector<std::unique_ptr<OpDescBind>> grad_op_descs =
framework::OpInfoMap::Instance() framework::OpInfoMap::Instance()
.Get(op_desc.Type()) .Get(op_desc.Type())
...@@ -297,7 +297,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -297,7 +297,7 @@ All parameter, weight, gradient are variables in Paddle.
grad_op_descs.begin(), grad_op_descs.end(), grad_op_descs.begin(), grad_op_descs.end(),
grad_op_desc_ptrs.begin(), grad_op_desc_ptrs.begin(),
[](std::unique_ptr<OpDescBind> &p) { return p.release(); }); [](std::unique_ptr<OpDescBind> &p) { return p.release(); });
return grad_op_desc_ptrs; return std::make_pair(grad_op_desc_ptrs, grad_to_var);
}); });
m.def("prune", [](const ProgramDescBind &origin, m.def("prune", [](const ProgramDescBind &origin,
const std::vector<std::array<size_t, 2>> &targets) { const std::vector<std::array<size_t, 2>> &targets) {
......
...@@ -6,7 +6,8 @@ import pdb ...@@ -6,7 +6,8 @@ import pdb
__all__ = ['append_backward_ops'] __all__ = ['append_backward_ops']
def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None): def _rename_arg_(op_desc_list, old_name, new_name, begin_idx=None,
end_idx=None):
if begin_idx is None: if begin_idx is None:
begin_idx = 0 begin_idx = 0
if end_idx is None: if end_idx is None:
...@@ -16,6 +17,21 @@ def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None): ...@@ -16,6 +17,21 @@ def rename_arg(op_desc_list, old_name, new_name, begin_idx=None, end_idx=None):
op_desc_list[i].rename_output(old_name, new_name) op_desc_list[i].rename_output(old_name, new_name)
def _create_op_desc_(op_type, inputs, outputs, attrs):
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for para, args in inputs.iteritems():
op_desc.set_input(para, args)
for para, args in outputs.iteritems():
op_desc.set_output(para, args)
for name, val in attrs.iteritems():
if isinstance(val, framework.Block):
op_desc.set_block_attr(name, val.desc)
else:
op_desc.set_attr(name, val)
return op_desc
def backward_impl(target, def backward_impl(target,
block, block,
target_block, target_block,
...@@ -23,9 +39,9 @@ def backward_impl(target, ...@@ -23,9 +39,9 @@ def backward_impl(target,
grad_info_map, grad_info_map,
callback=None): callback=None):
grad_op_descs = [] grad_op_descs = []
grad_to_var = {} grad_to_var = dict()
program = block.program program = block.program
for each_op in block.ops: for each_op in reversed(block.ops):
grad_sub_block_list = [] grad_sub_block_list = []
if each_op.has_attr("sub_block"): if each_op.has_attr("sub_block"):
sub_block_idx = each_op.block_attr("sub_block") sub_block_idx = each_op.block_attr("sub_block")
...@@ -34,10 +50,10 @@ def backward_impl(target, ...@@ -34,10 +50,10 @@ def backward_impl(target,
backward_impl(target, sub_block, grad_sub_block, no_grad_set, backward_impl(target, sub_block, grad_sub_block, no_grad_set,
grad_info_map, callback) grad_info_map, callback)
grad_sub_block_list.append(grad_sub_block) grad_sub_block_list.append(grad_sub_block)
grad_op_desc = core.get_grad_op_desc(each_op.desc, grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
no_grad_set[block.idx], each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
grad_to_var, grad_sub_block_list)
grad_op_descs.append(grad_op_desc) grad_op_descs.append(grad_op_desc)
grad_to_var = dict(grad_to_var, **op_grad_to_var)
# grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...] # grad_op_descs = [[op1_g1, op1_g2], [op2_g], ...]
# flatten grad_op_descs # flatten grad_op_descs
grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ????? grad_op_descs = [op for sublist in grad_op_descs for op in sublist] # ?????
...@@ -48,11 +64,10 @@ def backward_impl(target, ...@@ -48,11 +64,10 @@ def backward_impl(target,
for pos, op_desc in enumerate(grad_op_descs): for pos, op_desc in enumerate(grad_op_descs):
for var_name in op_desc.input_arg_names(): for var_name in op_desc.input_arg_names():
if len(var_inputs[var_name]) > 1: if len(var_inputs[var_name]) > 1:
pdb.set_trace() pending_sum_ops.append((_create_op_desc_(
pending_sum_ops.append((core.OpDesc( op_type="sum_op",
type="sum_op",
inputs=var_inputs[var_name], inputs=var_inputs[var_name],
output=[var_name], outputs=[var_name],
attrs={}), pos)) attrs={}), pos))
var_inputs[var_name] = [var_name] var_inputs[var_name] = [var_name]
for var_name in op_desc.output_arg_names(): for var_name in op_desc.output_arg_names():
...@@ -66,8 +81,8 @@ def backward_impl(target, ...@@ -66,8 +81,8 @@ def backward_impl(target,
var_rename_count[var_name] = var_rename_count[var_name] + 1 var_rename_count[var_name] = var_rename_count[var_name] + 1
# rename original var_name # rename original var_name
var_inputs[var_name][0] = new_name var_inputs[var_name][0] = new_name
rename_arg(grad_op_descs, var_name, new_name, 0, pos) _rename_arg_(grad_op_descs, var_name, new_name, 0, pos)
rename_arg(pending_sum_ops, var_name, new_name) _rename_arg_(pending_sum_ops, var_name, new_name)
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
...@@ -76,10 +91,11 @@ def backward_impl(target, ...@@ -76,10 +91,11 @@ def backward_impl(target,
var_inputs[var_name].append(new_name) var_inputs[var_name].append(new_name)
for var_name, inputs in var_inputs.iteritems(): for var_name, inputs in var_inputs.iteritems():
if len(inputs) > 1: if len(inputs) > 1:
pdb.set_trace() pending_sum_ops.append((_create_op_desc_(
pending_sum_ops.append((core.OpDesc("sum_op", {"X": inputs}, op_type="sum_op",
{"Out": var_name}, {}), inputs={"X": inputs},
len(grad_op_descs))) outputs={"Out": var_name},
attrs={}), len(grad_op_descs)))
# TODO: remove op in no grad set # TODO: remove op in no grad set
# 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的 # 根据append的顺序可以看出pending_sum_ops一定是根据sum_op的插入位置排序的
...@@ -103,15 +119,22 @@ def backward_impl(target, ...@@ -103,15 +119,22 @@ def backward_impl(target,
target_block.desc.var(grad_target_name) target_block.desc.var(grad_target_name)
grad_op_descs.insert( grad_op_descs.insert(
0, 0,
core.OpDesc(u"fill_constant", {}, { _create_op_desc_(
u"Out": [unicode(grad_target_name, "ascii")] op_type="fill_constant",
}, {u"shape": (1), inputs={},
u"value": 1.0, outputs={"Out": [grad_target_name]},
u"dtype": core.DataType.FP32})) attrs={
"shape": [1],
"value": 1.0,
"dtype": core.DataType.FP32
}))
# insert backward operators to target_block # insert backward operators to target_block
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
op_desc.infer_var_type(target_block.desc)
op_desc.infer_shape(target_block.desc)
target_block.desc.append_allocated_op(op_desc) target_block.desc.append_allocated_op(op_desc)
pdb.set_trace()
target_block.sync_with_cpp() target_block.sync_with_cpp()
...@@ -147,6 +170,7 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None): ...@@ -147,6 +170,7 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
grad_info_map = dict() grad_info_map = dict()
root_block = loss.block.program.block(0) root_block = loss.block.program.block(0)
pdb.set_trace()
backward_impl(loss, root_block, root_block, no_grad_set, grad_info_map) backward_impl(loss, root_block, root_block, no_grad_set, grad_info_map)
pdb.set_trace() pdb.set_trace()
if parameter_list is not None: if parameter_list is not None:
...@@ -159,7 +183,7 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None): ...@@ -159,7 +183,7 @@ def append_backward_ops(loss, parameter_list=None, no_grad_set=None):
if param not in grad_info_map: if param not in grad_info_map:
raise ValueError("param %s is not in map" % param) raise ValueError("param %s is not in map" % param)
grad_info = grad_info_map[param] grad_info = grad_info_map[param]
grad_block = loss.block.program.block(grad_info[1]) grad_block = grad_info[1]
if not grad_block.has_var(grad_info[0]): if not grad_block.has_var(grad_info[0]):
raise ValueError("grad block[{0}] did not have grad var {1}".format( raise ValueError("grad block[{0}] did not have grad var {1}".format(
grad_info[1], grad_info[0])) grad_info[1], grad_info[0]))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册