op_desc.cc 27.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_desc.h"
16
#include <algorithm>
Y
Yu Yang 已提交
17
#include <functional>
18 19
#include <mutex>  // NOLINT
#include <string>
Y
Yu Yang 已提交
20
#include <unordered_map>
21
#include <utility>
22
#include "glog/logging.h"
Y
Yi Wang 已提交
23
#include "paddle/fluid/framework/block_desc.h"
24
#include "paddle/fluid/framework/op_call_stack.h"
Y
yuyang18 已提交
25
#include "paddle/fluid/framework/op_proto_maker.h"
Y
Yi Wang 已提交
26 27 28
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/shape_inference.h"
M
minqiyang 已提交
29
#include "paddle/fluid/framework/var_type_inference.h"
Y
Yu Yang 已提交
30

F
fengjiayi 已提交
31 32 33
namespace paddle {
namespace framework {

Y
Yu Yang 已提交
34 35
class OpDesc;
class BlockDesc;
36 37
class CompileTimeInferShapeContext : public InferShapeContext {
 public:
Y
Yu Yang 已提交
38
  CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55

  bool HasInput(const std::string &name) const override;

  bool HasOutput(const std::string &name) const override;

  bool HasInputs(const std::string &name) const override;

  bool HasOutputs(const std::string &name) const override;

  AttrReader Attrs() const override;

  const std::vector<std::string> &Inputs(
      const std::string &name) const override;

  const std::vector<std::string> &Outputs(
      const std::string &name) const override;

56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
  void ShareDim(const std::string &in, const std::string &out, size_t i = 0,
                size_t j = 0) override {
    PADDLE_ENFORCE_LT(i, Inputs(in).size());
    PADDLE_ENFORCE_LT(j, Outputs(out).size());
    const std::string &input_n = Inputs(in)[i];
    const std::string &output_n = Outputs(out)[j];

    PADDLE_ENFORCE(input_n != framework::kEmptyVarName, "The %s[%d] is @EMPTY@",
                   in, i);
    PADDLE_ENFORCE(output_n != framework::kEmptyVarName,
                   "The %s[%d] is @EMPTY@", out, j);

    auto *in_var = block_.FindVarRecursive(input_n);
    auto *out_var = block_.FindVarRecursive(output_n);

    PADDLE_ENFORCE(in_var->GetType() == out_var->GetType(),
                   "The type of %s and %s is not the same.", input_n, output_n);

    SetDim(output_n, GetDim(input_n));
  }

Q
Qiao Longfei 已提交
77 78 79 80
  void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
                size_t j = 0) const override {
    PADDLE_ENFORCE_LT(i, Inputs(in).size());
    PADDLE_ENFORCE_LT(j, Outputs(out).size());
C
chengduo 已提交
81 82 83 84
    PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName,
                   "The %s[%d] is @EMPTY@", in, i);
    PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName,
                   "The %s[%d] is @EMPTY@", out, j);
Q
Qiao Longfei 已提交
85 86
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
C
chengduo 已提交
87 88 89
    if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
        in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
      VLOG(3) << "input " << in << " is not LodTensor or LodTensorArray.";
X
fix  
Xin Pan 已提交
90 91
      return;
    }
92
    out_var->SetLoDLevel(in_var->GetLoDLevel());
Q
Qiao Longfei 已提交
93
  }
D
dzhwinter 已提交
94

