ipu_compiler.cc 34.1 KB
Newer Older
J
jianghaicheng 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

A
Allen Guo 已提交
15
#include "paddle/fluid/platform/device/ipu/ipu_compiler.h"
J
jianghaicheng 已提交
16

A
Allen Guo 已提交
17 18 19 20
#include <popart/adam.hpp>
#include <popart/adaptive.hpp>
#include <popart/optimizer.hpp>
#include <popart/sgd.hpp>
A
Allen Guo 已提交
21

J
jianghaicheng 已提交
22
#include "paddle/fluid/framework/ir/graph_helper.h"
23 24
#include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/ipu_strategy.h"
A
Allen Guo 已提交
25
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
J
jianghaicheng 已提交
26 27 28 29 30

namespace paddle {
namespace platform {
namespace ipu {

31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
namespace {

struct CustomOpAttrVisitor : public boost::static_visitor<void> {
  CustomOpAttrVisitor(std::map<std::string, popart::any>* attr,
                      const std::string& attr_name)
      : attrs_(attr), attr_name_(attr_name) {}

  mutable std::map<std::string, popart::any>* attrs_;
  std::string attr_name_;

  void operator()(int v) const { attrs_->emplace(attr_name_, v); }
  void operator()(float v) const { attrs_->emplace(attr_name_, v); }
  void operator()(const std::string& v) const {
    attrs_->emplace(attr_name_, v);
  }
  void operator()(const std::vector<int>& v) const {
    attrs_->emplace(attr_name_, v);
  }
  void operator()(const std::vector<float>& v) const {
    attrs_->emplace(attr_name_, v);
  }
  void operator()(const std::vector<std::string>& v) const {
    attrs_->emplace(attr_name_, v);
  }
  void operator()(bool v) const { attrs_->emplace(attr_name_, v); }
  void operator()(const std::vector<bool>& v) const {
    attrs_->emplace(attr_name_, v);
  }
  void operator()(BlockDesc* desc) const {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported calling method for `BlockDesc` type when extracting "
        "custom operator attributes."));
  }
  void operator()(const std::vector<BlockDesc*>& v) const {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported calling method for `BlockDesc` type when extracting  "
        "custom operator attributes."));
  }
  void operator()(int64_t v) const { attrs_->emplace(attr_name_, v); }
  void operator()(const std::vector<int64_t>& v) const {
    attrs_->emplace(attr_name_, v);
  }
  void operator()(const std::vector<double>& v) const {
    attrs_->emplace(attr_name_, v);
  }
  void operator()(boost::blank) const {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported calling method for `boost::blank` type when extracting "
        "custom operator attributes."));
  }
};

struct ConstantOpAttrVisitor : public boost::static_visitor<void> {
  ConstantOpAttrVisitor(framework::LoDTensor* tensor, VarType::Type dtype)
      : tensor_(tensor), dtype_(dtype) {}

  framework::LoDTensor* tensor_;
  VarType::Type dtype_;

  void operator()(const std::vector<int>& vec) const {
    framework::TensorFromVector<int>(vec, tensor_);
  }
  void operator()(const std::vector<float>& vec) const {
    if (dtype_ == VarType::FP16) {
      std::vector<float16> vec_fp16;
      std::transform(vec.begin(), vec.end(), std::back_inserter(vec_fp16),
                     [](float f) -> float16 { return float16(f); });
      framework::TensorFromVector<float16>(vec_fp16, tensor_);
    } else {
      framework::TensorFromVector<float>(vec, tensor_);
    }
  }
  void operator()(const std::vector<bool>& vec) const {
    framework::TensorFromVector<bool>(vec, tensor_);
  }
  void operator()(const std::vector<int64_t>& vec) const {
    framework::TensorFromVector<int64_t>(vec, tensor_);
  }
  void operator()(const std::vector<double>& vec) const {
    framework::TensorFromVector<double>(vec, tensor_);
  }
#define RAISE_ERROR \
  PADDLE_THROW(     \
      platform::errors::InvalidArgument("Constant value must be a vector"))
  void operator()(int v) const { RAISE_ERROR; }
  void operator()(float v) const { RAISE_ERROR; }
  void operator()(const std::string& v) const { RAISE_ERROR; }
  void operator()(const std::vector<std::string>& v) const { RAISE_ERROR; }
  void operator()(bool v) const { RAISE_ERROR; }
  void operator()(BlockDesc* desc) const { RAISE_ERROR; }
  void operator()(const std::vector<BlockDesc*>& v) const { RAISE_ERROR; }
  void operator()(int64_t v) const { RAISE_ERROR; }
  void operator()(boost::blank) const { RAISE_ERROR; }
#undef RAISE_ERROR
};

A
Allen Guo 已提交
127 128
popart::AdamMode AdamModeFromStr(const std::string& str,
                                 const bool& use_no_bias_optimizer) {
A
Allen Guo 已提交
129
  if (str == "adam") {
A
Allen Guo 已提交
130 131 132 133
    if (!use_no_bias_optimizer)
      return popart::AdamMode::Adam;
    else
      return popart::AdamMode::AdamNoBias;
A
Allen Guo 已提交
134 135 136
  } else if (str == "adamax") {
    return popart::AdamMode::AdaMax;
  } else if (str == "lamb") {
A
Allen Guo 已提交
137 138 139 140
    if (!use_no_bias_optimizer)
      return popart::AdamMode::Lamb;
    else
      return popart::AdamMode::LambNoBias;
A
Allen Guo 已提交
141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Uknown AdamMode: %s, AdamMode must be one of these values: adam, "
        "adamax or lamb",
        str));
  }
}

