提交 7bc67c31 编写于 作者: X Xin Pan

polish more

test=develop
上级 0c04cac4
...@@ -164,28 +164,30 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op, ...@@ -164,28 +164,30 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
const std::vector<VarBase*>& inputs, const std::vector<VarBase*>& inputs,
bool stop_gradient) { bool stop_gradient) {
VLOG(3) << "py_trace"; VLOG(3) << "py_trace";
op->input_vars_["X"] = inputs; op->input_vars_[PyLayer::kFwdInp] = inputs;
op->output_vars_["Out"] = PyLayer::Apply(op->forward_id_, inputs); op->output_vars_[PyLayer::kFwdOut] = PyLayer::Apply(op->forward_id_, inputs);
for (VarBase* inp : inputs) { for (VarBase* inp : inputs) {
if (inp->pre_op_) { if (inp->pre_op_) {
op->pre_ops_["X"].push_back(inp->pre_op_); op->pre_ops_[PyLayer::kFwdInp].push_back(inp->pre_op_);
op->pre_ops_out_idx_["X"].push_back(inp->pre_op_out_idx_); op->pre_ops_out_idx_[PyLayer::kFwdInp].push_back(inp->pre_op_out_idx_);
} else { } else {
op->pre_ops_["X"].push_back(nullptr); op->pre_ops_[PyLayer::kFwdInp].push_back(nullptr);
} }
} }
auto& outputs = op->output_vars_["Out"]; auto& outputs = op->output_vars_[PyLayer::kFwdOut];
for (size_t i = 0; i < outputs.size(); ++i) { for (size_t i = 0; i < outputs.size(); ++i) {
VarBase* out = outputs[i]; VarBase* out = outputs[i];
out->stop_gradient_ = stop_gradient; out->stop_gradient_ = stop_gradient;
out->pre_op_ = op; out->pre_op_ = op;
out->pre_op_out_name_ = "Out"; out->pre_op_out_name_ = PyLayer::kFwdOut;
out->pre_op_out_idx_ = i; out->pre_op_out_idx_ = i;
} }
if (!stop_gradient) { if (!stop_gradient) {
auto& grad_input_vars = op->grad_input_vars_["X@GRAD"]; auto& grad_input_vars =
auto& grad_output_vars = op->grad_output_vars_["Out@GRAD"]; op->grad_input_vars_[framework::GradVarName(PyLayer::kFwdInp)];
auto& grad_output_vars =
op->grad_output_vars_[framework::GradVarName(PyLayer::kFwdOut)];
for (const VarBase* inp : inputs) { for (const VarBase* inp : inputs) {
grad_input_vars.push_back(inp->var_); grad_input_vars.push_back(inp->var_);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册