C
chengduo 已提交
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
  void DecreaseLoDLevel(const std::string &in, const std::string &out,
                        size_t i = 0, size_t j = 0) const override {
    PADDLE_ENFORCE_LT(i, Inputs(in).size());
    PADDLE_ENFORCE_LT(j, Outputs(out).size());
    PADDLE_ENFORCE(Inputs(in)[i] != framework::kEmptyVarName,
                   "The %s[%d] is @EMPTY@", in, i);
    PADDLE_ENFORCE(Outputs(out)[j] != framework::kEmptyVarName,
                   "The %s[%d] is @EMPTY@", out, j);
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
    PADDLE_ENFORCE(out_var->GetType() == proto::VarType::LOD_TENSOR_ARRAY ||
                       out_var->GetType() == proto::VarType::LOD_TENSOR,
                   "The input %s should be LodTensorArray or LodTensor.",
                   out_var->Name());
    PADDLE_ENFORCE(in_var->GetType() == proto::VarType::LOD_TENSOR,
                   "The input %s should be LodTensor.", in_var->Name());
    if (in_var->GetLoDLevel() > 0) {
      out_var->SetLoDLevel(in_var->GetLoDLevel() - 1);
    }
  }

116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
  std::vector<InferShapeVarPtr> GetInputVarPtrs(
      const std::string &name) override {
    const std::vector<std::string> arg_names = Inputs(name);
    std::vector<InferShapeVarPtr> res;
    res.reserve(arg_names.size());
    std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

  std::vector<InferShapeVarPtr> GetOutputVarPtrs(
      const std::string &name) override {
    const std::vector<std::string> arg_names = Outputs(name);
    std::vector<InferShapeVarPtr> res;
    res.reserve(arg_names.size());
    std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

X
Xin Pan 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152
  DDim GetInputDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
    PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
                      "Input(%s) should hold one element, but now it holds %d",
                      name, arg_names.size());
    return this->GetDim(arg_names[0]);
  }

  std::vector<DDim> GetInputsDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
    return GetDims(arg_names);
  }

153 154
  bool IsRuntime() const override;

X
Xin Pan 已提交
155 156 157 158 159 160 161 162 163 164
  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Inputs(name));
  }

  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Outputs(name));
  }

X
Xin Pan 已提交
165 166 167 168 169 170 171 172 173 174 175 176 177 178
  void SetOutputDim(const std::string &name, const DDim &dim) override {
    auto &arg_names = Outputs(name);
    PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
                      "Output(%s) should hold one element, but now it holds %d",
                      name, arg_names.size());
    SetDim(arg_names[0], dim);
  }

  void SetOutputsDim(const std::string &name,
                     const std::vector<DDim> &dims) override {
    auto &names = Outputs(name);
    SetDims(names, dims);
  }

179
 protected:
X
Xin Pan 已提交
180 181 182 183 184 185 186 187 188 189 190 191
  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<std::string> &names) const {
    std::vector<proto::VarType::Type> retv;
    retv.resize(names.size());
    std::transform(
        names.begin(), names.end(), retv.begin(),
        std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType), this,
                  std::placeholders::_1));
    return retv;
  }

  proto::VarType::Type GetVarType(const std::string &name) const;
Q
Qiao Longfei 已提交
192

X
Xin Pan 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214
  DDim GetDim(const std::string &name) const {
    auto var = block_.FindVarRecursive(name);
    PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
    DDim res;
    try {
      auto shape = var->GetShape();
      res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
    } catch (...) {
      VLOG(5) << "GetDim of variable " << name << " error";
      std::rethrow_exception(std::current_exception());
    }
    return res;
  }

  std::vector<DDim> GetDims(const std::vector<std::string> &names) const {
    std::vector<DDim> ret;
    ret.reserve(names.size());
    std::transform(
        names.begin(), names.end(), std::back_inserter(ret),
        [this](const std::string &name) { return this->GetDim(name); });
    return ret;
  }
215

X
Xin Pan 已提交
216 217 218 219 220 221 222 223 224 225 226 227 228
  void SetDim(const std::string &name, const DDim &dim);

  void SetDims(const std::vector<std::string> &names,
               const std::vector<DDim> &dims) {
    size_t length = names.size();
    PADDLE_ENFORCE_EQ(length, dims.size());
    for (size_t i = 0; i < length; ++i) {
      if (names[i] == framework::kEmptyVarName) {
        continue;
      }
      SetDim(names[i], dims[i]);
    }
  }