popart::AdaptiveMode AdaptiveModeFromStr(const std::string& str) {
  if (str == "adadelta") {
    return popart::AdaptiveMode::AdaDelta;
  } else if (str == "adagrad") {
    return popart::AdaptiveMode::AdaGrad;
  } else if (str == "rmsprop") {
    return popart::AdaptiveMode::RMSProp;
  } else if (str == "centered_rmsprop") {
    return popart::AdaptiveMode::CenteredRMSProp;
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Uknown AdaptiveMode: %s, AdaptiveMode must be one of these values: "
        "adadelta, adagrad, rmsprop or centered_rmsprop",
        str));
  }
}

popart::WeightDecayMode WeightDecayModeFromStr(const std::string& str) {
  if (str == "decay") {
    return popart::WeightDecayMode::Decay;
  } else if (str == "l2_regularization") {
    return popart::WeightDecayMode::L2Regularization;
  } else {
    PADDLE_THROW(platform::errors::InvalidArgument(
        "Uknown WeightDecayMode: %s, WeightDecayMode must be decay or "
        "l2_regularization",
        str));
  }
}

A
Allen Guo 已提交
179 180 181 182 183 184 185 186 187 188 189
popart::DataType DataTypeFromStr(const std::string& str) {
  if (str == "FLOAT") {
    return popart::DataType::FLOAT;
  } else if (str == "FLOAT16") {
    return popart::DataType::FLOAT16;
  } else {
    PADDLE_THROW(
        platform::errors::Unimplemented("Unsupported DataType: %s", str));
  }
}

J
jianghaicheng 已提交
190
template <typename T>
A
Allen Guo 已提交
191
T GetAttrAllowNull(std::string attr, OpDesc* op_desc) {
J
jianghaicheng 已提交
192 193 194 195 196 197 198 199
  if (op_desc->HasAttr(attr)) {
    return BOOST_GET_CONST(T, op_desc->GetAttr(attr));
  } else {
    return {};
  }
}

template <typename T>
A
Allen Guo 已提交
200
nonstd::optional<T> GetOptAttrAllowNull(std::string attr, OpDesc* op_desc) {
J
jianghaicheng 已提交
201 202 203 204 205 206 207
  if (op_desc->HasAttr(attr)) {
    return BOOST_GET_CONST(T, op_desc->GetAttr(attr));
  } else {
    return {};
  }
}

A
Allen Guo 已提交
208 209 210 211 212 213 214 215 216 217
template <typename TI, typename TO>
TO GetCastSigAttrAllowNull(std::string attr, OpDesc* op_desc) {
  if (op_desc->HasAttr(attr)) {
    auto x = BOOST_GET_CONST(TI, op_desc->GetAttr(attr));
    return static_cast<TO>(x);
  } else {
    return {};
  }
}

218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245
// Helper for adding namescope info
struct NameScopeHelper {
  NameScopeHelper(const OpDesc* op, popart::Builder* builder);

  ~NameScopeHelper() {
    if (pushed_) {
      builder_->popNameScope();
    }
  }

  bool pushed_ = false;
  popart::Builder* builder_;
};

NameScopeHelper::NameScopeHelper(const OpDesc* op, popart::Builder* builder)
    : builder_(builder) {
  auto op_namescope = BOOST_GET_CONST(std::string, op->GetAttr(sOpNamescope));
  if (op_namescope.empty() || op_namescope == "/") {
    return;
  }
  op_namescope.pop_back();
  op_namescope.erase(op_namescope.begin());
  builder->pushNameScope(op_namescope);
  pushed_ = true;
}

}  // namespace

A
Allen Guo 已提交
246 247 248 249 250 251 252 253 254 255 256 257 258
GraphHelper::GraphHelper(const Graph* g) {
  graph = g;
  sorted_ops = framework::ir::TopologySortOperations(*g);
  for (auto* node : g->Nodes()) {
    nodes_id_map[node->id()] = node;
    if (node->IsVar()) {
      vars_name_map[node->Name()] = node;
      sorted_vars_id.push_back(node->id());
    }
  }
  std::sort(sorted_vars_id.begin(), sorted_vars_id.end());
}

A
Allen Guo 已提交
259 260 261 262 263
Compiler::Compiler() { RegisterOpFunc(); }

Compiler::~Compiler() {
  builder_.reset();
  resources_.reset();
J
jianghaicheng 已提交
264 265
}

A
Allen Guo 已提交
266
void Compiler::Prepare(const Graph* graph) {
A
Allen Guo 已提交
267 268
  builder_ = popart::Builder::create();
  resources_ = std::make_unique<CompilerResources>();
A
Allen Guo 已提交
269
  graph_helper_ = std::make_unique<GraphHelper>(graph);
A
Allen Guo 已提交
270 271 272 273 274 275 276 277 278 279 280
  // Set the flag of set_amp_for_all_
  for (auto* node : graph_helper_->sorted_ops) {
    auto* op_desc = node->Op();
    auto op_type = op_desc->Type();
    if (op_type == "popart_matmul") {
      if (op_desc->HasAttr(sAvailMemAttribute)) {
        set_amp_for_all_ = false;
        return;
      }
    }
  }
A
Allen Guo 已提交
281
}
J
jianghaicheng 已提交
282 283 284 285

