op_desc.cc 44.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_desc.h"
16

17
#include <string>
18

19
#include "glog/logging.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/framework/block_desc.h"
21
#include "paddle/fluid/framework/op_call_stack.h"
Y
yuyang18 已提交
22
#include "paddle/fluid/framework/op_proto_maker.h"
Y
Yi Wang 已提交
23 24
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
M
minqiyang 已提交
25
#include "paddle/fluid/framework/var_type_inference.h"
26
#include "paddle/fluid/operators/ops_extra_info.h"
27
#include "paddle/phi/common/complex.h"
R
Ruibiao Chen 已提交
28
#include "paddle/utils/blank.h"
Y
Yu Yang 已提交
29

F
fengjiayi 已提交
30 31 32
namespace paddle {
namespace framework {

33 34
class CompileTimeInferShapeContext : public InferShapeContext {
 public:
Y
Yu Yang 已提交
35
  CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
36 37 38 39 40

  bool HasInput(const std::string &name) const override;

  bool HasOutput(const std::string &name) const override;

41 42
  bool HasAttr(const std::string &name) const override;

43 44
  bool HasInputs(const std::string &name) const override;

45 46
  bool HasOutputs(const std::string &name,
                  bool allow_null = false) const override;
47 48 49

  AttrReader Attrs() const override;

H
hong 已提交
50
  std::vector<std::string> Inputs(const std::string &name) const override;
51

H
hong 已提交
52
  std::vector<std::string> Outputs(const std::string &name) const override;
53

54 55 56
  std::string GetInputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
57 58
    PADDLE_ENFORCE_LT(idx,
                      op_proto->inputs().size(),
59 60 61
                      platform::errors::OutOfRange(
                          "The index should be less than the size of inputs of "
                          "operator %s, but got index is %d and size is %d",
62 63 64
                          op_.Type(),
                          idx,
                          op_proto->inputs().size()));
65 66 67 68 69 70 71
    return op_proto->inputs()[idx].name();
  }

  std::string GetOutputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
    PADDLE_ENFORCE_LT(
72 73
        idx,
        op_proto->outputs().size(),
74 75 76
        platform::errors::OutOfRange(
            "The index should be less than the size of outputs of "
            "operator %s, but got index is %d and size is %d",
77 78 79
            op_.Type(),
            idx,
            op_proto->outputs().size()));
80 81 82
    return op_proto->outputs()[idx].name();
  }

83 84 85
  void ShareDim(const std::string &in,
                const std::string &out,
                size_t i = 0,
86
                size_t j = 0) override {
87 88
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
89 90 91
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
92 93 94 95
                          Inputs(in).size(),
                          i));
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
96 97 98
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
99 100
                          Outputs(out).size(),
                          j));
101

H
hong 已提交
102 103
    std::string input_n = Inputs(in)[i];
    std::string output_n = Outputs(out)[j];
104

105 106
    PADDLE_ENFORCE_NE(input_n,
                      framework::kEmptyVarName,
107 108
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
109 110
    PADDLE_ENFORCE_NE(output_n,
                      framework::kEmptyVarName,
111 112
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
113 114 115 116

    auto *in_var = block_.FindVarRecursive(input_n);
    auto *out_var = block_.FindVarRecursive(output_n);

117
    PADDLE_ENFORCE_EQ(
118 119
        in_var->GetType(),
        out_var->GetType(),
120 121 122
        platform::errors::InvalidArgument(
            "The type of input %s and output %s do not match. The input type "
            "is %s, output type is %s.",
123 124 125
            input_n,
            output_n,
            DataTypeToString(in_var->GetType()),
126
            DataTypeToString(out_var->GetType())));
127 128 129 130

    SetDim(output_n, GetDim(input_n));
  }

H
hong 已提交
131 132 133 134 135 136
  void ShareAllLoD(const std::string &in,
                   const std::string &out) const override {
    auto &in_var_names = op_.Input(in);
    auto &out_var_names = op_.Output(out);

    PADDLE_ENFORCE_EQ(
137 138
        in_var_names.size(),
        out_var_names.size(),
H
hong 已提交
139
        platform::errors::PreconditionNotMet(
T
tianshuo78520a 已提交
140
            "Op [%s]:  Input var number should be equal with output var number",
H
hong 已提交
141 142 143 144 145 146 147 148 149 150 151
            op_.Type()));

    for (size_t i = 0; i < in_var_names.size(); ++i) {
      if (out_var_names[i] == framework::kEmptyVarName) {
        continue;
      }

      auto *in_var = block_.FindVarRecursive(in_var_names[i]);
      auto *out_var = block_.FindVarRecursive(out_var_names[i]);
      if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
          in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
152 153
        VLOG(3) << "input " << in
                << " is not phi::DenseTensor or LoDTensorArray.";
H
hong 已提交
154 155 156 157 158 159
        return;
      }
      out_var->SetLoDLevel(in_var->GetLoDLevel());
    }
  }

160 161 162
  void ShareLoD(const std::string &in,
                const std::string &out,
                size_t i = 0,
Q
Qiao Longfei 已提交
163
                size_t j = 0) const override {
164 165
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
166 167 168
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
169 170 171 172
                          Inputs(in).size(),
                          i));
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
173 174 175
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
176 177 178 179
                          Outputs(out).size(),
                          j));
    PADDLE_ENFORCE_NE(Inputs(in)[i],
                      framework::kEmptyVarName,
180 181
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
182 183
    PADDLE_ENFORCE_NE(Outputs(out)[j],
                      framework::kEmptyVarName,
184 185
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
Q
Qiao Longfei 已提交
186 187
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
C
chengduo 已提交
188 189
    if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
        in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
190 191
      VLOG(3) << "input " << in
              << " is not phi::DenseTensor or LoDTensorArray.";
X
fix  
Xin Pan 已提交
192 193
      return;
    }
194
    out_var->SetLoDLevel(in_var->GetLoDLevel());
Q
Qiao Longfei 已提交
195
  }
D
dzhwinter 已提交
196

197
  int32_t GetLoDLevel(const std::string &in, size_t i = 0) const override {
198 199
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
200 201 202
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, input "
                          "variable %s of operator %s only has %d elements.",
203 204 205 206 207
                          in,
                          op_.Type(),
                          Inputs(in).size()));
    PADDLE_ENFORCE_NE(Inputs(in)[i],
                      framework::kEmptyVarName,
208 209
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] of operator %s is empty.",
210 211 212
                          in,
                          i,
                          op_.Type()));