229

F
fengjiayi 已提交
230 231 232 233
  std::vector<DDim> GetRepeatedDims(const std::string &name) const override;

  void SetRepeatedDims(const std::string &name,
                       const std::vector<DDim> &dims) override;
F
fengjiayi 已提交
234

Y
Yu Yang 已提交
235 236
  const OpDesc &op_;
  const BlockDesc &block_;
237 238
};

Y
Yu Yang 已提交
239 240
OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
               const VariableNameMap &outputs, const AttributeMap &attrs) {
241
  desc_.set_type(type);
F
fengjiayi 已提交
242 243 244
  inputs_ = inputs;
  outputs_ = outputs;
  attrs_ = attrs;
F
Fix bug  
fengjiayi 已提交
245
  need_update_ = true;
L
liuwei1031 已提交
246
  block_ = nullptr;
F
fengjiayi 已提交
247 248
}

X
Xin Pan 已提交
249 250 251 252 253 254
OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) {
  CopyFrom(other);
  block_ = block;
  need_update_ = true;
}

255
void OpDesc::CopyFrom(const OpDesc &op_desc) {
F
fengjiayi 已提交
256 257 258 259 260 261 262
  desc_.set_type(op_desc.Type());
  inputs_ = op_desc.inputs_;
  outputs_ = op_desc.outputs_;
  attrs_ = op_desc.attrs_;
  need_update_ = true;
}

F
fengjiayi 已提交
263
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
264 265 266 267
    : desc_(desc), need_update_(false) {
  // restore inputs_
  int input_size = desc_.inputs_size();
  for (int i = 0; i < input_size; ++i) {
268
    const proto::OpDesc::Var &var = desc_.inputs(i);
269 270 271 272 273 274 275 276 277 278
    std::vector<std::string> &args = inputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore outputs_
  int output_size = desc_.outputs_size();
  for (int i = 0; i < output_size; ++i) {
279
    const proto::OpDesc::Var &var = desc_.outputs(i);
280 281 282 283 284 285 286 287
    std::vector<std::string> &args = outputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore attrs_
288
  for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
289
    std::string attr_name = attr.name();
290
    // The sub_block referred to by the BLOCK attr hasn't been added
X
Xin Pan 已提交
291 292 293
    // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS attr here.
    if (attr.type() != proto::AttrType::BLOCK &&
        attr.type() != proto::AttrType::BLOCKS) {
294 295
      attrs_[attr_name] = GetAttrValue(attr);
    }
296
  }
297
  this->block_ = block;
298 299
}

Y
Yu Yang 已提交
300
proto::OpDesc *OpDesc::Proto() {
301
  Flush();
302
  return &desc_;
F
fengjiayi 已提交
303 304
}

Y
Yu Yang 已提交
305
const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
F
fengjiayi 已提交
306 307 308 309 310 311
  auto it = inputs_.find(name);
  PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name,
                 Type());
  return it->second;
}

Y
Yu Yang 已提交
312
std::vector<std::string> OpDesc::InputArgumentNames() const {
F
Update  
fengjiayi 已提交
313 314 315 316 317 318 319
  std::vector<std::string> retv;
  for (auto &ipt : this->inputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
320 321
void OpDesc::SetInput(const std::string &param_name,
                      const std::vector<std::string> &args) {
F
fengjiayi 已提交
322 323 324 325
  need_update_ = true;
  inputs_[param_name] = args;
}

Y
Yu Yang 已提交
326
const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
F
fengjiayi 已提交
327 328 329 330 331 332
  auto it = outputs_.find(name);
  PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s",
                 name, Type());
  return it->second;
}