void Compiler::RegisterOpFunc() {
  VLOG(10) << "enter Compiler::RegisterOpFunc";
#define INT_VEC std::vector<std::int64_t>
A
Allen Guo 已提交
286
#define INT32_VEC std::vector<std::int32_t>
J
jianghaicheng 已提交
287 288 289
#define FLOAT_VEC std::vector<float>
#define FLOAT float
#define INT std::int64_t
A
Allen Guo 已提交
290
#define INT32 std::int32_t
J
jianghaicheng 已提交
291 292 293 294 295 296 297
#define BOOL bool
#define STRING std::string
#define STRING_VEC std::vector<std::string*>
#define NONE

#define ARG(Type, Name) , GetAttrAllowNull<Type>(#Name, op_desc)
#define OPT_ARG(Type, Name) , GetOptAttrAllowNull<Type>(#Name, op_desc)
A
Allen Guo 已提交
298
#define SIG_ARG(TI, TO, Name) , GetCastSigAttrAllowNull<TI, TO>(#Name, op_desc)
J
jianghaicheng 已提交
299 300 301 302 303 304 305
#define POPART_CONST_ARG(Name) , const PopartConstant& Name
#define HOST_SIDE_CONST_ARG(Name) , const HostSideConstant& Name
#define POPART_ATTRIB_VEC_ARG(Name)
#define BODY_ARG(Name) NONE

  name_function_ = {
#define OP_DECL(FuncName, OnnxImpl, Args)                     \
A
Allen Guo 已提交
306
  {#FuncName, [&](OpDesc* op_desc) {                          \
J
jianghaicheng 已提交
307 308 309 310 311 312
     auto op_type = op_desc->Type();                          \
     VLOG(10) << "build op:" << op_type << " args " << #Args; \
     auto inputs = GetOpInputs(op_desc);                      \
     auto debug_context = BuildDebugContext(op_desc);         \
     auto aiGraphcoreOpset = builder_->aiGraphcoreOpset1();   \
     auto aiOnnxOpset = builder_->aiOnnxOpset11();            \
A
Allen Guo 已提交
313
     NameScopeHelper ns_helper(op_desc, builder_.get());      \
J
jianghaicheng 已提交
314
     auto output_ids = OnnxImpl(inputs Args, debug_context);  \
A
Allen Guo 已提交
315
     PostLower(output_ids, op_desc);                          \
J
jianghaicheng 已提交
316
   }},  // NOLINT
A
Allen Guo 已提交
317 318
#include "paddle/fluid/platform/device/ipu/supported_ops_autogen.h"
#include "paddle/fluid/platform/device/ipu/supported_ops_custom.h"
J
jianghaicheng 已提交
319 320 321 322 323 324 325
  };

#undef OP_DECL
#undef BODY_ARG
#undef POPART_ATTRIB_VEC_ARG
#undef HOST_SIDE_CONST_ARG
#undef POPART_CONST_ARG
A
Allen Guo 已提交
326
#undef SIG_ARG
J
jianghaicheng 已提交
327 328 329 330 331 332
#undef OPT_ARG
#undef ARG
#undef NONE
#undef STRING_VEC
#undef STRING
#undef BOOL
A
Allen Guo 已提交
333
#undef INT32
J
jianghaicheng 已提交
334 335 336
#undef INT
#undef FLOAT
#undef FLOAT_VEC
A
Allen Guo 已提交
337
#undef INT32_VEC
J
jianghaicheng 已提交
338 339 340
#undef INT_VEC
}

A
Allen Guo 已提交
341
void Compiler::InitInputs(const std::vector<std::string>& feed_list) {
J
jianghaicheng 已提交
342
  for (const auto& feed_name : feed_list) {
A
Allen Guo 已提交
343 344 345
    auto* node = graph_helper_->vars_name_map[feed_name];
    auto* var_desc = node->Var();
    VLOG(10) << "feed_name= " << var_desc->Name();
346
    auto data_type = VarType2PopartDType(var_desc->GetDataType());
A
Allen Guo 已提交
347 348 349 350 351 352 353
    popart::TensorInfo input_info{data_type, var_desc->GetShape()};
    VLOG(10) << "popart input_info = " << input_info;
    popart::TensorId tensor_id =
        builder_->addInputTensor(input_info, feed_name);
    VLOG(10) << "popart input tensor id = " << tensor_id;
    resources_->inputs.push_back(tensor_id);
    resources_->tensors.emplace(var_desc->Name(), tensor_id);
J
jianghaicheng 已提交
354 355 356 357 358
  }
}

void Compiler::InitOutputs(const std::vector<std::string>& fetch_list) {
  for (const auto& fetch_name : fetch_list) {
A
Allen Guo 已提交
359 360 361 362 363 364
    auto tensor = resources_->tensors.find(fetch_name);
    PADDLE_ENFORCE_NE(
        tensor, resources_->tensors.end(),
        platform::errors::NotFound(
            "Output tensor %s is not found, please check the model.",
            fetch_name));
J
jianghaicheng 已提交
365 366 367
    VLOG(10) << "fetch_name= " << fetch_name;
    VLOG(10) << "popart output tensor id = " << tensor->second;
    builder_->addOutputTensor(tensor->second);
A
Allen Guo 已提交
368 369 370 371
    resources_->outputs.push_back(tensor->second);
  }
}