C
chengduo 已提交
213
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
214
    PADDLE_ENFORCE_NOT_NULL(
215 216 217 218 219 220
        in_var,
        platform::errors::NotFound(
            "The input variable %s[%d] of operator %s is not found.",
            in,
            i,
            op_.Type()));
221
    return in_var->GetLoDLevel();
C
chengduo 已提交
222 223
  }

224 225
  void SetLoDLevel(const std::string &out,
                   int32_t lod_level,
226
                   size_t j = 0) const override {
227 228
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
229 230 231
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, output "
                          "variable %s of operator %s only has %d elements.",
232 233 234 235 236
                          out,
                          op_.Type(),
                          Outputs(out).size()));
    PADDLE_ENFORCE_NE(Outputs(out)[j],
                      framework::kEmptyVarName,
237 238
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] of operator %s is empty.",
239 240 241
                          out,
                          j,
                          op_.Type()));
242
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
243
    PADDLE_ENFORCE_NOT_NULL(
244 245 246 247 248 249
        out_var,
        platform::errors::NotFound(
            "The output variable %s[%d] of operator %s is not found.",
            out,
            j,
            op_.Type()));
250 251 252
    if (lod_level >= 0) {
      out_var->SetLoDLevel(lod_level);
    }
253 254
  }

C
Chen Weihang 已提交
255
  paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
256
  GetInputVarPtrs(const std::string &name) const override {
257
    const std::vector<std::string> arg_names = Inputs(name);
C
Chen Weihang 已提交
258
    paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
259
    res.reserve(arg_names.size());
260 261 262
    std::transform(arg_names.begin(),
                   arg_names.end(),
                   std::back_inserter(res),
263 264 265 266 267 268
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

C
Chen Weihang 已提交
269
  paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
270
  GetOutputVarPtrs(const std::string &name) const override {
271
    const std::vector<std::string> arg_names = Outputs(name);
C
Chen Weihang 已提交
272
    paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
273
    res.reserve(arg_names.size());
274 275 276
    std::transform(arg_names.begin(),
                   arg_names.end(),
                   std::back_inserter(res),
277 278 279 280 281 282
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

X
Xin Pan 已提交
283 284
  DDim GetInputDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
285 286
    PADDLE_ENFORCE_EQ(arg_names.size(),
                      1UL,
287 288 289
                      platform::errors::InvalidArgument(
                          "The input(%s) should hold only one element, but now "
                          "it holds %d elements.",
290 291
                          name,
                          arg_names.size()));
X
Xin Pan 已提交
292 293 294 295 296 297 298 299
    return this->GetDim(arg_names[0]);
  }

  std::vector<DDim> GetInputsDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
    return GetDims(arg_names);
  }

300 301
  bool IsRuntime() const override;

302 303
  bool IsRunMKLDNNKernel() const override;

304 305 306 307
  proto::VarType::Type GetInputVarType(const std::string &name) const override {
    return GetVarType(Inputs(name).at(0));
  }

X
Xin Pan 已提交
308 309 310 311 312 313 314 315 316 317
  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Inputs(name));
  }

  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Outputs(name));
  }

X
Xin Pan 已提交
318
  void SetOutputDim(const std::string &name, const DDim &dim) override {
H
hong 已提交
319
    auto arg_names = Outputs(name);
320 321
    PADDLE_ENFORCE_EQ(arg_names.size(),
                      1UL,
322 323 324
                      platform::errors::InvalidArgument(
                          "The iutput(%s) should hold only one element, but "
                          "now it holds %d elements.",
325 326
                          name,
                          arg_names.size()));
X
Xin Pan 已提交
327 328 329 330 331
    SetDim(arg_names[0], dim);
  }

  void SetOutputsDim(const std::string &name,
                     const std::vector<DDim> &dims) override {
H
hong 已提交
332
    auto names = Outputs(name);
X
Xin Pan 已提交
333 334 335
    SetDims(names, dims);
  }

336 337 338 339 340 341 342 343
  const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const override {
    return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
  }

  const phi::KernelSignature *GetPhiDefaultKernelSignature() const override {
    return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
  }

344
 protected:
X
Xin Pan 已提交
345 346 347 348 349
  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<std::string> &names) const {
    std::vector<proto::VarType::Type> retv;
    retv.resize(names.size());
    std::transform(
350 351 352 353 354
        names.begin(),
        names.end(),
        retv.begin(),
        std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType),
                  this,
X
Xin Pan 已提交
355 356 357 358 359
                  std::placeholders::_1));
    return retv;
  }

  proto::VarType::Type GetVarType(const std::string &name) const;
Q
Qiao Longfei 已提交
360

X
Xin Pan 已提交
361 362
  DDim GetDim(const std::string &name) const {
    auto var = block_.FindVarRecursive(name);
363 364
    PADDLE_ENFORCE_NOT_NULL(
        var, platform::errors::NotFound("Variable %s is not found.", name));
X
Xin Pan 已提交
365 366 367
    DDim res;
    try {
      auto shape = var->GetShape();
368
      res = phi::make_ddim(shape);
X
Xin Pan 已提交
369 370 371 372 373 374 375 376 377 378 379
    } catch (...) {
      VLOG(5) << "GetDim of variable " << name << " error";
      std::rethrow_exception(std::current_exception());
    }
    return res;
  }

  std::vector<DDim> GetDims(const std::vector<std::string> &names) const {
    std::vector<DDim> ret;
    ret.reserve(names.size());
    std::transform(
380 381 382
        names.begin(),
        names.end(),
        std::back_inserter(ret),
X
Xin Pan 已提交
383 384 385
        [this](const std::string &name) { return this->GetDim(name); });
    return ret;
  }
386

X
Xin Pan 已提交
387 388 389 390 391
  void SetDim(const std::string &name, const DDim &dim);

  void SetDims(const std::vector<std::string> &names,
               const std::vector<DDim> &dims) {
    size_t length = names.size();
392 393
    PADDLE_ENFORCE_EQ(length,
                      dims.size(),
394 395 396
                      platform::errors::InvalidArgument(
                          "The input variables number(%d) and input dimensions "
                          "number(%d) do not match.",
397 398
                          length,
                          dims.size()));
X
Xin Pan 已提交
399 400 401 402 403 404 405
    for (size_t i = 0; i < length; ++i) {
      if (names[i] == framework::kEmptyVarName) {
        continue;
      }
      SetDim(names[i], dims[i]);
    }
  }