Y
Yu Yang 已提交
333
std::vector<std::string> OpDesc::OutputArgumentNames() const {
F
Update  
fengjiayi 已提交
334 335 336 337 338 339 340
  std::vector<std::string> retv;
  for (auto &ipt : this->outputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
341 342
void OpDesc::SetOutput(const std::string &param_name,
                       const std::vector<std::string> &args) {
F
fengjiayi 已提交
343 344 345 346
  need_update_ = true;
  this->outputs_[param_name] = args;
}

347 348 349 350 351 352 353 354 355 356
bool OpDesc::HasProtoAttr(const std::string &name) const {
  auto &op_info = OpInfoMap::Instance();
  if (op_info.Has(desc_.type())) {
    auto op_info_ptr = op_info.Get(desc_.type());
    if (op_info_ptr.HasOpProtoAndChecker()) {
      const proto::OpProto &proto = op_info_ptr.Proto();
      for (int i = 0; i != proto.attrs_size(); ++i) {
        const proto::OpProto::Attr &attr = proto.attrs(i);
        if (attr.name() == name) {
          return true;
L
luotao1 已提交
357 358
        }
      }
L
luotao1 已提交
359 360 361 362 363
    }
  }
  return false;
}

Y
Yu Yang 已提交
364
proto::AttrType OpDesc::GetAttrType(const std::string &name) const {
F
fengjiayi 已提交
365 366
  auto it = attrs_.find(name);
  PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
367
  return static_cast<proto::AttrType>(it->second.which() - 1);
F
fengjiayi 已提交
368 369
}

Y
Yu Yang 已提交
370
std::vector<std::string> OpDesc::AttrNames() const {
F
fengjiayi 已提交
371 372 373 374 375 376 377 378
  std::vector<std::string> retv;
  retv.reserve(attrs_.size());
  for (auto &attr : attrs_) {
    retv.push_back(attr.first);
  }
  return retv;
}

379 380 381 382 383
void OpDesc::RemoveAttr(const std::string &name) {
  attrs_.erase(name);
  need_update_ = true;
}

Y
Yu Yang 已提交
384
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
M
minqiyang 已提交
385 386 387 388 389 390 391
  // NOTICE(minqiyang): pybind11 will take the empty list in python as
  // the std::vector<int> type in C++; so we have to change the attr's type
  // here if we meet this issue
  proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
  if (attr_type == proto::AttrType::INTS &&
      boost::get<std::vector<int>>(v).size() == 0u) {
    // Find current attr via attr name and set the correct attribute value
M
minqiyang 已提交
392
    const proto::OpProto::Attr &attr = GetProtoAttr(name);
M
minqiyang 已提交
393 394
    switch (attr.type()) {
      case proto::AttrType::BOOLEANS: {
M
minqiyang 已提交
395 396
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BOOLEANS";
M
minqiyang 已提交
397 398 399 400
        this->attrs_[name] = std::vector<bool>();
        break;
      }
      case proto::AttrType::INTS: {
M
minqiyang 已提交
401 402
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to INTS";
M
minqiyang 已提交
403 404 405
        this->attrs_[name] = std::vector<int>();
        break;
      }
406
      case proto::AttrType::LONGS: {
M
minqiyang 已提交
407 408
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from LONGS to LONGS";
409 410 411
        this->attrs_[name] = std::vector<int64_t>();
        break;
      }
M
minqiyang 已提交
412
      case proto::AttrType::FLOATS: {
M
minqiyang 已提交
413 414
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to FLOATS";
M
minqiyang 已提交
415 416 417 418
        this->attrs_[name] = std::vector<float>();
        break;
      }
      case proto::AttrType::STRINGS: {
M
minqiyang 已提交
419 420
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to STRINGS";
M
minqiyang 已提交
421 422 423 424
        this->attrs_[name] = std::vector<std::string>();
        break;
      }
      case proto::AttrType::BLOCKS: {
M
minqiyang 已提交
425 426
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BLOCKS";
M
minqiyang 已提交
427
        this->SetBlocksAttr(name, std::vector<BlockDesc *>());
M
minqiyang 已提交
428 429
        return;
      }
M
minqiyang 已提交
430 431
      default:
        PADDLE_THROW("Wrong attr type %d", attr.type());
M
minqiyang 已提交
432
    }
M
minqiyang 已提交
433 434
    need_update_ = true;
    return;
M
minqiyang 已提交
435 436
  }

F
fengjiayi 已提交
437 438 439 440
  this->attrs_[name] = v;
  need_update_ = true;
}

A
Abhinav Arora 已提交
441 442
void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) {
  this->attrs_[name] = block;
F
fengjiayi 已提交
443
  need_update_ = true;
F
fengjiayi 已提交
444 445
}