A
Allen Guo 已提交
372
void Compiler::LowerConstants(const Scope* scope) {
A
Allen Guo 已提交
373 374
  auto& kid_scope = scope->NewScope();
  VLOG(10) << "enter Compiler::LowerConstants";
A
Allen Guo 已提交
375
  for (auto* node : graph_helper_->sorted_ops) {
A
Allen Guo 已提交
376 377 378 379 380 381
    auto* op_desc = node->Op();
    auto op_type = op_desc->Type();
    if (op_type == "popart_constant") {
      auto shape =
          BOOST_GET_CONST(std::vector<int64_t>, op_desc->GetAttr("dims"));
      auto dtype_ = BOOST_GET_CONST(int, op_desc->GetAttr("dtype"));
382 383 384
      auto dtype = PopartDType2VarType(
          OnnxDType2PopartType(static_cast<ONNXDataType>(dtype_)));
      auto tensor_name = GetOpOutputs(op_desc).front();
A
Allen Guo 已提交
385 386 387 388 389 390
      auto* var = kid_scope.Var(tensor_name);
      VLOG(10) << "lowering constant: " << tensor_name;
      auto* tensor = var->GetMutable<framework::LoDTensor>();
      ConstantOpAttrVisitor visitor(tensor, dtype);
      auto value = op_desc->GetAttr("value");
      boost::apply_visitor(visitor, value);
391
      auto ddim = phi::make_ddim(shape);
A
Allen Guo 已提交
392 393 394
      tensor->Resize(ddim);

      auto const_data = std::unique_ptr<popart::ConstVoidData>();
395
      popart::TensorInfo tensor_info(PhiDType2PopartDType(tensor->dtype()),
A
Allen Guo 已提交
396
                                     shape);
A
Allen Guo 已提交
397
      const_data.reset(new popart::ConstVoidData(tensor->data(), tensor_info));
A
Allen Guo 已提交
398
      NameScopeHelper ns_helper(op_desc, builder_.get());
A
Allen Guo 已提交
399
      popart::TensorId result = builder_->aiOnnxOpset11().constant(*const_data);
A
Allen Guo 已提交
400
      PostLower(result, op_desc);
A
Allen Guo 已提交
401 402
      resources_->tensors.emplace(tensor_name, result);
    }
J
jianghaicheng 已提交
403
  }
A
Allen Guo 已提交
404
  VLOG(10) << "leave Compiler::LowerConstants";
J
jianghaicheng 已提交
405 406
}

A
Allen Guo 已提交
407
void Compiler::LowerWeights(const Scope* scope) {
A
Allen Guo 已提交
408
  VLOG(10) << "enter Compiler::LowerWeights";
A
Allen Guo 已提交
409
  // At this step, the graph doesn't contains optimizer related states
A
Allen Guo 已提交
410 411
  for (auto id : graph_helper_->sorted_vars_id) {
    auto* node = graph_helper_->nodes_id_map[id];
A
Allen Guo 已提交
412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
    // Weights are var node and Persistable
    if (node->IsVar() && !node->IsCtrlVar() && node->Var() &&
        node->Var()->Persistable()) {
      // Weights are Parameter in training mode
      if (ipu_strategy_->is_training && !node->Var()->IsParameter()) {
        continue;
      }
      auto var_name = node->Var()->Name();
      // Some op has same input and output tensor, like batchnorm
      if (resources_->tensors.count(var_name) != 0) {
        VLOG(10) << "found existed one, skip lowering Weight: " << var_name;
        continue;
      }
      VLOG(10) << "lowering weight: " << var_name;
      auto var = scope->FindVar(var_name);
      PADDLE_ENFORCE_NOT_NULL(
          var, platform::errors::NotFound("Tensor %s is not found in the scope",
                                          var_name));
      auto tensor = var->Get<framework::LoDTensor>();
431
      auto dtype = PhiDType2PopartDType(tensor.dtype());
A
Allen Guo 已提交
432 433 434 435 436 437 438 439 440 441 442 443 444
      auto shape = std::vector<int64_t>();
      for (size_t i = 0; i < tensor.dims().size(); ++i) {
        shape.push_back(tensor.dims().at(i));
      }
      popart::TensorInfo tensor_info(dtype, shape);
      popart::ConstVoidData const_data{tensor.data(), tensor_info};
      if (!node->outputs.empty()) {
        auto op_node = node->outputs[0];
        NameScopeHelper ns_helper(op_node->Op(), builder_.get());
        popart::TensorId result =
            builder_->addInitializedInputTensor(const_data, var_name);
        resources_->tensors.emplace(var_name, result);
        resources_->weights.push_back(var_name);
J
jianghaicheng 已提交
445 446 447
      }
    }
  }
A
Allen Guo 已提交
448 449 450
  VLOG(10) << "leave Compiler::LowerWeights";
}

A
Allen Guo 已提交
451 452 453 454 455 456 457 458 459 460 461 462 463
void Compiler::LowerBody() {
  VLOG(10) << "enter Compiler::LowerBody";
  for (auto* node : graph_helper_->sorted_ops) {
    auto* op_desc = node->Op();
    auto op_type = op_desc->Type();
    VLOG(10) << "lowering op: " << op_type;

    if (op_type == "popart_constant") {
      // pass
    } else if (op_type == "popart_optimizer") {
      // pass
    } else if (op_type == "popart_checkpointoutput") {
      auto inputs = GetOpInputs(op_desc);
A
Allen Guo 已提交
464
      NameScopeHelper ns_helper(op_desc, builder_.get());
A
Allen Guo 已提交
465
      auto output_ids = builder_->checkpointOutput(inputs);
A
Allen Guo 已提交
466
      PostLower(output_ids, op_desc);
A
Allen Guo 已提交
467 468 469 470 471 472 473 474 475 476 477 478 479
    } else if (op_type == "popart_custom_op") {
      auto inputs = GetOpInputs(op_desc);
      auto outputs = GetOpOutputs(op_desc);
      auto debug_context = BuildDebugContext(op_desc);
      auto attributes = std::map<std::string, popart::any>{};
      for (auto& attr : op_desc->GetAttrMap()) {
        CustomOpAttrVisitor visitor(&attributes, attr.first);
        boost::apply_visitor(visitor, attr.second);
      }
      auto __op_type =
          BOOST_GET_CONST(std::string, op_desc->GetAttr("__op_type"));
      VLOG(10) << "Build graph from custom op: " << __op_type;
      auto it = custom_ops_.find(__op_type);
A
Allen Guo 已提交
480
      NameScopeHelper ns_helper(op_desc, builder_.get());
A
Allen Guo 已提交
481 482 483
      auto output_ids =
          builder_->customOp(it->second.popart_op, it->second.popart_op.version,
                             inputs, outputs.size(), attributes, debug_context);
A
Allen Guo 已提交
484
      PostLower(output_ids, op_desc);
A
Allen Guo 已提交
485 486 487 488 489 490
    } else if (op_type == "popart_printtensor") {
      auto inputs = GetOpInputs(op_desc);
      auto debug_context = BuildDebugContext(op_desc);
      auto print_gradient =
          BOOST_GET_CONST(int64_t, op_desc->GetAttr("print_gradient"));
      auto title = BOOST_GET_CONST(std::string, op_desc->GetAttr("title"));
A
Allen Guo 已提交
491
      NameScopeHelper ns_helper(op_desc, builder_.get());
A
Allen Guo 已提交
492 493
      auto output_ids = builder_->aiGraphcoreOpset1().printtensor(
          inputs, print_gradient, debug_context, title);
A
Allen Guo 已提交
494
      PostLower(output_ids, op_desc);
A
Allen Guo 已提交
495 496 497 498 499 500 501 502 503 504
    } else {
      auto itr = name_function_.find(op_type);
      if (itr != name_function_.end()) {
        itr->second(node->Op());
      } else {
        PADDLE_THROW(platform::errors::NotFound(
            "%s is not registered, please check for unsupported operators for "
            "running on IPU",
            op_type));
      }
A
Allen Guo 已提交
505
    }
A
Allen Guo 已提交
506 507 508
  }
  VLOG(10) << "leave Compiler::LowerBody";
}
A
Allen Guo 已提交
509