406

F
fengjiayi 已提交
407 408 409 410
  std::vector<DDim> GetRepeatedDims(const std::string &name) const override;

  void SetRepeatedDims(const std::string &name,
                       const std::vector<DDim> &dims) override;
F
fengjiayi 已提交
411

Y
Yu Yang 已提交
412 413
  const OpDesc &op_;
  const BlockDesc &block_;
414 415
};

416 417 418 419 420 421 422
static void InitRuntimeAttributeMapByOpExtraInfo(const std::string &op_type,
                                                 AttributeMap *runtime_attrs) {
  const auto &extra_attr_map =
      operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type);
  runtime_attrs->insert(extra_attr_map.begin(), extra_attr_map.end());
}

423 424 425 426
OpDesc::OpDesc(const std::string &type,
               const VariableNameMap &inputs,
               const VariableNameMap &outputs,
               const AttributeMap &attrs) {
427
  desc_.set_type(type);
F
fengjiayi 已提交
428 429 430
  inputs_ = inputs;
  outputs_ = outputs;
  attrs_ = attrs;
F
Fix bug  
fengjiayi 已提交
431
  need_update_ = true;
L
liuwei1031 已提交
432
  block_ = nullptr;
433
  InitRuntimeAttributeMapByOpExtraInfo(type, &runtime_attrs_);
F
fengjiayi 已提交
434 435
}

436 437 438 439 440 441
OpDesc::OpDesc(const OpDesc &other) {
  CopyFrom(other);
  block_ = other.block_;
  need_update_ = true;
}

X
Xin Pan 已提交
442 443 444 445
OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) {
  CopyFrom(other);
  block_ = block;
  need_update_ = true;
446 447 448
  for (auto &iter : attrs_) {
    UpdateVarAttr(iter.first, iter.second);
  }
X
Xin Pan 已提交
449 450
}

451
void OpDesc::CopyFrom(const OpDesc &op_desc) {
F
fengjiayi 已提交
452 453 454 455
  desc_.set_type(op_desc.Type());
  inputs_ = op_desc.inputs_;
  outputs_ = op_desc.outputs_;
  attrs_ = op_desc.attrs_;
456 457
  runtime_attrs_ = op_desc.runtime_attrs_;
  // The record of original_id_ is only for auto parallel.
458
  original_id_ = op_desc.original_id_;
459
  if (op_desc.dist_attr_) {
460
    dist_attr_ = std::make_unique<OperatorDistAttr>(*op_desc.dist_attr_);
461
  }
F
fengjiayi 已提交
462 463 464
  need_update_ = true;
}

F
fengjiayi 已提交
465
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
466 467 468 469
    : desc_(desc), need_update_(false) {
  // restore inputs_
  int input_size = desc_.inputs_size();
  for (int i = 0; i < input_size; ++i) {
470
    const proto::OpDesc::Var &var = desc_.inputs(i);
471 472 473 474 475 476 477 478 479 480
    std::vector<std::string> &args = inputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore outputs_
  int output_size = desc_.outputs_size();
  for (int i = 0; i < output_size; ++i) {
481
    const proto::OpDesc::Var &var = desc_.outputs(i);
482 483 484 485 486 487 488 489
    std::vector<std::string> &args = outputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore attrs_
490
  InitRuntimeAttributeMapByOpExtraInfo(desc.type(), &runtime_attrs_);
491
  for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
492
    const std::string &attr_name = attr.name();
493
    // The sub_block referred to by the BLOCK attr hasn't been added
494 495 496 497 498 499 500
    // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS/VAR/VARS attr
    // here.
    auto attr_type = attr.type();
    if (attr_type != proto::AttrType::BLOCK &&
        attr_type != proto::AttrType::BLOCKS &&
        attr_type != proto::AttrType::VAR &&
        attr_type != proto::AttrType::VARS) {
501 502 503 504 505 506
      auto iter = runtime_attrs_.find(attr_name);
      if (iter == runtime_attrs_.end()) {
        attrs_[attr_name] = GetAttrValue(attr);
      } else {
        iter->second = GetAttrValue(attr);
      }
507
    }
508
  }
509
  this->block_ = block;
510 511
}

512 513 514 515 516 517 518 519 520
// Explicitly implement the assign operator, Since the added
// unique_ptr data member does not have the implicit assign operator.
OpDesc &OpDesc::operator=(const OpDesc &other) {
  CopyFrom(other);
  block_ = other.block_;
  need_update_ = true;
  return *this;
}

Y
Yu Yang 已提交
521
proto::OpDesc *OpDesc::Proto() {
522
  Flush();
523
  return &desc_;
F
fengjiayi 已提交
524 525
}

Y
Yu Yang 已提交
526
const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
F
fengjiayi 已提交
527
  auto it = inputs_.find(name);
528
  PADDLE_ENFORCE_NE(
529 530 531 532
      it,
      inputs_.end(),
      platform::errors::NotFound(
          "Input %s cannot be found in operator %s.", name, Type()));
F
fengjiayi 已提交
533 534 535
  return it->second;
}

536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558
std::vector<std::string> OpDesc::Input(const std::string &name,
                                       bool with_attr_var) const {
  // Attribute with VarDesc type will consider as Input
  if (with_attr_var) {
    auto it = attrs_.find(name);
    if (it != attrs_.end() && HasAttrVar(it->second))
      return AttrVarNames(it->second);
  }
  return this->Input(name);
}

VariableNameMap OpDesc::Inputs(bool with_attr_var) const {
  if (!with_attr_var) {
    return inputs_;
  }
  VariableNameMap res = inputs_;
  for (auto &attr : FilterAttrVar(attrs_)) {
    res[attr.first] = AttrVarNames(attr.second);
  }
  return res;
}