446 447 448 449 450 451
void OpDesc::SetBlocksAttr(const std::string &name,
                           std::vector<BlockDesc *> blocks) {
  this->attrs_[name] = blocks;
  need_update_ = true;
}

Y
Yu Yang 已提交
452
void OpDesc::SetAttrMap(
F
fengjiayi 已提交
453 454 455 456 457
    const std::unordered_map<std::string, Attribute> &attr_map) {
  attrs_ = attr_map;
  need_update_ = true;
}

Y
Yu Yang 已提交
458
Attribute OpDesc::GetAttr(const std::string &name) const {
F
fengjiayi 已提交
459 460 461 462 463
  auto it = attrs_.find(name);
  PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
  return it->second;
}

M
minqiyang 已提交
464 465 466
const proto::OpProto::Attr &OpDesc::GetProtoAttr(
    const std::string &name) const {
  const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
M
minqiyang 已提交
467 468 469 470 471 472 473 474 475 476
  for (int i = 0; i != proto.attrs_size(); ++i) {
    const proto::OpProto::Attr &attr = proto.attrs(i);
    if (attr.name() == name) {
      return attr;
    }
  }

  PADDLE_THROW("Attribute %s is not found in proto %s", name, proto.type());
}

Y
yuyang18 已提交
477
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
Y
Fix bug  
yuyang18 已提交
478 479 480 481
  auto it = attrs_.find(name);
  if (it != attrs_.end()) {
    return it->second;
  } else {
Y
yuyang18 已提交
482
    return Attribute();
Y
Fix bug  
yuyang18 已提交
483 484 485
  }
}

G
gongweibao 已提交
486 487 488 489 490 491 492 493 494 495 496 497 498 499
std::vector<int> OpDesc::GetBlocksAttrIds(const std::string &name) const {
  auto it = attrs_.find(name);
  PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
  auto blocks = boost::get<std::vector<BlockDesc *>>(it->second);

  std::vector<int> ids;
  for (auto n : blocks) {
    ids.push_back(n->ID());
  }

  return ids;
}

int OpDesc::GetBlockAttrId(const std::string &name) const {
F
fengjiayi 已提交
500 501
  auto it = attrs_.find(name);
  PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
Y
Yu Yang 已提交
502
  return boost::get<BlockDesc *>(it->second)->ID();
F
fengjiayi 已提交
503 504
}

Y
Yu Yang 已提交
505
const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
F
fengjiayi 已提交
506 507 508
  return attrs_;
}

Y
Yu Yang 已提交
509
void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
Y
Yancey1989 已提交
510 511
  RenameInput(old_name, new_name);
  RenameOutput(old_name, new_name);
F
fengjiayi 已提交
512 513 514
  need_update_ = true;
}

Y
Yu Yang 已提交
515 516
void OpDesc::RenameOutput(const std::string &old_name,
                          const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
517 518 519 520
  for (auto &output : outputs_) {
    std::replace(output.second.begin(), output.second.end(), old_name,
                 new_name);
  }
Y
yuyang18 已提交
521 522 523 524 525 526 527

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
    auto &op_vars = boost::get<std::vector<std::string>>(it->second);
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
528 529 530
  need_update_ = true;
}