A
Allen Guo 已提交
510 511
void Compiler::LowerOptimizer(const Scope* scope) {
  for (auto* node : graph_helper_->sorted_ops) {
A
Allen Guo 已提交
512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
    auto* op_desc = node->Op();
    auto op_type = op_desc->Type();
    if (op_type == "popart_optimizer") {
      auto raw_type =
          BOOST_GET_CONST(std::string, op_desc->GetAttr("raw_type"));
      resources_->optimizer_type = raw_type;
      auto loss_var =
          BOOST_GET_CONST(std::string, op_desc->GetAttr("loss_var"));
      resources_->loss_var = resources_->tensors[loss_var];
      resources_->with_lr_sched =
          BOOST_GET_CONST(bool, op_desc->GetAttr("with_lr_sched"));
      if (op_desc->HasAttr("lr_var")) {
        auto lr_var = BOOST_GET_CONST(std::string, op_desc->GetAttr("lr_var"));
        resources_->lr_var = lr_var;
        resources_->lr = GetSingleVarFromScope<float>(scope, lr_var);
      } else {
        // adadelta has no lr
        resources_->lr = 0.01f;
        resources_->with_lr_sched = false;
      }
      VLOG(10) << "Set initial lr: " << resources_->lr;
A
Allen Guo 已提交
533 534

      // Get the type of optimizer
A
Allen Guo 已提交
535
      auto type = BOOST_GET_CONST(std::string, op_desc->GetAttr("type"));
A
Allen Guo 已提交
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557
      // Set weight decay by tensor names for Lamb
      auto weight_decay_vars = BOOST_GET_CONST(
          std::vector<std::string>, op_desc->GetAttr("weight_decay_vars"));
      auto weight_decay_values = BOOST_GET_CONST(
          std::vector<float>, op_desc->GetAttr("weight_decay_values"));
      // Get the maximum permissible value for gradient clipping
      std::vector<popart::ClipNormSettings> clip_norm_settings = {};
      if (op_desc->HasAttr("clip_norm")) {
        auto clip_norm = BOOST_GET_CONST(float, op_desc->GetAttr("clip_norm"));
        clip_norm_settings.push_back(
            popart::ClipNormSettings::clipAllWeights(clip_norm));
        VLOG(10) << "Set the global gradient clipping with the maximum "
                    "permissible value: "
                 << clip_norm;
      }

      // Values from ipu_strategy
      auto loss_scaling = ipu_strategy_->loss_scaling;
      auto accl1_type = DataTypeFromStr(ipu_strategy_->accl1_type);
      auto accl2_type = DataTypeFromStr(ipu_strategy_->accl2_type);
      auto accl3_type = DataTypeFromStr(ipu_strategy_->accl3_type);

A
Allen Guo 已提交
558 559 560 561 562 563 564
      if (type == "sgd") {
        auto weight_decay =
            BOOST_GET_CONST(float, op_desc->GetAttr("weight_decay"));
        auto momentum = BOOST_GET_CONST(float, op_desc->GetAttr("momentum"));
        resources_->optimizer_fn = [=](float lr) {
          return std::make_unique<popart::SGD>(
              popart::OptimizerValue(lr, false),
A
Allen Guo 已提交
565
              popart::OptimizerValue(weight_decay, false),
A
Allen Guo 已提交
566 567 568
              popart::OptimizerValue(momentum, true),
              popart::SGD::getUnsetDampening(),
              popart::SGD::getUnsetVelocityScaling(),
A
Allen Guo 已提交
569
              popart::OptimizerValue(loss_scaling, true), clip_norm_settings);
A
Allen Guo 已提交
570
        };
A
Allen Guo 已提交
571 572 573 574 575 576
        resources_->eval_optimizer = std::make_unique<popart::SGD>(
            popart::OptimizerValue(0.0, false),
            popart::OptimizerValue(0.0, false),
            popart::OptimizerValue(0.0, true), popart::SGD::getUnsetDampening(),
            popart::SGD::getUnsetVelocityScaling(),
            popart::OptimizerValue(loss_scaling, true), clip_norm_settings);
A
Allen Guo 已提交
577 578 579 580 581 582 583 584 585 586
      } else if (type == "adam") {
        auto weight_decay =
            BOOST_GET_CONST(float, op_desc->GetAttr("weight_decay"));
        auto beta1 = BOOST_GET_CONST(float, op_desc->GetAttr("beta1"));
        auto beta2 = BOOST_GET_CONST(float, op_desc->GetAttr("beta2"));
        auto eps = BOOST_GET_CONST(float, op_desc->GetAttr("eps"));
        auto mwn = ipu_strategy_->max_weight_norm;
        VLOG(10) << "set max_weight_norm: " << mwn;
        auto adam_mode_ =
            BOOST_GET_CONST(std::string, op_desc->GetAttr("adam_mode"));
A
Allen Guo 已提交
587 588 589
        auto adam_mode =
            AdamModeFromStr(adam_mode_, ipu_strategy_->use_no_bias_optimizer);
        auto weight_decay_mode_ = ipu_strategy_->weight_decay_mode;
A
Allen Guo 已提交
590
        auto scaled_optimizer_state_ = ipu_strategy_->scaled_optimizer_state;
A
Allen Guo 已提交
591 592 593 594
        if (weight_decay_mode_.empty()) {
          weight_decay_mode_ = BOOST_GET_CONST(
              std::string, op_desc->GetAttr("weight_decay_mode"));
        }
A
Allen Guo 已提交
595 596
        auto weight_decay_mode = WeightDecayModeFromStr(weight_decay_mode_);
        resources_->optimizer_fn = [=](float lr) {
A
Allen Guo 已提交
597 598 599 600 601 602 603 604 605 606 607 608
          if (adam_mode == popart::AdamMode::Lamb ||
              adam_mode == popart::AdamMode::LambNoBias) {
            const std::map<std::string, std::pair<float, bool>>
                optimizer_value = {{"defaultLearningRate", {lr, false}},
                                   {"defaultBeta1", {beta1, false}},
                                   {"defaultBeta2", {beta2, false}},
                                   {"defaultEps", {eps, true}},
                                   {"lossScaling", {loss_scaling, true}},
                                   {"defaultMaxWeightNorm", {mwn, true}}};
            auto optimizer_instance = std::make_unique<popart::Adam>(
                optimizer_value, adam_mode, weight_decay_mode,
                popart::DataType::UNDEFINED, accl1_type, accl2_type,
A
Allen Guo 已提交
609
                clip_norm_settings, scaled_optimizer_state_);
A
Allen Guo 已提交
610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
            for (int i = 0; i < weight_decay_vars.size(); i++) {
              optimizer_instance->insertSpecific(
                  weight_decay_vars[i],
                  {{"weightDecay", {weight_decay_values[i], false}}});
              VLOG(10) << "Set Tensor " << weight_decay_vars[i]
                       << " weight decay as " << weight_decay_values[i];
            }
            return optimizer_instance;
          } else {
            return std::make_unique<popart::Adam>(
                popart::OptimizerValue(lr, false),
                popart::OptimizerValue(weight_decay, false),
                popart::OptimizerValue(beta1, false),
                popart::OptimizerValue(beta2, false),
                popart::OptimizerValue(eps, true),
                popart::OptimizerValue(loss_scaling, true),
                popart::OptimizerValue(mwn, true), adam_mode, weight_decay_mode,
                popart::DataType::UNDEFINED, accl1_type, accl2_type,
A
Allen Guo 已提交
628
                clip_norm_settings, scaled_optimizer_state_);
A
Allen Guo 已提交
629 630
          }
        };
A
Allen Guo 已提交
631
        if (adam_mode == popart::AdamMode::Lamb) {
A
Allen Guo 已提交
632 633 634 635 636 637 638 639 640 641
          const std::map<std::string, std::pair<float, bool>> optimizer_value =
              {{"defaultLearningRate", {0.0, false}},
               {"defaultBeta1", {beta1, false}},
               {"defaultBeta2", {beta2, false}},
               {"defaultEps", {eps, true}},
               {"lossScaling", {loss_scaling, true}},
               {"defaultMaxWeightNorm", {mwn, true}}};
          auto eval_optimizer = std::make_unique<popart::Adam>(
              optimizer_value, adam_mode, weight_decay_mode,
              popart::DataType::UNDEFINED, popart::DataType::FLOAT,
A
Allen Guo 已提交
642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
              popart::DataType::FLOAT, clip_norm_settings,
              scaled_optimizer_state_);
          for (int i = 0; i < weight_decay_vars.size(); i++) {
            eval_optimizer->insertSpecific(weight_decay_vars[i],
                                           {{"weightDecay", {0.0, false}}});
          }
          resources_->eval_optimizer = std::move(eval_optimizer);
        } else if (adam_mode == popart::AdamMode::LambNoBias) {
          const std::map<std::string, std::pair<float, bool>> optimizer_value =
              {{"defaultLearningRate", {0.0, false}},
               {"defaultBeta1", {1.0, false}},
               {"defaultBeta2", {1.0, false}},
               {"defaultEps", {eps, true}},
               {"lossScaling", {loss_scaling, true}},
               {"defaultMaxWeightNorm", {mwn, true}}};
          auto eval_optimizer = std::make_unique<popart::Adam>(
              optimizer_value, adam_mode, weight_decay_mode,
              popart::DataType::UNDEFINED, popart::DataType::FLOAT,
              popart::DataType::FLOAT, clip_norm_settings,
              scaled_optimizer_state_);
A
Allen Guo 已提交
662 663 664 665 666 667 668 669 670 671 672
          for (int i = 0; i < weight_decay_vars.size(); i++) {
            eval_optimizer->insertSpecific(weight_decay_vars[i],
                                           {{"weightDecay", {0.0, false}}});
          }
          resources_->eval_optimizer = std::move(eval_optimizer);
        } else {
          resources_->eval_optimizer = std::make_unique<popart::Adam>(
              popart::OptimizerValue(0.0, false),
              popart::OptimizerValue(0.0, false),
              popart::OptimizerValue(beta1, false),
              popart::OptimizerValue(beta2, false),
A
Allen Guo 已提交
673 674 675 676
              popart::OptimizerValue(eps, true),
              popart::OptimizerValue(loss_scaling, true),
              popart::OptimizerValue(mwn, true), adam_mode, weight_decay_mode,
              popart::DataType::UNDEFINED, popart::DataType::FLOAT,
A
Allen Guo 已提交
677 678
              popart::DataType::FLOAT, clip_norm_settings,
              scaled_optimizer_state_);
A
Allen Guo 已提交
679
        }
A
Allen Guo 已提交
680 681 682 683 684 685 686 687 688
      } else if (type == "adaptive") {
        auto alpha = BOOST_GET_CONST(float, op_desc->GetAttr("alpha"));
        auto momentum = BOOST_GET_CONST(float, op_desc->GetAttr("momentum"));
        auto eps = BOOST_GET_CONST(float, op_desc->GetAttr("eps"));
        auto weight_decay =
            BOOST_GET_CONST(float, op_desc->GetAttr("weight_decay"));
        auto adaptive_mode_ =
            BOOST_GET_CONST(std::string, op_desc->GetAttr("adaptive_mode"));
        auto adaptive_mode = AdaptiveModeFromStr(adaptive_mode_);
A
Allen Guo 已提交
689 690 691 692 693
        auto weight_decay_mode_ = ipu_strategy_->weight_decay_mode;
        if (weight_decay_mode_.empty()) {
          weight_decay_mode_ = BOOST_GET_CONST(
              std::string, op_desc->GetAttr("weight_decay_mode"));
        }
A
Allen Guo 已提交
694 695 696 697
        auto weight_decay_mode = WeightDecayModeFromStr(weight_decay_mode_);
        resources_->optimizer_fn = [=](float lr) {
          return std::make_unique<popart::Adaptive>(
              popart::OptimizerValue(lr, false),
A
Allen Guo 已提交
698
              popart::OptimizerValue(weight_decay, false),
A
Allen Guo 已提交
699 700 701 702
              popart::OptimizerValue(alpha, true),
              popart::OptimizerValue(momentum, true),
              popart::OptimizerValue(eps, true),
              popart::OptimizerValue(loss_scaling, true), adaptive_mode,
A
Allen Guo 已提交
703 704
              weight_decay_mode, popart::DataType::UNDEFINED, accl1_type,
              accl2_type, accl3_type);
A
Allen Guo 已提交
705
        };
A
Allen Guo 已提交
706 707 708 709 710 711 712 713 714 715
        resources_->eval_optimizer = std::make_unique<popart::Adaptive>(
            popart::OptimizerValue(0.0, false),
            popart::OptimizerValue(0.0, false),
            popart::OptimizerValue(alpha, true),
            popart::OptimizerValue(momentum, true),
            popart::OptimizerValue(eps, true),
            popart::OptimizerValue(loss_scaling, true), adaptive_mode,
            weight_decay_mode, popart::DataType::UNDEFINED,
            popart::DataType::FLOAT, popart::DataType::FLOAT,
            popart::DataType::UNDEFINED);
A
Allen Guo 已提交
716 717 718 719 720 721
      } else {
        PADDLE_THROW(platform::errors::Unimplemented(
            "optimizer %s is not implemented", type));
      }
    }
  }
J
jianghaicheng 已提交
722 723
}