std::vector<std::string> OpDesc::InputArgumentNames(bool with_attr_var) const {
F
Update  
fengjiayi 已提交
559
  std::vector<std::string> retv;
560
  for (auto &ipt : this->Inputs(with_attr_var)) {
F
Update  
fengjiayi 已提交
561 562 563 564 565
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
566 567
void OpDesc::SetInput(const std::string &param_name,
                      const std::vector<std::string> &args) {
F
fengjiayi 已提交
568 569 570 571
  need_update_ = true;
  inputs_[param_name] = args;
}

Y
Yu Yang 已提交
572
const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
F
fengjiayi 已提交
573
  auto it = outputs_.find(name);
574
  PADDLE_ENFORCE_NE(
575 576 577 578
      it,
      outputs_.end(),
      platform::errors::NotFound(
          "Output %s cannot be found in operator %s.", name, Type()));
F
fengjiayi 已提交
579 580 581
  return it->second;
}

582 583 584 585
bool OpDesc::HasOutput(const std::string &name) const {
  return outputs_.find(name) != outputs_.end();
}

586 587 588 589 590
bool OpDesc::HasInput(const std::string &name, bool with_attr_var) const {
  if (with_attr_var) {
    auto it = attrs_.find(name);
    if (it != attrs_.end() && HasAttrVar(it->second)) return true;
  }
591 592 593
  return inputs_.find(name) != inputs_.end();
}

Y
Yu Yang 已提交
594
std::vector<std::string> OpDesc::OutputArgumentNames() const {
F
Update  
fengjiayi 已提交
595 596 597 598 599 600 601
  std::vector<std::string> retv;
  for (auto &ipt : this->outputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
602 603
void OpDesc::SetOutput(const std::string &param_name,
                       const std::vector<std::string> &args) {
F
fengjiayi 已提交
604 605 606 607
  need_update_ = true;
  this->outputs_[param_name] = args;
}

608 609 610 611 612
void OpDesc::RemoveOutput(const std::string &name) {
  outputs_.erase(name);
  need_update_ = true;
}

613 614 615 616 617
void OpDesc::RemoveInput(const std::string &name) {
  inputs_.erase(name);
  need_update_ = true;
}

618 619 620 621 622 623 624 625 626 627
bool OpDesc::HasProtoAttr(const std::string &name) const {
  auto &op_info = OpInfoMap::Instance();
  if (op_info.Has(desc_.type())) {
    auto op_info_ptr = op_info.Get(desc_.type());
    if (op_info_ptr.HasOpProtoAndChecker()) {
      const proto::OpProto &proto = op_info_ptr.Proto();
      for (int i = 0; i != proto.attrs_size(); ++i) {
        const proto::OpProto::Attr &attr = proto.attrs(i);
        if (attr.name() == name) {
          return true;
L
luotao1 已提交
628 629
        }
      }
L
luotao1 已提交
630 631 632 633 634
    }
  }
  return false;
}

635 636 637 638
proto::AttrType OpDesc::GetAttrType(const std::string &name,
                                    bool with_attr_var) const {
  auto attr = this->GetAttr(name, with_attr_var);
  return static_cast<proto::AttrType>(attr.index() - 1);
F
fengjiayi 已提交
639 640
}

641
std::vector<std::string> OpDesc::AttrNames(bool with_attr_var) const {
F
fengjiayi 已提交
642 643 644
  std::vector<std::string> retv;
  retv.reserve(attrs_.size());
  for (auto &attr : attrs_) {
645
    if (!with_attr_var && HasAttrVar(attr.second)) continue;
F
fengjiayi 已提交
646 647 648 649 650
    retv.push_back(attr.first);
  }
  return retv;
}

651 652
bool OpDesc::HasAttr(const std::string &name, bool with_attr_var) const {
  auto iter = attrs_.find(name);
653 654 655 656 657 658 659
  bool is_found = true;
  if (iter == attrs_.end()) {
    iter = runtime_attrs_.find(name);
    if (iter == runtime_attrs_.end()) {
      is_found = false;
    }
  }
660 661 662 663 664 665
  if (with_attr_var) {
    return is_found;
  }
  return is_found && !HasAttrVar(iter->second);
}

666 667
void OpDesc::RemoveAttr(const std::string &name) {
  attrs_.erase(name);
668
  runtime_attrs_.erase(name);
669 670 671
  need_update_ = true;
}

Y
Yu Yang 已提交
672
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
673 674
  AttributeMap *attrs_ptr = &(this->attrs_);

675 676
  bool is_runtime_attr = false;

677 678 679 680
  const auto &extra_attr_map =
      operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(Type());
  auto extra_attr_iter = extra_attr_map.find(name);
  if (extra_attr_iter != extra_attr_map.end()) {
681
    is_runtime_attr = true;
682
    attrs_ptr = &(this->runtime_attrs_);
683 684 685 686 687
    // When an attribute is found in both attrs and runtime_attrs, it must
    // be a runtime attribute, so it's value in attrs should be removed.
    if (this->attrs_.find(name) != this->attrs_.end()) {
      this->attrs_.erase(name);
    }
688
  }
M
minqiyang 已提交
689 690 691
  // NOTICE(minqiyang): pybind11 will take the empty list in python as
  // the std::vector<int> type in C++; so we have to change the attr's type
  // here if we meet this issue
R
Ruibiao Chen 已提交
692
  proto::AttrType attr_type = static_cast<proto::AttrType>(v.index() - 1);
M
minqiyang 已提交
693
  if (attr_type == proto::AttrType::INTS &&
694
      PADDLE_GET_CONST(std::vector<int>, v).empty()) {
M
minqiyang 已提交
695
    // Find current attr via attr name and set the correct attribute value
696 697 698 699 700 701
    if (is_runtime_attr) {
      attr_type =
          static_cast<proto::AttrType>(extra_attr_iter->second.index() - 1);
    } else if (HasProtoAttr(name)) {
      attr_type = GetProtoAttr(name).type();
    }
702
    switch (attr_type) {
M
minqiyang 已提交
703
      case proto::AttrType::BOOLEANS: {
M
minqiyang 已提交
704 705
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BOOLEANS";
706
        attrs_ptr->operator[](name) = std::vector<bool>();
M
minqiyang 已提交
707 708 709
        break;
      }
      case proto::AttrType::INTS: {
M
minqiyang 已提交
710 711
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to INTS";
712
        attrs_ptr->operator[](name) = std::vector<int>();
M
minqiyang 已提交
713 714
        break;
      }
715
      case proto::AttrType::LONGS: {
M
minqiyang 已提交
716 717
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from LONGS to LONGS";
718
        attrs_ptr->operator[](name) = std::vector<int64_t>();
719 720
        break;
      }
M
minqiyang 已提交
721
      case proto::AttrType::FLOATS: {
M
minqiyang 已提交
722 723
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to FLOATS";
724
        attrs_ptr->operator[](name) = std::vector<float>();
M
minqiyang 已提交
725 726
        break;
      }
727 728 729 730 731 732
      case proto::AttrType::FLOAT64S: {
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to FLOAT64S";
        this->attrs_[name] = std::vector<double>();
        break;
      }
M
minqiyang 已提交
733
      case proto::AttrType::STRINGS: {
M
minqiyang 已提交
734 735
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to STRINGS";
736
        attrs_ptr->operator[](name) = std::vector<std::string>();
M
minqiyang 已提交
737 738 739
        break;
      }
      case proto::AttrType::BLOCKS: {
M
minqiyang 已提交
740 741
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BLOCKS";
742
        attrs_ptr->operator[](name) = std::vector<BlockDesc *>();
M
minqiyang 已提交
743 744
        return;
      }
M
minqiyang 已提交
745
      default:
746
        PADDLE_THROW(platform::errors::Unimplemented(
747
            "Unsupported attribute type (code %d).", attr_type));
M
minqiyang 已提交
748
    }
M
minqiyang 已提交
749 750
    need_update_ = true;
    return;
M
minqiyang 已提交
751 752
  }

753
  // In order to set bool attr properly
754 755 756 757 758 759 760 761 762 763 764 765 766 767
  if (attr_type == proto::AttrType::INT) {
    if (HasProtoAttr(name) &&
        GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) {
      attrs_ptr->operator[](name) = static_cast<bool>(PADDLE_GET_CONST(int, v));
      need_update_ = true;
      return;
    }
    if (extra_attr_iter != extra_attr_map.end() &&
        static_cast<proto::AttrType>(extra_attr_iter->second.index() - 1) ==
            proto::AttrType::BOOLEAN) {
      attrs_ptr->operator[](name) = static_cast<bool>(PADDLE_GET_CONST(int, v));
      need_update_ = true;
      return;
    }
768 769
  }

770
  attrs_ptr->operator[](name) = v;
771
  VLOG(10) << "op_type: " << Type() << ", attr name: " << name
L
Leo Chen 已提交
772
           << " , type index: " << attrs_ptr->operator[](name).index();
F
fengjiayi 已提交
773 774 775
  need_update_ = true;
}

776 777 778 779 780 781 782 783 784 785
void OpDesc::SetVarAttr(const std::string &name, VarDesc *var) {
  this->attrs_[name] = var;
  need_update_ = true;
}

void OpDesc::SetVarsAttr(const std::string &name, std::vector<VarDesc *> vars) {
  this->attrs_[name] = vars;
  need_update_ = true;
}

A
Abhinav Arora 已提交
786 787
void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) {
  this->attrs_[name] = block;
F
fengjiayi 已提交
788
  need_update_ = true;
F
fengjiayi 已提交
789 790
}

791 792 793 794 795 796
void OpDesc::SetBlocksAttr(const std::string &name,
                           std::vector<BlockDesc *> blocks) {
  this->attrs_[name] = blocks;
  need_update_ = true;
}

Y
Yu Yang 已提交
797
void OpDesc::SetAttrMap(
F
fengjiayi 已提交
798 799 800 801 802
    const std::unordered_map<std::string, Attribute> &attr_map) {
  attrs_ = attr_map;
  need_update_ = true;
}

803 804 805 806 807 808
void OpDesc::SetRuntimeAttrMap(
    const std::unordered_map<std::string, Attribute> &attr_map) {
  runtime_attrs_ = attr_map;
  need_update_ = true;
}

809
Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const {
F
fengjiayi 已提交
810
  auto it = attrs_.find(name);
811 812
  if (it == attrs_.end()) {
    it = runtime_attrs_.find(name);
813 814 815 816
    PADDLE_ENFORCE_NE(
        it,
        runtime_attrs_.end(),
        platform::errors::NotFound("Attribute %s is not found.", name));
817
  }
818 819 820 821
  if (!with_attr_var) {
    PADDLE_ENFORCE_EQ(
        HasAttrVar(it->second),
        false,
822 823 824 825 826
        platform::errors::NotFound(
            "Attribute %s with constant value is not found, but found it with "
            "Variable(s) type, which maybe not supported in some scenarios "
            "currently, such as TensorRT et.al",
            name));
827
  }
F
fengjiayi 已提交
828 829 830
  return it->second;
}

M
minqiyang 已提交
831 832 833
const proto::OpProto::Attr &OpDesc::GetProtoAttr(
    const std::string &name) const {
  const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
M
minqiyang 已提交
834 835 836 837 838 839 840
  for (int i = 0; i != proto.attrs_size(); ++i) {
    const proto::OpProto::Attr &attr = proto.attrs(i);
    if (attr.name() == name) {
      return attr;
    }
  }

841 842
  PADDLE_THROW(platform::errors::NotFound(
      "Attribute %s is not found in proto %s.", name, proto.type()));
M
minqiyang 已提交
843 844
}

Y
yuyang18 已提交
845
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
Y
Fix bug  
yuyang18 已提交
846 847 848 849
  auto it = attrs_.find(name);
  if (it != attrs_.end()) {
    return it->second;
  } else {
Y
yuyang18 已提交
850
    return Attribute();
Y
Fix bug  
yuyang18 已提交
851 852 853
  }
}