Y
Yu Yang 已提交
531 532
void OpDesc::RenameInput(const std::string &old_name,
                         const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
533 534 535
  for (auto &input : inputs_) {
    std::replace(input.second.begin(), input.second.end(), old_name, new_name);
  }
Y
Yancey1989 已提交
536 537 538 539 540 541 542

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
    auto &op_vars = boost::get<std::vector<std::string>>(it->second);
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
543 544 545
  need_update_ = true;
}

Y
Yu Yang 已提交
546
struct SetAttrDescVisitor : public boost::static_visitor<void> {
547 548
  explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
  mutable proto::OpDesc::Attr *attr_;
Y
Yu Yang 已提交
549 550 551
  void operator()(int v) const { attr_->set_i(v); }
  void operator()(float v) const { attr_->set_f(v); }
  void operator()(const std::string &v) const { attr_->set_s(v); }
Q
QI JUN 已提交
552 553 554 555 556 557 558

  // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
  template <class T,
            class = typename std::enable_if<std::is_same<bool, T>::value>::type>
  void operator()(T b) const {
    attr_->set_b(b);
  }
Y
Yu Yang 已提交
559 560 561 562 563 564 565 566 567 568 569 570 571

  void operator()(const std::vector<int> &v) const {
    VectorToRepeated(v, attr_->mutable_ints());
  }
  void operator()(const std::vector<float> &v) const {
    VectorToRepeated(v, attr_->mutable_floats());
  }
  void operator()(const std::vector<std::string> &v) const {
    VectorToRepeated(v, attr_->mutable_strings());
  }
  void operator()(const std::vector<bool> &v) const {
    VectorToRepeated(v, attr_->mutable_bools());
  }
572 573 574
  void operator()(const std::vector<BlockDesc *> &v) const {
    std::vector<int> blocks_idx;
    for (auto blk : v) {
T
tangwei12 已提交
575
      blocks_idx.push_back(blk->ID());
576 577 578
    }
    VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
  }
T
tangwei12 已提交
579 580 581

  void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }

582
  void operator()(int64_t v) const { attr_->set_l(v); }
T
tangwei12 已提交
583 584 585 586 587

  void operator()(const std::vector<int64_t> &v) const {
    VectorToRepeated(v, attr_->mutable_longs());
  }

Y
Yu Yang 已提交
588 589 590
  void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};

Y
Yu Yang 已提交
591
void OpDesc::Flush() {
F
fengjiayi 已提交
592
  if (need_update_) {
593
    this->desc_.mutable_inputs()->Clear();
F
fengjiayi 已提交
594
    for (auto &ipt : inputs_) {
595
      auto *input = desc_.add_inputs();
F
fengjiayi 已提交
596 597 598 599
      input->set_parameter(ipt.first);
      VectorToRepeated(ipt.second, input->mutable_arguments());
    }

600
    this->desc_.mutable_outputs()->Clear();
F
fengjiayi 已提交
601
    for (auto &opt : outputs_) {
602
      auto *output = desc_.add_outputs();
F
fengjiayi 已提交
603 604 605 606
      output->set_parameter(opt.first);
      VectorToRepeated(opt.second, output->mutable_arguments());
    }

607
    this->desc_.mutable_attrs()->Clear();
F
fengjiayi 已提交
608
    for (auto &attr : attrs_) {
609
      auto *attr_desc = desc_.add_attrs();
F
fengjiayi 已提交
610 611
      attr_desc->set_name(attr.first);
      attr_desc->set_type(
612
          static_cast<proto::AttrType>(attr.second.which() - 1));
Y
Yu Yang 已提交
613 614
      SetAttrDescVisitor visitor(attr_desc);
      boost::apply_visitor(visitor, attr.second);
F
fengjiayi 已提交
615 616 617 618 619
    }

    need_update_ = false;
  }
}
Y
Yu Yang 已提交
620

621 622
static std::once_flag init_infer_shape_funcs;