A
Allen Guo 已提交
724 725 726 727 728
void Compiler::PostLower(const std::vector<std::string>& tensor_ids,
                         const OpDesc* op_desc) {
  // Set pipline
  // Due to the limitation of popart, if an op has multiple outputs,
  // pipline settings needs to be set at the same time
J
jianghaicheng 已提交
729 730 731 732 733 734 735 736 737 738
  auto tensor_ids_set =
      std::set<std::string>(tensor_ids.begin(), tensor_ids.end());
  if (op_desc->HasAttr(sIpuIndexAttr)) {
    auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuIndexAttr));
    builder_->virtualGraph(tensor_ids_set, ipu_index);
    VLOG(10) << "set " << sIpuIndexAttr << " = " << ipu_index
             << " for op: " << op_desc->Type();
    if (op_desc->HasAttr(sIpuStageAttr)) {
      auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuStageAttr));
      builder_->pipelineStage(tensor_ids_set, ipu_stage);
A
Allen Guo 已提交
739
      VLOG(10) << "set " << sIpuStageAttr << " = " << ipu_stage
J
jianghaicheng 已提交
740 741 742
               << " for op: " << op_desc->Type();
    }
  }
743 744 745 746 747 748 749 750
  // Record output tensors
  auto pd_outs = GetOpOutputs(op_desc);
  PADDLE_ENFORCE_EQ(
      pd_outs.size(), tensor_ids.size(),
      platform::errors::Fatal("paddle and popart op have different outputs"));
  for (int i = 0; i < tensor_ids.size(); ++i) {
    resources_->tensors.emplace(pd_outs[i], tensor_ids[i]);
  }