G
gongweibao 已提交
854 855
std::vector<int> OpDesc::GetBlocksAttrIds(const std::string &name) const {
  auto it = attrs_.find(name);
856
  PADDLE_ENFORCE_NE(
857 858
      it,
      attrs_.end(),
859 860
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
R
Ruibiao Chen 已提交
861
  auto blocks = PADDLE_GET_CONST(std::vector<BlockDesc *>, it->second);
G
gongweibao 已提交
862 863 864 865 866 867 868 869 870 871

  std::vector<int> ids;
  for (auto n : blocks) {
    ids.push_back(n->ID());
  }

  return ids;
}

int OpDesc::GetBlockAttrId(const std::string &name) const {
F
fengjiayi 已提交
872
  auto it = attrs_.find(name);
873
  PADDLE_ENFORCE_NE(
874 875
      it,
      attrs_.end(),
876 877
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
R
Ruibiao Chen 已提交
878
  return PADDLE_GET_CONST(BlockDesc *, it->second)->ID();
F
fengjiayi 已提交
879 880
}

Y
Yu Yang 已提交
881
const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
F
fengjiayi 已提交
882 883 884
  return attrs_;
}

885 886
const AttributeMap &OpDesc::GetRuntimeAttrMap() const { return runtime_attrs_; }

Y
Yu Yang 已提交
887
void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
Y
Yancey1989 已提交
888 889
  RenameInput(old_name, new_name);
  RenameOutput(old_name, new_name);
F
fengjiayi 已提交
890 891 892
  need_update_ = true;
}

