tracer.h 6.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <map>
#include <string>
#include <vector>

#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/imperative/engine.h"
#include "paddle/fluid/imperative/layer.h"

namespace paddle {
namespace imperative {

void CreateGradOp(const framework::OpDesc& op_desc,
                  const std::unordered_set<std::string>& no_grad_set,
                  const std::vector<framework::BlockDesc*>& grad_sub_block,
                  framework::OpDesc** grad_op_desc,
                  std::unordered_map<std::string, std::string>* grad_to_var) {
  std::vector<std::unique_ptr<framework::OpDesc>> grad_op_descs =
      framework::OpInfoMap::Instance()
          .Get(op_desc.Type())
          .GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
  PADDLE_ENFORCE(grad_op_descs.size() == 1, "Only support 1 grad op now.");
  // TODO(panyx0718): Leak?
  *grad_op_desc = grad_op_descs[0].release();
}

X
Xin Pan 已提交
43 44 45 46 47 48 49 50
void InitVar(framework::Variable* var, framework::Variable* grad_var) {
  auto& var_t = var->Get<framework::LoDTensor>();
  float* data =
      grad_var->GetMutable<framework::LoDTensor>()->mutable_data<float>(
          var_t.dims(), platform::CPUPlace());
  std::fill(data, data + var_t.numel(), 0.0);
}

51 52
class Tracer {
 public:
X
Xin Pan 已提交
53 54
  explicit Tracer(framework::BlockDesc* root_block,
                  framework::BlockDesc* startup_block)
X
clean  
Xin Pan 已提交
55
      : root_block_(root_block), startup_block_(startup_block) {}
56

X
clean  
Xin Pan 已提交
57
  virtual ~Tracer() {}
58

X
Xin Pan 已提交
59 60 61
  void Trace(OpBase* op,
             const std::map<std::string, std::vector<VarBase*>>& inputs,
             const std::map<std::string, std::vector<VarBase*>>& outputs,
62
             framework::BlockDesc* block) {
X
Xin Pan 已提交
63 64
    std::map<std::string, VarBase*> vars;

65 66 67 68 69 70 71
    framework::OpDesc* op_desc = op->op_desc_;
    VLOG(3) << "tracer tracing " << op_desc->Type();
    op_desc->InferShape(*block);
    op_desc->InferVarType(block);
    std::unique_ptr<framework::OperatorBase> op_base =
        framework::OpRegistry::CreateOp(*op_desc);

X
Xin Pan 已提交
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    framework::VariableValueMap invars_map;
    framework::VariableValueMap outvars_map;

    op->input_vars_ = inputs;
    for (auto it : op->input_vars_) {
      auto& invars = invars_map[it.first];
      for (VarBase* inp : it.second) {
        PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr",
                                op->op_desc_->Type(), inp->var_desc_->Name());

        invars.push_back(inp->var_);
        vars[inp->var_desc_->Name()] = inp;
        if (inp->pre_op_) {
          (*op->pre_ops_)[it.first].push_back(inp->pre_op_);
          (*op->pre_ops_out_idx_)[it.first].push_back(inp->pre_op_out_idx_);
87
        } else {
X
Xin Pan 已提交
88
          (*op->pre_ops_)[it.first].push_back(nullptr);
89
        }
X
Xin Pan 已提交
90
        VLOG(3) << "input vname " << inp->var_desc_->Name() << " "
X
clean  
Xin Pan 已提交
91
                << inp->var_->IsInitialized();
92 93 94
      }
    }

X
Xin Pan 已提交
95 96 97 98 99 100 101 102 103 104
    op->output_vars_ = outputs;
    for (auto it : op->output_vars_) {
      auto& outvars = outvars_map[it.first];
      const std::vector<VarBase*>& outputs = it.second;
      for (size_t i = 0; i < outputs.size(); ++i) {
        VarBase* out = outputs[i];
        outvars.push_back(out->var_);
        vars[out->var_desc_->Name()] = out;

        framework::VarDesc* var_desc = block->FindVar(out->var_desc_->Name());
105
        if (var_desc->GetType() == framework::proto::VarType::LOD_TENSOR) {
X
Xin Pan 已提交
106
          out->var_->GetMutable<framework::LoDTensor>();
107 108 109
        } else {
          LOG(ERROR) << "tracer doesn't support yet";
        }
X
Xin Pan 已提交
110 111 112 113 114 115
        out->pre_op_ = op;
        out->pre_op_out_name_ = it.first;
        out->pre_op_out_idx_ = i;

        VLOG(3) << "output vname " << out->var_desc_->Name() << " "
                << out->var_->IsInitialized();
116 117
      }
    }
X
Xin Pan 已提交
118 119

    VLOG(3) << "tracer running " << op_desc->Type();
X
Xin Pan 已提交
120 121 122
    framework::RuntimeContext ctx(invars_map, outvars_map);
    op_base->Run(ctx, platform::CPUPlace());

X
Xin Pan 已提交
123 124 125 126 127 128 129 130 131
    if (block == startup_block_) {
      op->grad_op_desc_ = nullptr;
      op->grad_to_var_ = nullptr;
    } else {
      framework::OpDesc* grad_op_desc;
      auto grad_to_var = new std::unordered_map<std::string, std::string>();
      CreateGradOp(*op_desc, {}, {block}, &grad_op_desc, grad_to_var);
      op->grad_op_desc_ = grad_op_desc;
      op->grad_to_var_ = grad_to_var;
X
Xin Pan 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163

      for (auto it : grad_op_desc->Inputs()) {
        auto& grad_in_vars = op->grad_input_vars_[it.first];
        for (const std::string& grad_invar : it.second) {
          block->FindRecursiveOrCreateVar(grad_invar);
          auto var_it = op->grad_to_var_->find(grad_invar);
          if (var_it == op->grad_to_var_->end()) {
            auto fwd_var_it = vars.find(grad_invar);
            PADDLE_ENFORCE(fwd_var_it != vars.end());
            grad_in_vars.push_back(fwd_var_it->second->var_);
          } else {
            VarBase* var = vars[var_it->second];
            if (!var->grads_->IsInitialized()) {
              InitVar(var->var_, var->grads_);
            }
            grad_in_vars.push_back(var->grads_);
          }
        }
      }
      for (auto it : grad_op_desc->Outputs()) {
        auto& grad_out_vars = op->grad_output_vars_[it.first];
        for (const std::string& grad_outvar : it.second) {
          block->FindRecursiveOrCreateVar(grad_outvar);
          auto var_it = op->grad_to_var_->find(grad_outvar);
          PADDLE_ENFORCE(var_it != op->grad_to_var_->end());
          VarBase* var = vars[var_it->second];
          if (!var->grads_->IsInitialized()) {
            InitVar(var->var_, var->grads_);
          }
          grad_out_vars.push_back(var->grads_);
        }
      }
X
Xin Pan 已提交
164
    }
165 166 167 168 169
    op->block_ = block;
  }

 private:
  framework::BlockDesc* root_block_;
X
Xin Pan 已提交
170
  framework::BlockDesc* startup_block_;
171 172 173 174
};

}  // namespace imperative
}  // namespace paddle