提交 edba405d 编写于 作者: F fengjiayi

Pass test_dyn_rnn.py

上级 dcc51da4
...@@ -90,6 +90,14 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, ...@@ -90,6 +90,14 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
need_update_ = true; need_update_ = true;
} }
void OpDescBind::CopyFrom(const OpDescBind &op_desc) {
desc_.set_type(op_desc.Type());
inputs_ = op_desc.inputs_;
outputs_ = op_desc.outputs_;
attrs_ = op_desc.attrs_;
need_update_ = true;
}
OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog) OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
: desc_(desc), need_update_(false) { : desc_(desc), need_update_(false) {
// restore inputs_ // restore inputs_
......
...@@ -35,6 +35,8 @@ class OpDescBind { ...@@ -35,6 +35,8 @@ class OpDescBind {
OpDescBind(const OpDesc &desc, ProgramDescBind *prog); OpDescBind(const OpDesc &desc, ProgramDescBind *prog);
void CopyFrom(const OpDescBind &op_desc);
OpDesc *Proto(); OpDesc *Proto();
std::string Type() const { return desc_.type(); } std::string Type() const { return desc_.type(); }
......
...@@ -72,7 +72,7 @@ const TensorDesc &VarDescBind::tensor_desc() const { ...@@ -72,7 +72,7 @@ const TensorDesc &VarDescBind::tensor_desc() const {
case VarDesc::LOD_TENSOR_ARRAY: case VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().tensor(); return desc_.tensor_array().tensor();
default: default:
PADDLE_THROW("Unexpected branch."); PADDLE_THROW("The type of var '", this->Name(), "' is unsupported.");
} }
} }
......
...@@ -255,6 +255,7 @@ void BindOpDesc(py::module &m) { ...@@ -255,6 +255,7 @@ void BindOpDesc(py::module &m) {
op_desc op_desc
.def("__init__", [](OpDescBind &self) { new (&self) OpDescBind(); }, .def("__init__", [](OpDescBind &self) { new (&self) OpDescBind(); },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("copy_from", &OpDescBind::CopyFrom)
.def("type", &OpDescBind::Type) .def("type", &OpDescBind::Type)
.def("set_type", &OpDescBind::SetType) .def("set_type", &OpDescBind::SetType)
.def("input", &OpDescBind::Input) .def("input", &OpDescBind::Input)
......
...@@ -32,6 +32,16 @@ def _create_op_desc_(op_type, inputs, outputs, attrs): ...@@ -32,6 +32,16 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
return op_desc return op_desc
def _infer_var_data_type_(var_name, block):
grad_var = block.desc.find_var(var_name.encode("ascii"))
fwd_name = _strip_grad_suffix_(var_name.encode("ascii"))
if block.desc.has_var_recursive(fwd_name):
fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
grad_var.set_dtype(fwd_var.dtype())
else:
grad_var.set_dtype(core.DataType.FP32)
def _is_all_in_set_(cands, s): def _is_all_in_set_(cands, s):
for c in cands: for c in cands:
if not c in s: if not c in s:
...@@ -64,7 +74,7 @@ def _backward_impl_(target, ...@@ -64,7 +74,7 @@ def _backward_impl_(target,
grad_sub_block = program.create_block(parent_idx=sub_block_idx) grad_sub_block = program.create_block(parent_idx=sub_block_idx)
_backward_impl_(target, sub_block, grad_sub_block, no_grad_set, _backward_impl_(target, sub_block, grad_sub_block, no_grad_set,
grad_info_map, callback) grad_info_map, callback)
grad_sub_block_list.append(grad_sub_block) grad_sub_block_list.append(grad_sub_block.desc)
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
each_op.desc, no_grad_set[block.idx], grad_sub_block_list) each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
grad_op_descs.append(grad_op_desc) grad_op_descs.append(grad_op_desc)
...@@ -80,17 +90,18 @@ def _backward_impl_(target, ...@@ -80,17 +90,18 @@ def _backward_impl_(target,
for var_name in op_desc.input_arg_names(): for var_name in op_desc.input_arg_names():
if len(var_inputs[var_name]) > 1: if len(var_inputs[var_name]) > 1:
pending_sum_ops.append((_create_op_desc_( pending_sum_ops.append((_create_op_desc_(
op_type="sum_op", op_type="sum",
inputs=var_inputs[var_name], inputs={"X": var_inputs[var_name]},
outputs=[var_name], outputs={"Out": [var_name]},
attrs={}), idx)) attrs={}), idx))
var_inputs[var_name] = [var_name] var_inputs[var_name] = [var_name]
for var_name in op_desc.output_arg_names(): for var_name in op_desc.output_arg_names():
if len(var_inputs[var_name]) == 0: if var_name == core.empty_var_name() or len(var_inputs[
var_name]) == 0:
# it's the first time we get the variable # it's the first time we get the variable
var_inputs[var_name] = [var_name] var_inputs[var_name] = [var_name]
else: else:
if len(var_inputs[var_name] == 1): if len(var_inputs[var_name]) == 1:
new_name = var_name + "@RENAME@" + \ new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name]) str(var_rename_count[var_name])
var_rename_count[var_name] = var_rename_count[var_name] + 1 var_rename_count[var_name] = var_rename_count[var_name] + 1
...@@ -107,7 +118,7 @@ def _backward_impl_(target, ...@@ -107,7 +118,7 @@ def _backward_impl_(target,
for var_name, inputs in var_inputs.iteritems(): for var_name, inputs in var_inputs.iteritems():
if len(inputs) > 1: if len(inputs) > 1:
pending_sum_ops.append((_create_op_desc_( pending_sum_ops.append((_create_op_desc_(
op_type="sum_op", op_type="sum",
inputs={"X": inputs}, inputs={"X": inputs},
outputs={"Out": var_name}, outputs={"Out": var_name},
attrs={}), len(grad_op_descs))) attrs={}), len(grad_op_descs)))
...@@ -131,13 +142,15 @@ def _backward_impl_(target, ...@@ -131,13 +142,15 @@ def _backward_impl_(target,
{}) {})
grad_op_descs.insert(ele[1], fill_zeros_like_op) grad_op_descs.insert(ele[1], fill_zeros_like_op)
# create new gradient variables in the target block desc # create new gradient variables in the target block desc
new_vars = set()
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
for grad_var_name in op_desc.output_arg_names(): for grad_var_name in op_desc.output_arg_names():
grad_var_name = grad_var_name.encode("ascii") grad_var_name = grad_var_name.encode("ascii")
if target_block.desc.has_var( if target_block.desc.has_var_recursive(
grad_var_name) or grad_var_name == core.empty_var_name(): grad_var_name) or grad_var_name == core.empty_var_name():
continue continue
target_block.desc.var(grad_var_name) target_block.desc.var(grad_var_name)
new_vars.add(grad_var_name)
if not grad_to_var.has_key(grad_var_name): if not grad_to_var.has_key(grad_var_name):
continue continue
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name, grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
...@@ -160,7 +173,11 @@ def _backward_impl_(target, ...@@ -160,7 +173,11 @@ def _backward_impl_(target,
for op_desc in grad_op_descs: for op_desc in grad_op_descs:
op_desc.infer_var_type(target_block.desc) op_desc.infer_var_type(target_block.desc)
op_desc.infer_shape(target_block.desc) op_desc.infer_shape(target_block.desc)
target_block.desc.append_allocated_op(op_desc) for arg in op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_(arg, target_block)
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc)
target_block.sync_with_cpp() target_block.sync_with_cpp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册