Y
Yu Yang 已提交
893 894
void OpDesc::RenameOutput(const std::string &old_name,
                          const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
895
  for (auto &output : outputs_) {
896 897
    std::replace(
        output.second.begin(), output.second.end(), old_name, new_name);
Y
Yang Yang(Tony) 已提交
898
  }
Y
yuyang18 已提交
899 900 901

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
R
Ruibiao Chen 已提交
902
    auto &op_vars = PADDLE_GET(std::vector<std::string>, it->second);
Y
yuyang18 已提交
903 904 905
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

906 907 908 909
  if (dist_attr_) {
    dist_attr_->rename_output(old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
910 911 912
  need_update_ = true;
}

Y
Yu Yang 已提交
913 914
void OpDesc::RenameInput(const std::string &old_name,
                         const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
915 916 917
  for (auto &input : inputs_) {
    std::replace(input.second.begin(), input.second.end(), old_name, new_name);
  }
Y
Yancey1989 已提交
918 919 920

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
R
Ruibiao Chen 已提交
921
    auto &op_vars = PADDLE_GET(std::vector<std::string>, it->second);
Y
Yancey1989 已提交
922 923 924
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

925 926 927 928
  if (dist_attr_) {
    dist_attr_->rename_input(old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
929 930 931
  need_update_ = true;
}

932
struct SetAttrDescVisitor {
933 934
  explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
  mutable proto::OpDesc::Attr *attr_;
Y
Yu Yang 已提交
935 936
  void operator()(int v) const { attr_->set_i(v); }
  void operator()(float v) const { attr_->set_f(v); }
937
  void operator()(double v) const { attr_->set_float64(v); }
Y
Yu Yang 已提交
938
  void operator()(const std::string &v) const { attr_->set_s(v); }
939 940
  void operator()(const paddle::experimental::Scalar &v) const {
    auto *s = new proto::Scalar;
941
    *s = MakeScalarProto(v);
942 943
    attr_->set_allocated_scalar(s);
  }
Q
QI JUN 已提交
944 945 946 947 948 949 950

  // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
  template <class T,
            class = typename std::enable_if<std::is_same<bool, T>::value>::type>
  void operator()(T b) const {
    attr_->set_b(b);
  }
Y
Yu Yang 已提交
951 952 953 954 955 956 957 958 959 960 961 962 963

  void operator()(const std::vector<int> &v) const {
    VectorToRepeated(v, attr_->mutable_ints());
  }
  void operator()(const std::vector<float> &v) const {
    VectorToRepeated(v, attr_->mutable_floats());
  }
  void operator()(const std::vector<std::string> &v) const {
    VectorToRepeated(v, attr_->mutable_strings());
  }
  void operator()(const std::vector<bool> &v) const {
    VectorToRepeated(v, attr_->mutable_bools());
  }
964 965 966 967 968 969 970 971 972 973 974 975 976

  void operator()(const std::vector<VarDesc *> &v) const {
    std::vector<std::string> var_names;
    for (auto var : v) {
      var_names.emplace_back(var->Name());
    }
    VectorToRepeated(var_names, attr_->mutable_vars_name());
  }

  void operator()(const VarDesc *desc) const {
    attr_->set_var_name(desc->Name());
  }

977 978 979
  void operator()(const std::vector<BlockDesc *> &v) const {
    std::vector<int> blocks_idx;
    for (auto blk : v) {
T
tangwei12 已提交
980
      blocks_idx.push_back(blk->ID());
981 982 983
    }
    VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
  }
T
tangwei12 已提交
984 985 986

  void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }

987
  void operator()(int64_t v) const { attr_->set_l(v); }
T
tangwei12 已提交
988 989 990 991 992

  void operator()(const std::vector<int64_t> &v) const {
    VectorToRepeated(v, attr_->mutable_longs());
  }

993 994 995 996
  void operator()(const std::vector<double> &v) const {
    VectorToRepeated(v, attr_->mutable_float64s());
  }

997 998 999 1000
  void operator()(const std::vector<paddle::experimental::Scalar> &v) const {
    std::vector<proto::Scalar> scalars;
    scalars.reserve(v.size());
    for (const auto &item : v) {
1001
      scalars.emplace_back(MakeScalarProto(item));
1002 1003 1004 1005
    }
    VectorToRepeated(scalars, attr_->mutable_scalars());
  }

R
Ruibiao Chen 已提交
1006
  void operator()(paddle::blank) const {
1007 1008
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported calling method of SetAttrDescVisitor object for "
C
chalsliu 已提交
1009
        "`boost::blank` type."));
1010
  }
Y
Yu Yang 已提交
1011 1012
};

Y
Yu Yang 已提交
1013
void OpDesc::Flush() {
1014
  VLOG(8) << "Flush "
L
Leo Chen 已提交
1015
          << " " << Type() << " " << need_update_;
F
fengjiayi 已提交
1016
  if (need_update_) {
1017
    this->desc_.mutable_inputs()->Clear();
F
fengjiayi 已提交
1018
    for (auto &ipt : inputs_) {
1019
      auto *input = desc_.add_inputs();
F
fengjiayi 已提交
1020 1021 1022 1023
      input->set_parameter(ipt.first);
      VectorToRepeated(ipt.second, input->mutable_arguments());
    }

1024
    this->desc_.mutable_outputs()->Clear();
F
fengjiayi 已提交
1025
    for (auto &opt : outputs_) {
1026
      auto *output = desc_.add_outputs();
F
fengjiayi 已提交
1027 1028 1029 1030
      output->set_parameter(opt.first);
      VectorToRepeated(opt.second, output->mutable_arguments());
    }

1031
    this->desc_.mutable_attrs()->Clear();
1032 1033 1034 1035 1036 1037 1038 1039 1040
    auto set_attr_desc = [this](const std::string &attr_name,
                                const Attribute &attr) -> void {
      auto *attr_desc = desc_.add_attrs();
      attr_desc->set_name(attr_name);
      attr_desc->set_type(static_cast<proto::AttrType>(attr.index() - 1));
      SetAttrDescVisitor visitor(attr_desc);
      paddle::visit(visitor, attr);
    };

L
Leo Chen 已提交
1041 1042
    std::vector<std::pair<std::string, Attribute>> sorted_attrs{attrs_.begin(),
                                                                attrs_.end()};
1043 1044 1045 1046

    std::vector<std::pair<std::string, Attribute>> sorted_runtime_attrs{
        runtime_attrs_.begin(), runtime_attrs_.end()};

L
Leo Chen 已提交
1047 1048 1049 1050 1051
    std::sort(
        sorted_attrs.begin(),
        sorted_attrs.end(),
        [](std::pair<std::string, Attribute> a,
           std::pair<std::string, Attribute> b) { return a.first < b.first; });
1052 1053 1054 1055 1056
    std::sort(
        sorted_runtime_attrs.begin(),
        sorted_runtime_attrs.end(),
        [](std::pair<std::string, Attribute> a,
           std::pair<std::string, Attribute> b) { return a.first < b.first; });
1057

Z
zyfncg 已提交
1058
    for (auto &attr : sorted_runtime_attrs) {
1059 1060
      set_attr_desc(attr.first, attr.second);
    }
Z
zyfncg 已提交
1061
    for (auto &attr : sorted_attrs) {
1062
      set_attr_desc(attr.first, attr.second);
F
fengjiayi 已提交
1063 1064 1065 1066 1067
    }

    need_update_ = false;
  }
}
Y
Yu Yang 已提交
1068

Y
Yu Yang 已提交
1069
void OpDesc::CheckAttrs() {
1070 1071
  PADDLE_ENFORCE_EQ(Type().empty(),
                    false,
1072 1073
                    platform::errors::PreconditionNotMet(
                        "CheckAttrs() can not be called before type is set."));
Y
Yu Yang 已提交
1074 1075 1076 1077 1078 1079
  auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
  if (checker == nullptr) {
    // checker is not configured. That operator could be generated by Paddle,
    // not by users.
    return;
  }
1080
  VLOG(10) << "begin to check attribute of " << Type();
T
tangwei12 已提交
1081
  checker->Check(&attrs_);
1082 1083 1084 1085 1086 1087 1088
  const auto &extra_attr_checkers =
      operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(Type());
  if (!extra_attr_checkers.empty()) {
    for (const auto &extra_checker : extra_attr_checkers) {
      extra_checker(&runtime_attrs_, false);
    }
  }
F
fengjiayi 已提交
1089 1090
}

H
hong 已提交
1091
void OpDesc::InferShape(const BlockDesc &block) {
1092 1093
  try {
    VLOG(3) << "CompileTime infer shape on " << Type();
H
hong 已提交
1094
    auto &op_info = OpInfoMap::Instance().Get(this->Type());
1095
    this->CheckAttrs();
H
hong 已提交
1096
    auto &infer_shape = op_info.infer_shape_;
1097
    PADDLE_ENFORCE_EQ(
1098 1099
        static_cast<bool>(infer_shape),
        true,
1100 1101
        platform::errors::NotFound(
            "Operator %s's infer_shape is not registered.", this->Type()));
1102 1103 1104 1105 1106
    CompileTimeInferShapeContext ctx(*this, block);
    if (VLOG_IS_ON(10)) {
      std::ostringstream sout;
      auto inames = this->InputArgumentNames();
      sout << " From [";
1107 1108
      std::copy(inames.begin(),
                inames.end(),
1109 1110 1111
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "] to [";
      auto onames = this->OutputArgumentNames();
1112 1113
      std::copy(onames.begin(),
                onames.end(),
1114 1115 1116 1117 1118
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "]";
      VLOG(10) << sout.str();
    }
    infer_shape(&ctx);
1119
  } catch (platform::EnforceNotMet &exception) {
1120
    framework::AppendErrorOpHint(Type(), &exception);
1121 1122 1123 1124
    throw std::move(exception);
  } catch (...) {
    std::rethrow_exception(std::current_exception());
  }
Y
Yu Yang 已提交
1125 1126
}

Y
Yu Yang 已提交
1127
void OpDesc::InferVarType(BlockDesc *block) const {
X
Xin Pan 已提交
1128 1129
  // There are a few places that var type can be set.
  // When VarDesc is created, default set to LOD_TENSOR.
T
tianshuo78520a 已提交
1130
  // When output variable is created, default is default set to LOD_TENSOR.
X
Xin Pan 已提交
1131 1132
  // We limit here to be the only place that operator defines its customized
  // var type inference. Hence, we don't do any "default" setting here.
Y
Yu Yang 已提交
1133 1134
  auto &info = OpInfoMap::Instance().Get(this->Type());
  if (info.infer_var_type_) {
M
minqiyang 已提交
1135
    InferVarTypeContext context(this, block);
M
minqiyang 已提交
1136
    info.infer_var_type_(&context);
Y
Yu Yang 已提交
1137 1138 1139
  }
}

1140 1141 1142 1143
const OperatorDistAttr *OpDesc::DistAttr() const {
  return dist_attr_ ? dist_attr_.get() : nullptr;
}

1144 1145 1146 1147
OperatorDistAttr *OpDesc::MutableDistAttr() {
  if (dist_attr_) {
    return dist_attr_.get();
  } else {
1148
    dist_attr_ = std::make_unique<OperatorDistAttr>(*this);
1149 1150 1151 1152 1153 1154 1155 1156 1157
    return dist_attr_.get();
  }
}

void OpDesc::SetDistAttr(const OperatorDistAttr &dist_attr) {
  MutableDistAttr();
  *dist_attr_ = dist_attr;
}

1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202
void OpDesc::UpdateVarAttr(const std::string &name, const Attribute &attr) {
  auto attr_type = static_cast<proto::AttrType>(attr.index() - 1);
  auto type = GetAttrType(name, true);
  if (type == proto::AttrType::VAR) {
    PADDLE_ENFORCE_EQ(
        attr_type,
        type,
        platform::errors::InvalidArgument(
            "Required attr.type == proto::AttrType::VAR, but received %s",
            attr_type));
    auto *var_desc = PADDLE_GET_CONST(VarDesc *, attr);
    VLOG(3) << "Update AttrVar " << name << " with " << var_desc->Name();
    attrs_[name] = FindVarRecursive(var_desc->Name());
  } else if (type == proto::AttrType::VARS) {
    PADDLE_ENFORCE_EQ(
        attr_type,
        type,
        platform::errors::InvalidArgument(
            "Required attr.type == proto::AttrType::VARS, but received %s",
            attr_type));
    auto vars_desc = PADDLE_GET_CONST(std::vector<VarDesc *>, attr);
    std::vector<VarDesc *> new_val;
    for (auto &var_desc : vars_desc) {
      VLOG(3) << "Update AttrVars " << name << " with " << var_desc->Name();
      new_val.emplace_back(FindVarRecursive(var_desc->Name()));
    }
    attrs_[name] = std::move(new_val);
  }
}

VarDesc *OpDesc::FindVarRecursive(const std::string &name) {
  auto *cur_block = block_;
  while (cur_block != nullptr && cur_block->ID() >= 0) {
    auto *var = block_->FindVar(name);
    if (var != nullptr) {
      return var;
    }
    cur_block = cur_block->ParentBlock();
  }
  PADDLE_THROW(platform::errors::NotFound(
      "Not found Var(%s) from Block(%d) back into global Block.",
      name,
      block_->ID()));
}

1203
CompileTimeInferShapeContext::CompileTimeInferShapeContext(
Y
Yu Yang 已提交
1204
    const OpDesc &op, const BlockDesc &block)
1205 1206 1207
    : op_(op), block_(block) {}

bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
1208 1209
  auto inputs = op_.Inputs(/*with_attr_var=*/true);
  if (inputs.find(name) == inputs.end()) {
1210 1211
    return false;
  }
1212 1213
  const std::vector<std::string> &input_names =
      op_.Input(name, /*with_attr_var=*/true);
1214 1215 1216 1217
  auto length = input_names.size();
  if (length == 0) {
    return false;
  }
1218
  PADDLE_ENFORCE_EQ(
1219 1220
      length,
      1UL,
1221 1222
      platform::errors::InvalidArgument("Input(%s) should have only one value, "
                                        "but it has %d values now.",
1223 1224
                                        name,
                                        length));
1225 1226 1227 1228
  return block_.HasVarRecursive(input_names[0]);
}

bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
1229 1230 1231
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
1232 1233 1234 1235 1236
  const std::vector<std::string> &output_names = op_.Output(name);
  auto length = output_names.size();
  if (length == 0) {
    return false;
  }
1237 1238
  PADDLE_ENFORCE_EQ(length,
                    1UL,
1239 1240 1241
                    platform::errors::InvalidArgument(
                        "Output(%s) should have only one value, "
                        "but it has %d values now.",
1242 1243
                        name,
                        length));
1244 1245 1246
  return block_.HasVarRecursive(output_names[0]);
}

1247
bool CompileTimeInferShapeContext::HasAttr(const std::string &name) const {
1248
  return op_.HasAttr(name, /*with_attr_var=*/false);
1249 1250
}

1251
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
1252 1253
  auto inputs = op_.Inputs(/*with_attr_var=*/true);
  if (inputs.find(name) == inputs.end()) {
1254 1255
    return false;
  }
1256 1257
  const std::vector<std::string> &input_names =
      op_.Input(name, /*with_attr_var=*/true);
1258 1259 1260 1261 1262 1263 1264 1265 1266
  if (input_names.empty()) {
    return false;
  }
  for (auto &input : input_names) {
    if (!block_.HasVarRecursive(input)) return false;
  }
  return true;
}

