未验证 提交 cda893fc 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2Stat]Modify into core.ops.run_program (#33246)

* Modify into core.ops.run_program

* add DDout in core.ops.run_program

* fix typo

* add DOut

* fix typo

* put DOut last
上级 98f08177
...@@ -65,6 +65,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = { ...@@ -65,6 +65,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}}, {"box_coder", {"PriorBox", "PriorBoxVar", "TargetBox"}},
{"momentum", {"Param", "Grad", "Velocity", "LearningRate"}}, {"momentum", {"Param", "Grad", "Velocity", "LearningRate"}},
{"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}}, {"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}},
{"run_program", {"X", "Params"}},
}; };
// NOTE(zhiqiu): Like op_ins_map. // NOTE(zhiqiu): Like op_ins_map.
...@@ -98,6 +99,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = { ...@@ -98,6 +99,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{"rnn", {"DropoutState", "Reserve", "Out", "State"}}, {"rnn", {"DropoutState", "Reserve", "Out", "State"}},
{"lamb", {"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"run_program", {"DOut"}},
}; };
// NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are // NOTE(zhiqiu): Commonly, the outputs in auto-generated OP function are
...@@ -148,6 +150,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = { ...@@ -148,6 +150,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"lamb", {"lamb",
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"rnn", {"DropoutState"}}, {"rnn", {"DropoutState"}},
{"run_program", {"Out", "DOut", "OutScope"}},
}; };
// NOTE(pangyoki): Tensor View Strategy. // NOTE(pangyoki): Tensor View Strategy.
......
...@@ -221,23 +221,15 @@ class PartialProgramLayer(layers.Layer): ...@@ -221,23 +221,15 @@ class PartialProgramLayer(layers.Layer):
def forward(self, inputs): def forward(self, inputs):
in_vars, out_vars, tmp_scope_vec = self._prepare(inputs) in_vars, out_vars, tmp_scope_vec = self._prepare(inputs)
framework._dygraph_tracer().trace_op(
type='run_program', attrs = ('global_block', self.program.desc.block(0), 'start_op_index',
inputs={ 0, 'end_op_index', self._infer_program.desc.block(0).op_size(),
'X': valid_vars(in_vars), 'is_test', not self.training)
'Params': valid_vars(self._params) core.ops.run_program(
}, valid_vars(in_vars),
outputs={ valid_vars(self._params),
'Out': valid_vars(out_vars), valid_vars(out_vars), tmp_scope_vec,
'OutScope': tmp_scope_vec, valid_vars(self._double_grads), *attrs)
'DOut': valid_vars(self._double_grads)
},
attrs={
'global_block': self.program.desc.block(0),
'start_op_index': 0,
'end_op_index': self._infer_program.desc.block(0).op_size(),
'is_test': not self.training
})
restored_nest_out = self._restore_out(out_vars) restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out) return self._remove_no_value(restored_nest_out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册