From 7bc67c31e52b3eafbf7827c302b63d4f3fdad8b8 Mon Sep 17 00:00:00 2001 From: Xin Pan Date: Mon, 14 Jan 2019 10:06:42 +0800 Subject: [PATCH] polish more test=develop --- paddle/fluid/imperative/tracer.cc | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index a01225cce..2878f5be8 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -164,28 +164,30 @@ std::vector Tracer::PyTrace(OpBase* op, const std::vector& inputs, bool stop_gradient) { VLOG(3) << "py_trace"; - op->input_vars_["X"] = inputs; - op->output_vars_["Out"] = PyLayer::Apply(op->forward_id_, inputs); + op->input_vars_[PyLayer::kFwdInp] = inputs; + op->output_vars_[PyLayer::kFwdOut] = PyLayer::Apply(op->forward_id_, inputs); for (VarBase* inp : inputs) { if (inp->pre_op_) { - op->pre_ops_["X"].push_back(inp->pre_op_); - op->pre_ops_out_idx_["X"].push_back(inp->pre_op_out_idx_); + op->pre_ops_[PyLayer::kFwdInp].push_back(inp->pre_op_); + op->pre_ops_out_idx_[PyLayer::kFwdInp].push_back(inp->pre_op_out_idx_); } else { - op->pre_ops_["X"].push_back(nullptr); + op->pre_ops_[PyLayer::kFwdInp].push_back(nullptr); } } - auto& outputs = op->output_vars_["Out"]; + auto& outputs = op->output_vars_[PyLayer::kFwdOut]; for (size_t i = 0; i < outputs.size(); ++i) { VarBase* out = outputs[i]; out->stop_gradient_ = stop_gradient; out->pre_op_ = op; - out->pre_op_out_name_ = "Out"; + out->pre_op_out_name_ = PyLayer::kFwdOut; out->pre_op_out_idx_ = i; } if (!stop_gradient) { - auto& grad_input_vars = op->grad_input_vars_["X@GRAD"]; - auto& grad_output_vars = op->grad_output_vars_["Out@GRAD"]; + auto& grad_input_vars = + op->grad_input_vars_[framework::GradVarName(PyLayer::kFwdInp)]; + auto& grad_output_vars = + op->grad_output_vars_[framework::GradVarName(PyLayer::kFwdOut)]; for (const VarBase* inp : inputs) { grad_input_vars.push_back(inp->var_); -- GitLab