tracer.cc 10.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/tracer.h"

M
minqiyang 已提交
17
#include <memory>
M
minqiyang 已提交
18
#include <set>
M
minqiyang 已提交
19 20
#include <unordered_map>
#include <unordered_set>
21
#include <utility>
M
minqiyang 已提交
22

M
minqiyang 已提交
23
#include "paddle/fluid/framework/var_type_inference.h"
M
minqiyang 已提交
24 25 26
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
C
chengduo 已提交
27
#include "paddle/fluid/platform/profiler.h"
M
minqiyang 已提交
28

29
namespace paddle {
M
minqiyang 已提交
30 31 32 33 34
namespace imperative {

void CreateGradOp(const framework::OpDesc& op_desc,
                  const std::unordered_set<std::string>& no_grad_set,
                  const std::vector<framework::BlockDesc*>& grad_sub_block,
X
Xin Pan 已提交
35
                  std::vector<framework::OpDesc*>* grad_op_descs,
M
minqiyang 已提交
36
                  std::unordered_map<std::string, std::string>* grad_to_var) {
X
Xin Pan 已提交
37
  PADDLE_ENFORCE(grad_op_descs->empty());
X
Xin Pan 已提交
38 39 40
  const framework::OpInfo& op_info =
      framework::OpInfoMap::Instance().Get(op_desc.Type());
  if (!op_info.grad_op_maker_) return;
J
JiabinYang 已提交
41

X
Xin Pan 已提交
42 43
  std::vector<std::unique_ptr<framework::OpDesc>> descs =
      op_info.GradOpMaker()(op_desc, no_grad_set, grad_to_var, grad_sub_block);
X
Xin Pan 已提交
44 45 46
  for (auto& desc : descs) {
    grad_op_descs->emplace_back(desc.release());
  }
M
minqiyang 已提交
47 48
}

49 50
void CreateNoBuffuerGrad(std::shared_ptr<imperative::VarBase> var,
                         platform::DeviceContext* dev_ctx) {
51
  PADDLE_ENFORCE_NOT_NULL(var, "Could not get valid var base");
M
minqiyang 已提交
52 53
  PADDLE_ENFORCE_NOT_NULL(dev_ctx,
                          "Could not get valid device from forward op");
54 55 56

  if (var->grads_ == nullptr) {
    auto& var_t = var->var_->Get<framework::LoDTensor>();
57 58 59 60
    var->grads_ = std::shared_ptr<imperative::VarBase>(
        new VarBase(var->GradName(), framework::proto::VarType::FP32,
                    framework::vectorize(var_t.dims()), dev_ctx->GetPlace(),
                    var->IsStopGradient(), false, false));
61
  }
M
minqiyang 已提交
62 63 64 65
}

platform::Place GetExpectedPlace(platform::Place place, VarBasePtrMap inputs) {
  platform::Place result = place;
66 67
  for (const auto& it : inputs) {
    for (const std::shared_ptr<imperative::VarBase>& var : it.second) {
M
minqiyang 已提交
68 69 70 71 72 73 74 75 76 77 78 79
      platform::Place tmp_place =
          var->var_->Get<framework::LoDTensor>().place();
      if (!platform::is_same_place(tmp_place, result)) {
        PADDLE_THROW(
            "Input variable should keep in the same place: %s, but get place: "
            "%s of input %s instead",
            result, tmp_place, it.first);
      }
    }
  }

  return result;
M
minqiyang 已提交
80 81
}

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
framework::VariableNameMap CreateInputVarNameMap(
    const OpBase* op, const VarBasePtrMap& varbase_map) {
  framework::VariableNameMap result;

  auto& info_map = framework::OpInfoMap::Instance();
  auto* op_info = info_map.GetNullable(op->Type());
  if (op_info == nullptr || op_info->proto_ == nullptr) {
    return result;
  }

  for (auto& in : op_info->Proto().inputs()) {
    auto it = varbase_map.find(in.name());
    if (it == varbase_map.end()) {
      PADDLE_ENFORCE(in.dispensable());
      result[in.name()] = {};
    } else {
      auto var_vector = it->second;
      std::vector<std::string> args;
      args.reserve(var_vector.size());
101
      for (std::shared_ptr<imperative::VarBase> var_base : var_vector) {
102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
        args.emplace_back(var_base->Name());
      }
      result[in.name()] = args;
    }
  }
  return result;
}

framework::VariableNameMap CreateOutputVarNameMap(
    const OpBase* op, const VarBasePtrMap& varbase_map) {
  framework::VariableNameMap result;

  auto& info_map = framework::OpInfoMap::Instance();
  auto* op_info = info_map.GetNullable(op->Type());
  if (op_info == nullptr || op_info->proto_ == nullptr) {
    return result;
  }

  for (auto& out : op_info->Proto().outputs()) {
    auto it = varbase_map.find(out.name());
    if (it == varbase_map.end()) {
      PADDLE_ENFORCE(out.dispensable());
      result[out.name()] = {};
    } else {
      auto var_vector = it->second;
      std::vector<std::string> args;
      args.reserve(var_vector.size());
129
      for (const std::shared_ptr<imperative::VarBase>& var_base : var_vector) {
130 131 132 133 134 135 136 137
        args.emplace_back(var_base->Name());
      }
      result[out.name()] = args;
    }
  }
  return result;
}

M
minqiyang 已提交
138
Tracer::Tracer(framework::BlockDesc* root_block) : root_block_(root_block) {}
139

140 141 142 143
void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
                   VarBasePtrMap* outputs, framework::AttributeMap attrs_map,
                   const platform::Place expected_place,
                   const bool stop_gradient) {
C
chengduo 已提交
144
  platform::RecordEvent record_event(op->type_);
M
minqiyang 已提交
145 146 147
  framework::VariableValueMap invars_map;
  framework::VariableValueMap outvars_map;

148
  // Construct input_vars_map and output_vars_map
149 150
  std::map<std::string, std::shared_ptr<imperative::VarBase>> current_vars_map;
  for (auto it : inputs) {
M
minqiyang 已提交
151
    auto& invars = invars_map[it.first];
M
minqiyang 已提交
152
    invars.reserve(it.second.size());
153
    for (std::shared_ptr<imperative::VarBase> inp : it.second) {
154 155
      PADDLE_ENFORCE_NOT_NULL(inp->var_, "op %s input %s nullptr", op->Type(),
                              inp->Name());
M
minqiyang 已提交
156

157
      invars.emplace_back(inp->var_.get());
158 159
      if (!stop_gradient) {
        current_vars_map[inp->Name()] = inp;
M
minqiyang 已提交
160
      }
161 162 163
      VLOG(3) << "input var name: " << inp->Name()
              << " inited: " << inp->var_->IsInitialized()
              << " stop_grad: " << inp->IsStopGradient();
M
minqiyang 已提交
164
    }
M
minqiyang 已提交
165
    op->TrackPreOp(it.first, it.second);
M
minqiyang 已提交
166 167
  }

168
  for (const auto& it : *outputs) {
M
minqiyang 已提交
169
    auto& outvars = outvars_map[it.first];
170 171 172 173 174
    const std::vector<std::shared_ptr<imperative::VarBase>>& outputs_tmp =
        it.second;
    outvars.reserve(outputs_tmp.size());
    for (size_t i = 0U; i < outputs_tmp.size(); ++i) {
      // Add weak_ptr to track outputs
175
      op->outputs_ref.emplace_back(outputs_tmp[i]);
176
      std::shared_ptr<imperative::VarBase> out = outputs_tmp[i];
177
      outvars.emplace_back(out->var_.get());
X
Xin Pan 已提交
178
      out->TrackPreOp(op, it.first, i, stop_gradient);
179 180 181
      if (!stop_gradient) {
        current_vars_map[out->Name()] = out;
      }
M
minqiyang 已提交
182

183
      VLOG(3) << "output var name: " << out->Name()
184 185
              << " inited: " << out->var_->IsInitialized()
              << " stop_grad: " << out->IsStopGradient();
M
minqiyang 已提交
186 187 188
    }
  }

189 190 191 192
  // Check attrs and create op
  framework::VariableNameMap invars_name_map =
      CreateInputVarNameMap(op, inputs);
  framework::VariableNameMap outvars_name_map =
M
minqiyang 已提交
193
      CreateOutputVarNameMap(op, *outputs);
194 195 196 197 198 199 200 201 202 203

  auto& info = framework::OpInfoMap::Instance().Get(op->Type());
  if (info.Checker() != nullptr) {
    info.Checker()->Check(&attrs_map);
  }

  std::unique_ptr<framework::OperatorBase> op_base =
      framework::OpRegistry::CreateOp(op->Type(), invars_name_map,
                                      outvars_name_map, attrs_map);

M
minqiyang 已提交
204
  if (info.infer_var_type_) {
M
minqiyang 已提交
205
    RuntimeInferVarTypeContext infer_var_type_ctx(&inputs, outputs, &attrs_map);
M
minqiyang 已提交
206
    info.infer_var_type_(&infer_var_type_ctx);
M
minqiyang 已提交
207 208
  }

209 210 211
  // TODO(minqiyang): Support infer var type in imperative mode
  // Run forward op
  VLOG(3) << "tracer running " << op->Type();
M
minqiyang 已提交
212 213 214 215 216 217 218 219
  framework::RuntimeContext ctx(invars_map, outvars_map);

  // TODO(panyx0718): Cache p.
  framework::OperatorWithKernel* op_kernel =
      dynamic_cast<framework::OperatorWithKernel*>(op_base.get());
  PADDLE_ENFORCE_NOT_NULL(op_kernel, "only support op with kernel");

  framework::Scope scope;
P
Paddle CI 已提交
220
  op->place_ = GetExpectedPlace(expected_place, inputs);
221

P
Paddle CI 已提交
222 223
  PreparedOp prepared_op = PreparedOp::Prepare(ctx, *op_kernel, op->place_);
  prepared_op.op.RuntimeInferShape(scope, op->place_, ctx);
X
polish  
Xin Pan 已提交
224 225 226
  prepared_op.func(
      framework::ExecutionContext(prepared_op.op, scope, *prepared_op.dev_ctx,
                                  prepared_op.ctx, prepared_op.kernel_configs));
M
minqiyang 已提交
227 228

  if (!stop_gradient) {
229 230 231
    VLOG(5) << "start construct backward op";

    // construct grad op descs
M
minqiyang 已提交
232
    op->attrs_ = attrs_map;
233 234
    std::unique_ptr<framework::OpDesc> fwd_op_desc(new framework::OpDesc(
        op->Type(), invars_name_map, outvars_name_map, attrs_map));
235 236
    std::unique_ptr<std::unordered_map<std::string, std::string>> grad_to_var(
        new std::unordered_map<std::string, std::string>());
237 238 239
    // NOTE(minqiyang): We don't support control flow op in imperative now
    // Add grad_block_ when we want to support it
    CreateGradOp(*fwd_op_desc, {}, {}, &op->grad_op_descs_, grad_to_var.get());
X
Xin Pan 已提交
240

241
    VLOG(5) << "create grad op desc: " << op->grad_op_descs_[0]->Type();
M
minqiyang 已提交
242

243 244 245 246 247 248
    const size_t grad_op_count = op->grad_op_descs_.size();

    op->grad_input_vars_.resize(grad_op_count);
    op->grad_output_vars_.resize(grad_op_count);

    for (size_t i = 0; i < grad_op_count; ++i) {
X
Xin Pan 已提交
249 250 251
      framework::OpDesc* grad_op_desc = op->grad_op_descs_[i];
      for (auto it : grad_op_desc->Inputs()) {
        auto& grad_in_vars = op->grad_input_vars_[i][it.first];
252
        grad_in_vars.reserve(it.second.size());
X
Xin Pan 已提交
253 254 255
        for (const std::string& grad_invar : it.second) {
          auto var_it = grad_to_var->find(grad_invar);
          if (var_it == grad_to_var->end()) {
256 257
            auto fwd_var_it = current_vars_map.find(grad_invar);
            PADDLE_ENFORCE(fwd_var_it != current_vars_map.end());
X
Xin Pan 已提交
258
            // Forward inputs or outputs.
M
minqiyang 已提交
259
            grad_in_vars.emplace_back(fwd_var_it->second);
X
Xin Pan 已提交
260
          } else {
261 262
            std::shared_ptr<imperative::VarBase> var =
                current_vars_map[var_it->second];
263
            CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext());
X
Xin Pan 已提交
264
            // Douts.
265
            var->grads_->SetPreOp(var->PreOp());
M
minqiyang 已提交
266
            grad_in_vars.emplace_back(var->grads_);
X
Xin Pan 已提交
267 268 269 270 271 272 273 274 275 276 277
          }
        }
      }

      for (auto it : grad_op_desc->Outputs()) {
        auto& grad_out_vars = op->grad_output_vars_[i][it.first];
        for (const std::string& grad_outvar : it.second) {
          auto var_it = grad_to_var->find(grad_outvar);
          PADDLE_ENFORCE(var_it != grad_to_var->end(),
                         "Could not found the grad op output var, should this "
                         "operator %s's stop gradient be True",
278
                         op->Type());
279 280 281

          std::shared_ptr<imperative::VarBase> var =
              current_vars_map[var_it->second];
282
          CreateNoBuffuerGrad(var, prepared_op.GetDeviceContext());
283
          var->grads_->SetPreOp(var->PreOp());
M
minqiyang 已提交
284
          grad_out_vars.push_back(var->grads_);
285
          VLOG(3) << "grads output var name: " << var->name_;
M
minqiyang 已提交
286 287 288 289 290 291
        }
      }
    }
  }
}
}  // namespace imperative
292
}  // namespace paddle