未验证 提交 cda893fc 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Modify into core.ops.run_program (#33246)

* Modify into core.ops.run_program

* add DDout in core.ops.run_program

* fix typo

* add DOut

* fix typo

* put DOut last
上级 98f08177
......@@ -65,6 +65,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}},
{"momentum", {"Param", "Grad", "Velocity", "LearningRate"}},
{"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}},
{"run_program", {"X", "Params"}},
};
// NOTE(zhiqiu): Like op_ins_map.
......@@ -98,6 +99,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"run_program", {"DOut"}},
};
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
......@@ -148,6 +150,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"rnn", {"DropoutState"}},
{"run_program", {"Out", "DOut", "OutScope"}},
};
// NOTE(pangyoki): Tensor View Strategy.
......
......@@ -221,23 +221,15 @@ class PartialProgramLayer(layers.Layer):
def forward(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
framework._dygraph_tracer().trace_op(
type='run_program',
inputs={
'X': valid_vars(in_vars),
'Params': valid_vars(self._params)
},
outputs={
'Out': valid_vars(out_vars),
'OutScope': tmp_scope_vec,
'DOut': valid_vars(self._double_grads)
},
attrs={
'global_block': self.program.desc.block(0),
'start_op_index': 0,
'end_op_index': self._infer_program.desc.block(0).op_size(),
'is_test': not self.training
})
attrs = ('global_block', self.program.desc.block(0), 'start_op_index',
0, 'end_op_index', self._infer_program.desc.block(0).op_size(),
'is_test', not self.training)
core.ops.run_program(
valid_vars(in_vars),
valid_vars(self._params),
valid_vars(out_vars), tmp_scope_vec,
valid_vars(self._double_grads), *attrs)
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册