A
Allen Guo 已提交
751 752 753
  for (auto& tensor_id : tensor_ids) {
    PostLower(tensor_id, op_desc, true);
  }
J
jianghaicheng 已提交
754 755
}

A
Allen Guo 已提交
756
void Compiler::PostLower(const std::string& tensor_id, const OpDesc* op_desc) {
757 758 759 760 761 762
  // Record output tensor
  auto pd_outs = GetOpOutputs(op_desc);
  PADDLE_ENFORCE_EQ(
      pd_outs.size(), 1,
      platform::errors::Fatal("paddle and popart op have different outputs"));
  resources_->tensors.emplace(pd_outs[0], tensor_id);
A
Allen Guo 已提交
763 764
  PostLower(tensor_id, op_desc, false);
}
J
jianghaicheng 已提交
765

A
Allen Guo 已提交
766 767 768 769
void Compiler::PostLower(const std::string& tensor_id, const OpDesc* op_desc,
                         bool skip_pipline) {
  // Set pipline
  if (!skip_pipline && op_desc->HasAttr(sIpuIndexAttr)) {
J
jianghaicheng 已提交
770 771 772 773 774 775 776
    auto ipu_index = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuIndexAttr));
    builder_->virtualGraph(tensor_id, ipu_index);
    VLOG(10) << "set " << sIpuIndexAttr << " = " << ipu_index
             << " for op: " << op_desc->Type();
    if (op_desc->HasAttr(sIpuStageAttr)) {
      auto ipu_stage = BOOST_GET_CONST(int, op_desc->GetAttr(sIpuStageAttr));
      builder_->pipelineStage(tensor_id, ipu_stage);
A
Allen Guo 已提交
777
      VLOG(10) << "set " << sIpuStageAttr << " = " << ipu_stage
J
jianghaicheng 已提交
778 779 780
               << " for op: " << op_desc->Type();
    }
  }