1267 1268
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name,
                                              bool allow_null) const {
1269 1270 1271
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
1272 1273 1274 1275
  const std::vector<std::string> &output_names = op_.Output(name);
  if (output_names.empty()) {
    return false;
  }
Y
YuanRisheng 已提交
1276
  if (!allow_null) {
1277 1278 1279
    for (auto &output : output_names) {
      if (!block_.HasVarRecursive(output)) return false;
    }
1280
  }
Y
YuanRisheng 已提交
1281
  return true;
1282 1283 1284
}

AttrReader CompileTimeInferShapeContext::Attrs() const {
1285
  return AttrReader(op_.GetAttrMap(), op_.GetRuntimeAttrMap());
1286 1287
}

H
hong 已提交
1288
std::vector<std::string> CompileTimeInferShapeContext::Inputs(
1289
    const std::string &name) const {
1290
  return op_.Input(name, /*with_attr_var=*/true);
1291 1292
}

H
hong 已提交
1293
std::vector<std::string> CompileTimeInferShapeContext::Outputs(
1294 1295 1296 1297
    const std::string &name) const {
  return op_.Output(name);
}

F
fengjiayi 已提交
1298
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
F
fengjiayi 已提交
1299 1300
    const std::string &name) const {
  auto var = block_.FindVarRecursive(name);
1301 1302
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
1303 1304 1305 1306
  std::vector<DDim> res;
  try {
    auto shapes = var->GetShapes();
    for (const auto &s : shapes) {
1307
      res.push_back(phi::make_ddim(s));
F
fengjiayi 已提交
1308 1309
    }
  } catch (...) {
M
minqiyang 已提交
1310
    VLOG(5) << "GetRepeatedDim of variable " << name << " error.";
F
fengjiayi 已提交
1311 1312 1313
    std::rethrow_exception(std::current_exception());
  }
  return res;
1314 1315 1316 1317
}