623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641
/**
 * NOTE(paddle-dev): Very tricky code here. Maybe we should find a
 * better way to register compile-time infershape method gentlely.
 *
 * Normally, we can register a class derived from InferShapeBase, so that
 * we can set the field of `infer_shape_` inside OpInfo when registering op.
 *
 * However, there is another way we can set the field of `infer_shape_` inside
 * OpInfo. Usually, we overload InferShape method of OperatorWithKernel. After
 * running the following method InitInferShapeFuncs, `infer_shape_` would be set
 * to be the InferShape method of OperatorWithKernel. That is to say, we borrow
 * the run-time InferShape method of OperatorWithKernel to be the compile-time
 * InferShape method.
 *
 * However, during compiling time, we may not know inputs, outputs and attrs of
 * run-time OperatorWithKernel. So the following code creates a fake
 * OperatorWithKernel object. That is why the field info_ of OperatorBase
 * would be null.
 */
642 643 644 645 646 647 648
static void InitInferShapeFuncs() {
  std::call_once(init_infer_shape_funcs, [] {
    auto &map = OpInfoMap::Instance();
    auto &info_map = *map.mutable_map();

    for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
      auto op_type = kern_pair.first;
C
chengduoZH 已提交
649 650 651 652
      auto it = info_map.find(op_type);
      PADDLE_ENFORCE(it != info_map.end(), "%s has not been registered",
                     op_type);
      auto &op_info = it->second;
653 654 655
      if (op_info.infer_shape_) {  // infer_shape has been registered.
        continue;
      }
656 657 658 659 660 661 662

      auto op = dynamic_cast<OperatorWithKernel *>(op_info.Creator()(
          "", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));

      PADDLE_ENFORCE_NOT_NULL(
          op, "InferShapeBase is not registered to Operator %s", op_type);

663 664 665
      op_info.infer_shape_ = [op](InferShapeContext *ctx) {
        op->InferShape(ctx);
      };
Y
Yu Yang 已提交
666
    }
667
  });
Y
Yu Yang 已提交
668 669
}

Y
Yu Yang 已提交
670
void OpDesc::CheckAttrs() {
F
fengjiayi 已提交
671 672
  PADDLE_ENFORCE(!Type().empty(),
                 "CheckAttr() can not be called before type is setted.");
Y
Yu Yang 已提交
673 674 675 676 677 678
  auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
  if (checker == nullptr) {
    // checker is not configured. That operator could be generated by Paddle,
    // not by users.
    return;
  }
679
  VLOG(10) << "begin to check attribute of " << Type();
T
tangwei12 已提交
680
  checker->Check(&attrs_);
F
fengjiayi 已提交
681 682
}

Y
Yu Yang 已提交
683
void OpDesc::InferShape(const BlockDesc &block) const {
684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710
  try {
    VLOG(3) << "CompileTime infer shape on " << Type();
    InitInferShapeFuncs();
    auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
    PADDLE_ENFORCE(static_cast<bool>(infer_shape),
                   "%s's infer_shape has not been registered", this->Type());
    CompileTimeInferShapeContext ctx(*this, block);
    if (VLOG_IS_ON(10)) {
      std::ostringstream sout;
      auto inames = this->InputArgumentNames();
      sout << " From [";
      std::copy(inames.begin(), inames.end(),
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "] to [";
      auto onames = this->OutputArgumentNames();
      std::copy(onames.begin(), onames.end(),
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "]";
      VLOG(10) << sout.str();
    }
    infer_shape(&ctx);
  } catch (platform::EnforceNotMet exception) {
    framework::InsertCallStackInfo(Type(), attrs_, &exception);
    throw std::move(exception);
  } catch (...) {
    std::rethrow_exception(std::current_exception());
  }
Y
Yu Yang 已提交
711 712
}

Y
Yu Yang 已提交
713
void OpDesc::InferVarType(BlockDesc *block) const {
X
Xin Pan 已提交
714 715 716 717 718
  // There are a few places that var type can be set.
  // When VarDesc is created, default set to LOD_TENSOR.
  // When output variable is created, default is defaut set to LOD_TENSOR.
  // We limit here to be the only place that operator defines its customized
  // var type inference. Hence, we don't do any "default" setting here.
Y
Yu Yang 已提交
719 720
  auto &info = OpInfoMap::Instance().Get(this->Type());
  if (info.infer_var_type_) {
M
minqiyang 已提交
721
    InferVarTypeContext context(this, block);
M
minqiyang 已提交
722
    info.infer_var_type_(&context);
Y
Yu Yang 已提交
723 724 725
  }
}

726
CompileTimeInferShapeContext::CompileTimeInferShapeContext(
Y
Yu Yang 已提交
727
    const OpDesc &op, const BlockDesc &block)
728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791
    : op_(op), block_(block) {}

bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
  const std::vector<std::string> &input_names = op_.Input(name);
  auto length = input_names.size();
  if (length == 0) {
    return false;
  }
  PADDLE_ENFORCE_EQ(length, 1UL,
                    "Input(%s) should have only one value, "
                    "but it have %d now",
                    name, length);
  return block_.HasVarRecursive(input_names[0]);
}

bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
  const std::vector<std::string> &output_names = op_.Output(name);
  auto length = output_names.size();
  if (length == 0) {
    return false;
  }
  PADDLE_ENFORCE_EQ(length, 1UL,
                    "Output(%s) should have only one value, "
                    "but it have %d now",
                    name, length);
  return block_.HasVarRecursive(output_names[0]);
}

bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
  const std::vector<std::string> &input_names = op_.Input(name);
  if (input_names.empty()) {
    return false;
  }
  for (auto &input : input_names) {
    if (!block_.HasVarRecursive(input)) return false;
  }
  return true;
}

bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
  const std::vector<std::string> &output_names = op_.Output(name);
  if (output_names.empty()) {
    return false;
  }
  for (auto &output : output_names) {
    if (!block_.HasVarRecursive(output)) return false;
  }
  return true;
}

AttrReader CompileTimeInferShapeContext::Attrs() const {
  return AttrReader(op_.GetAttrMap());
}

const std::vector<std::string> &CompileTimeInferShapeContext::Inputs(
    const std::string &name) const {
  return op_.Input(name);
}

const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
    const std::string &name) const {
  return op_.Output(name);
}

F
fengjiayi 已提交
792
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
F
fengjiayi 已提交
793 794 795 796 797 798 799 800 801 802
    const std::string &name) const {
  auto var = block_.FindVarRecursive(name);
  PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
  std::vector<DDim> res;
  try {
    auto shapes = var->GetShapes();
    for (const auto &s : shapes) {
      res.push_back(s.empty() ? make_ddim({0UL}) : make_ddim(s));
    }
  } catch (...) {
M
minqiyang 已提交
803
    VLOG(5) << "GetRepeatedDim of variable " << name << " error.";
F
fengjiayi 已提交
804 805 806
    std::rethrow_exception(std::current_exception());
  }
  return res;
807 808 809 810
}

void CompileTimeInferShapeContext::SetDim(const std::string &name,
                                          const DDim &dim) {
F
fengjiayi 已提交
811
  block_.FindVarRecursive(name)->SetShape(vectorize(dim));
812
}
F
fengjiayi 已提交
813 814 815 816 817 818

void CompileTimeInferShapeContext::SetRepeatedDims(
    const std::string &name, const std::vector<DDim> &dims) {
  auto var = block_.FindVarRecursive(name);
  PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
  std::vector<std::vector<int64_t>> dim_vec(dims.size());
819
  std::transform(dims.begin(), dims.end(), dim_vec.begin(), vectorize<>);
F
fengjiayi 已提交
820
  var->SetShapes(dim_vec);
821
}
F
fengjiayi 已提交
822

823 824
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }

825
proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
826 827 828
    const std::string &name) const {
  return block_.FindVarRecursive(name)->GetType();
}
829

F
fengjiayi 已提交
830 831
}  // namespace framework
}  // namespace paddle