A
Allen Guo 已提交
781
  // Set amp
A
Allen Guo 已提交
782
  if (op_desc->Type() == "popart_matmul") {
A
Allen Guo 已提交
783 784 785 786
    if (set_amp_for_all_) {
      auto amp = ipu_strategy_->available_memory_proportion;
      if (amp < 0.0f || amp > 1.0) {
        PADDLE_THROW(platform::errors::InvalidArgument(
A
Allen Guo 已提交
787 788
            "AvailableMemoryProportion %f is invalid, which should be in "
            "range [0.0, 1.0]",
A
Allen Guo 已提交
789 790 791 792 793 794 795 796 797 798
            amp));
      }
      if (amp > 0.0f) {
        builder_->setAvailableMemoryProportion(tensor_id, amp);
      }
    } else {
      if (op_desc->HasAttr(sAvailMemAttribute)) {
        auto amp = BOOST_GET_CONST(float, op_desc->GetAttr(sAvailMemAttribute));
        if (amp < 0.0f || amp > 1.0) {
          PADDLE_THROW(platform::errors::InvalidArgument(
A
Allen Guo 已提交
799 800
              "AvailableMemoryProportion %f is invalid, which should be in "
              "range [0.0, 1.0]",
A
Allen Guo 已提交
801 802 803 804 805 806 807 808
              amp));
        }
        if (amp > 0.0f) {
          builder_->setAvailableMemoryProportion(tensor_id, amp);
          VLOG(10) << "set available_memory_proportion for tensor: "
                   << tensor_id << " as " << amp;
        }
      }
A
Allen Guo 已提交
809
    }
A
Allen Guo 已提交
810
    // Set serialize matmul
A
Allen Guo 已提交
811 812 813 814 815 816 817 818
    if (op_desc->HasAttr(sMatmulSerializeFactor)) {
      auto factor =
          BOOST_GET_CONST(int, op_desc->GetAttr(sMatmulSerializeFactor));
      std::string mode = "output_channels";
      if (op_desc->HasAttr(sMatmulSerializeMode)) {
        mode = BOOST_GET_CONST(std::string,
                               op_desc->GetAttr(sMatmulSerializeMode));
      }
A
Allen Guo 已提交
819
      builder_->setSerializeMatMul({tensor_id}, mode, factor, true);
A
Allen Guo 已提交
820 821 822
    }
  }
}
J
jianghaicheng 已提交
823

A
Allen Guo 已提交
824 825 826 827 828 829 830 831
void Compiler::SetCustomOps(
    const std::vector<IpuCustomOpIdentifier>& custom_ops) {
  for (auto x : custom_ops) {
    custom_ops_.emplace(x.paddle_op, x);
  }
}

std::string Compiler::GetFP16ModelProto() {
J
jianghaicheng 已提交
832 833
  popart::GraphTransformer graph_transformer(builder_->getModelProto());
  graph_transformer.convertFloatsToHalfs();
A
Allen Guo 已提交
834
  return graph_transformer.getModelProto();
J
jianghaicheng 已提交
835 836
}

837
std::string Compiler::GetModelProto() { return builder_->getModelProto(); }
J
jianghaicheng 已提交
838 839 840 841 842 843 844 845 846 847 848 849

void Compiler::SaveModelProto(const std::string& path) {
  builder_->saveModelProto(path);
}

void Compiler::SaveModelProtoNoCheck(const std::string& path) {
  auto proto = GetModelProto();
  std::ofstream onnxfile(path, std::ios_base::binary);
  onnxfile.write(proto.data(), proto.size());
  onnxfile.close();
}

A
Allen Guo 已提交
850
std::vector<std::string> Compiler::GetOpInputs(const OpDesc* op) {
J
jianghaicheng 已提交
851 852 853
  auto ins = op->Input("__inputs__");
  std::vector<std::string> inputs;
  for (const auto& in : ins) {
A
Allen Guo 已提交
854 855
    if (resources_->tensors.find(in) != resources_->tensors.end()) {
      inputs.push_back(resources_->tensors[in]);
J
jianghaicheng 已提交
856 857 858 859 860 861 862
    } else {
      inputs.push_back(in);
    }
  }
  return inputs;
}

A
Allen Guo 已提交
863
const std::vector<std::string>& Compiler::GetOpOutputs(const OpDesc* op) {
J
jianghaicheng 已提交
864 865 866
  return op->Output("__outputs__");
}

A
Allen Guo 已提交
867
popart::DebugContext Compiler::BuildDebugContext(const OpDesc* op) {
J
jianghaicheng 已提交
868 869 870 871 872 873 874 875 876 877
  auto op_identify_id =
      BOOST_GET_CONST(std::string, op->GetAttr(sOpIdentifyIdAttr));
  VLOG(10) << "op_identify_id of op: " << op->Type() << " is "
           << op_identify_id;
  return popart::DebugContext(op_identify_id);
}

}  // namespace ipu
}  // namespace platform
}  // namespace paddle