void CompileTimeInferShapeContext::SetDim(const std::string &name,
                                          const DDim &dim) {
F
fengjiayi 已提交
1318
  block_.FindVarRecursive(name)->SetShape(vectorize(dim));
1319
}
F
fengjiayi 已提交
1320 1321 1322 1323

void CompileTimeInferShapeContext::SetRepeatedDims(
    const std::string &name, const std::vector<DDim> &dims) {
  auto var = block_.FindVarRecursive(name);
1324 1325
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
1326
  std::vector<std::vector<int64_t>> dim_vec(dims.size());
1327
  std::transform(dims.begin(), dims.end(), dim_vec.begin(), phi::vectorize<>);
F
fengjiayi 已提交
1328
  var->SetShapes(dim_vec);
1329
}
F
fengjiayi 已提交
1330

1331 1332
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }

1333 1334
bool CompileTimeInferShapeContext::IsRunMKLDNNKernel() const { return false; }

1335
proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
1336 1337 1338
    const std::string &name) const {
  return block_.FindVarRecursive(name)->GetType();
}
1339

1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355
std::vector<std::string> AttrVarNames(const Attribute &attr) {
  std::vector<std::string> vars_name;
  if (IsAttrVar(attr)) {
    vars_name.emplace_back(PADDLE_GET_CONST(VarDesc *, attr)->Name());
  } else if (IsAttrVars(attr)) {
    for (auto &iter : PADDLE_GET_CONST(std::vector<VarDesc *>, attr)) {
      vars_name.emplace_back(iter->Name());
    }
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Unsupported Attribute value type `%s` for AttrVarNames",
        platform::demangle(attr.type().name())));
  }
  return vars_name;
}

F
fengjiayi 已提交
1356 1357
}  // namespace framework
}  // namespace paddle