提交 edba405d 编写于 作者: F fengjiayi

Pass test_dyn_rnn.py

上级 dcc51da4
......@@ -90,6 +90,14 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
need_update_ = true;
}
void OpDescBind::CopyFrom(const OpDescBind &op_desc) {
desc_.set_type(op_desc.Type());
inputs_ = op_desc.inputs_;
outputs_ = op_desc.outputs_;
attrs_ = op_desc.attrs_;
need_update_ = true;
}
OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
: desc_(desc), need_update_(false) {
// restore inputs_
......
......@@ -35,6 +35,8 @@ class OpDescBind {
OpDescBind(const OpDesc &desc, ProgramDescBind *prog);
void CopyFrom(const OpDescBind &op_desc);
OpDesc *Proto();
std::string Type() const { return desc_.type(); }
......
......@@ -72,7 +72,7 @@ const TensorDesc &VarDescBind::tensor_desc() const {
case VarDesc::LOD_TENSOR_ARRAY:
return desc_.tensor_array().tensor();
default:
PADDLE_THROW("Unexpected branch.");
PADDLE_THROW("The type of var '", this->Name(), "' is unsupported.");
}
}
......
......@@ -255,6 +255,7 @@ void BindOpDesc(py::module &m) {
op_desc
.def("__init__", [](OpDescBind &self) { new (&self) OpDescBind(); },
py::return_value_policy::reference)
.def("copy_from", &OpDescBind::CopyFrom)
.def("type", &OpDescBind::Type)
.def("set_type", &OpDescBind::SetType)
.def("input", &OpDescBind::Input)
......
......@@ -32,6 +32,16 @@ def _create_op_desc_(op_type, inputs, outputs, attrs):
return op_desc
def _infer_var_data_type_(var_name, block):
grad_var = block.desc.find_var(var_name.encode("ascii"))
fwd_name = _strip_grad_suffix_(var_name.encode("ascii"))
if block.desc.has_var_recursive(fwd_name):
fwd_var = block.desc.find_var_recursive(fwd_name.encode("ascii"))
grad_var.set_dtype(fwd_var.dtype())
else:
grad_var.set_dtype(core.DataType.FP32)
def _is_all_in_set_(cands, s):
for c in cands:
if not c in s:
......@@ -64,7 +74,7 @@ def _backward_impl_(target,
grad_sub_block = program.create_block(parent_idx=sub_block_idx)
_backward_impl_(target, sub_block, grad_sub_block, no_grad_set,
grad_info_map, callback)
grad_sub_block_list.append(grad_sub_block)
grad_sub_block_list.append(grad_sub_block.desc)
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
each_op.desc, no_grad_set[block.idx], grad_sub_block_list)
grad_op_descs.append(grad_op_desc)
......@@ -80,17 +90,18 @@ def _backward_impl_(target,
for var_name in op_desc.input_arg_names():
if len(var_inputs[var_name]) > 1:
pending_sum_ops.append((_create_op_desc_(
op_type="sum_op",
inputs=var_inputs[var_name],
outputs=[var_name],
op_type="sum",
inputs={"X": var_inputs[var_name]},
outputs={"Out": [var_name]},
attrs={}), idx))
var_inputs[var_name] = [var_name]
for var_name in op_desc.output_arg_names():
if len(var_inputs[var_name]) == 0:
if var_name == core.empty_var_name() or len(var_inputs[
var_name]) == 0:
# it's the first time we get the variable
var_inputs[var_name] = [var_name]
else:
if len(var_inputs[var_name] == 1):
if len(var_inputs[var_name]) == 1:
new_name = var_name + "@RENAME@" + \
str(var_rename_count[var_name])
var_rename_count[var_name] = var_rename_count[var_name] + 1
......@@ -107,7 +118,7 @@ def _backward_impl_(target,
for var_name, inputs in var_inputs.iteritems():
if len(inputs) > 1:
pending_sum_ops.append((_create_op_desc_(
op_type="sum_op",
op_type="sum",
inputs={"X": inputs},
outputs={"Out": var_name},
attrs={}), len(grad_op_descs)))
......@@ -131,13 +142,15 @@ def _backward_impl_(target,
{})
grad_op_descs.insert(ele[1], fill_zeros_like_op)
# create new gradient variables in the target block desc
new_vars = set()
for op_desc in grad_op_descs:
for grad_var_name in op_desc.output_arg_names():
grad_var_name = grad_var_name.encode("ascii")
if target_block.desc.has_var(
if target_block.desc.has_var_recursive(
grad_var_name) or grad_var_name == core.empty_var_name():
continue
target_block.desc.var(grad_var_name)
new_vars.add(grad_var_name)
if not grad_to_var.has_key(grad_var_name):
continue
grad_info_map[grad_to_var[grad_var_name]] = (grad_var_name,
......@@ -160,7 +173,11 @@ def _backward_impl_(target,
for op_desc in grad_op_descs:
op_desc.infer_var_type(target_block.desc)
op_desc.infer_shape(target_block.desc)
target_block.desc.append_allocated_op(op_desc)
for arg in op_desc.output_arg_names():
if arg in new_vars:
_infer_var_data_type_(arg, target_block)
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op_desc)
target_block.sync_with_cpp()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册