op_desc.cc 42.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_desc.h"
16

17
#include <string>
18

19
#include "glog/logging.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/framework/block_desc.h"
21
#include "paddle/fluid/framework/op_call_stack.h"
Y
yuyang18 已提交
22
#include "paddle/fluid/framework/op_proto_maker.h"
Y
Yi Wang 已提交
23 24
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
M
minqiyang 已提交
25
#include "paddle/fluid/framework/var_type_inference.h"
26
#include "paddle/fluid/operators/ops_extra_info.h"
R
Ruibiao Chen 已提交
27
#include "paddle/utils/blank.h"
Y
Yu Yang 已提交
28

F
fengjiayi 已提交
29 30 31
namespace paddle {
namespace framework {

32 33
class CompileTimeInferShapeContext : public InferShapeContext {
 public:
Y
Yu Yang 已提交
34
  CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
35 36 37 38 39

  bool HasInput(const std::string &name) const override;

  bool HasOutput(const std::string &name) const override;

40 41
  bool HasAttr(const std::string &name) const override;

42 43
  bool HasInputs(const std::string &name) const override;

44 45
  bool HasOutputs(const std::string &name,
                  bool allow_null = false) const override;
46 47 48

  AttrReader Attrs() const override;

H
hong 已提交
49
  std::vector<std::string> Inputs(const std::string &name) const override;
50

H
hong 已提交
51
  std::vector<std::string> Outputs(const std::string &name) const override;
52

53 54 55
  std::string GetInputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
56 57
    PADDLE_ENFORCE_LT(idx,
                      op_proto->inputs().size(),
58 59 60
                      platform::errors::OutOfRange(
                          "The index should be less than the size of inputs of "
                          "operator %s, but got index is %d and size is %d",
61 62 63
                          op_.Type(),
                          idx,
                          op_proto->inputs().size()));
64 65 66 67 68 69 70
    return op_proto->inputs()[idx].name();
  }

  std::string GetOutputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
    PADDLE_ENFORCE_LT(
71 72
        idx,
        op_proto->outputs().size(),
73 74 75
        platform::errors::OutOfRange(
            "The index should be less than the size of outputs of "
            "operator %s, but got index is %d and size is %d",
76 77 78
            op_.Type(),
            idx,
            op_proto->outputs().size()));
79 80 81
    return op_proto->outputs()[idx].name();
  }

82 83 84
  void ShareDim(const std::string &in,
                const std::string &out,
                size_t i = 0,
85
                size_t j = 0) override {
86 87
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
88 89 90
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
91 92 93 94
                          Inputs(in).size(),
                          i));
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
95 96 97
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
98 99
                          Outputs(out).size(),
                          j));
100

H
hong 已提交
101 102
    std::string input_n = Inputs(in)[i];
    std::string output_n = Outputs(out)[j];
103

104 105
    PADDLE_ENFORCE_NE(input_n,
                      framework::kEmptyVarName,
106 107
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
108 109
    PADDLE_ENFORCE_NE(output_n,
                      framework::kEmptyVarName,
110 111
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
112 113 114 115

    auto *in_var = block_.FindVarRecursive(input_n);
    auto *out_var = block_.FindVarRecursive(output_n);

116
    PADDLE_ENFORCE_EQ(
117 118
        in_var->GetType(),
        out_var->GetType(),
119 120 121
        platform::errors::InvalidArgument(
            "The type of input %s and output %s do not match. The input type "
            "is %s, output type is %s.",
122 123 124
            input_n,
            output_n,
            DataTypeToString(in_var->GetType()),
125
            DataTypeToString(out_var->GetType())));
126 127 128 129

    SetDim(output_n, GetDim(input_n));
  }

H
hong 已提交
130 131 132 133 134 135
  void ShareAllLoD(const std::string &in,
                   const std::string &out) const override {
    auto &in_var_names = op_.Input(in);
    auto &out_var_names = op_.Output(out);

    PADDLE_ENFORCE_EQ(
136 137
        in_var_names.size(),
        out_var_names.size(),
H
hong 已提交
138
        platform::errors::PreconditionNotMet(
T
tianshuo78520a 已提交
139
            "Op [%s]:  Input var number should be equal with output var number",
H
hong 已提交
140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
            op_.Type()));

    for (size_t i = 0; i < in_var_names.size(); ++i) {
      if (out_var_names[i] == framework::kEmptyVarName) {
        continue;
      }

      auto *in_var = block_.FindVarRecursive(in_var_names[i]);
      auto *out_var = block_.FindVarRecursive(out_var_names[i]);
      if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
          in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
        VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
        return;
      }
      out_var->SetLoDLevel(in_var->GetLoDLevel());
    }
  }

158 159 160
  void ShareLoD(const std::string &in,
                const std::string &out,
                size_t i = 0,
Q
Qiao Longfei 已提交
161
                size_t j = 0) const override {
162 163
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
164 165 166
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
167 168 169 170
                          Inputs(in).size(),
                          i));
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
171 172 173
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
174 175 176 177
                          Outputs(out).size(),
                          j));
    PADDLE_ENFORCE_NE(Inputs(in)[i],
                      framework::kEmptyVarName,
178 179
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
180 181
    PADDLE_ENFORCE_NE(Outputs(out)[j],
                      framework::kEmptyVarName,
182 183
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
Q
Qiao Longfei 已提交
184 185
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
C
chengduo 已提交
186 187
    if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
        in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
188
      VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
X
fix  
Xin Pan 已提交
189 190
      return;
    }
191
    out_var->SetLoDLevel(in_var->GetLoDLevel());
Q
Qiao Longfei 已提交
192
  }
D
dzhwinter 已提交
193

194
  int32_t GetLoDLevel(const std::string &in, size_t i = 0) const override {
195 196
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
197 198 199
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, input "
                          "variable %s of operator %s only has %d elements.",
200 201 202 203 204
                          in,
                          op_.Type(),
                          Inputs(in).size()));
    PADDLE_ENFORCE_NE(Inputs(in)[i],
                      framework::kEmptyVarName,
205 206
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] of operator %s is empty.",
207 208 209
                          in,
                          i,
                          op_.Type()));
C
chengduo 已提交
210
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
211
    PADDLE_ENFORCE_NOT_NULL(
212 213 214 215 216 217
        in_var,
        platform::errors::NotFound(
            "The input variable %s[%d] of operator %s is not found.",
            in,
            i,
            op_.Type()));
218
    return in_var->GetLoDLevel();
C
chengduo 已提交
219 220
  }

221 222
  void SetLoDLevel(const std::string &out,
                   int32_t lod_level,
223
                   size_t j = 0) const override {
224 225
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
226 227 228
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, output "
                          "variable %s of operator %s only has %d elements.",
229 230 231 232 233
                          out,
                          op_.Type(),
                          Outputs(out).size()));
    PADDLE_ENFORCE_NE(Outputs(out)[j],
                      framework::kEmptyVarName,
234 235
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] of operator %s is empty.",
236 237 238
                          out,
                          j,
                          op_.Type()));
239
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
240
    PADDLE_ENFORCE_NOT_NULL(
241 242 243 244 245 246
        out_var,
        platform::errors::NotFound(
            "The output variable %s[%d] of operator %s is not found.",
            out,
            j,
            op_.Type()));
247 248 249
    if (lod_level >= 0) {
      out_var->SetLoDLevel(lod_level);
    }
250 251
  }

C
Chen Weihang 已提交
252
  paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
253
  GetInputVarPtrs(const std::string &name) const override {
254
    const std::vector<std::string> arg_names = Inputs(name);
C
Chen Weihang 已提交
255
    paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
256
    res.reserve(arg_names.size());
257 258 259
    std::transform(arg_names.begin(),
                   arg_names.end(),
                   std::back_inserter(res),
260 261 262 263 264 265
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

C
Chen Weihang 已提交
266
  paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
267
  GetOutputVarPtrs(const std::string &name) const override {
268
    const std::vector<std::string> arg_names = Outputs(name);
C
Chen Weihang 已提交
269
    paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
270
    res.reserve(arg_names.size());
271 272 273
    std::transform(arg_names.begin(),
                   arg_names.end(),
                   std::back_inserter(res),
274 275 276 277 278 279
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

X
Xin Pan 已提交
280 281
  DDim GetInputDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
282 283
    PADDLE_ENFORCE_EQ(arg_names.size(),
                      1UL,
284 285 286
                      platform::errors::InvalidArgument(
                          "The input(%s) should hold only one element, but now "
                          "it holds %d elements.",
287 288
                          name,
                          arg_names.size()));
X
Xin Pan 已提交
289 290 291 292 293 294 295 296
    return this->GetDim(arg_names[0]);
  }

  std::vector<DDim> GetInputsDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
    return GetDims(arg_names);
  }

297 298
  bool IsRuntime() const override;

299 300
  bool IsRunMKLDNNKernel() const override;

301 302 303 304
  proto::VarType::Type GetInputVarType(const std::string &name) const override {
    return GetVarType(Inputs(name).at(0));
  }

X
Xin Pan 已提交
305 306 307 308 309 310 311 312 313 314
  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Inputs(name));
  }

  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Outputs(name));
  }

X
Xin Pan 已提交
315
  void SetOutputDim(const std::string &name, const DDim &dim) override {
H
hong 已提交
316
    auto arg_names = Outputs(name);
317 318
    PADDLE_ENFORCE_EQ(arg_names.size(),
                      1UL,
319 320 321
                      platform::errors::InvalidArgument(
                          "The iutput(%s) should hold only one element, but "
                          "now it holds %d elements.",
322 323
                          name,
                          arg_names.size()));
X
Xin Pan 已提交
324 325 326 327 328
    SetDim(arg_names[0], dim);
  }

  void SetOutputsDim(const std::string &name,
                     const std::vector<DDim> &dims) override {
H
hong 已提交
329
    auto names = Outputs(name);
X
Xin Pan 已提交
330 331 332
    SetDims(names, dims);
  }

333 334 335 336 337 338 339 340
  const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const override {
    return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
  }

  const phi::KernelSignature *GetPhiDefaultKernelSignature() const override {
    return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
  }

341
 protected:
X
Xin Pan 已提交
342 343 344 345 346
  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<std::string> &names) const {
    std::vector<proto::VarType::Type> retv;
    retv.resize(names.size());
    std::transform(
347 348 349 350 351
        names.begin(),
        names.end(),
        retv.begin(),
        std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType),
                  this,
X
Xin Pan 已提交
352 353 354 355 356
                  std::placeholders::_1));
    return retv;
  }

  proto::VarType::Type GetVarType(const std::string &name) const;
Q
Qiao Longfei 已提交
357

X
Xin Pan 已提交
358 359
  DDim GetDim(const std::string &name) const {
    auto var = block_.FindVarRecursive(name);
360 361
    PADDLE_ENFORCE_NOT_NULL(
        var, platform::errors::NotFound("Variable %s is not found.", name));
X
Xin Pan 已提交
362 363 364
    DDim res;
    try {
      auto shape = var->GetShape();
365
      res = shape.empty() ? phi::make_ddim({0UL}) : phi::make_ddim(shape);
X
Xin Pan 已提交
366 367 368 369 370 371 372 373 374 375 376
    } catch (...) {
      VLOG(5) << "GetDim of variable " << name << " error";
      std::rethrow_exception(std::current_exception());
    }
    return res;
  }

  std::vector<DDim> GetDims(const std::vector<std::string> &names) const {
    std::vector<DDim> ret;
    ret.reserve(names.size());
    std::transform(
377 378 379
        names.begin(),
        names.end(),
        std::back_inserter(ret),
X
Xin Pan 已提交
380 381 382
        [this](const std::string &name) { return this->GetDim(name); });
    return ret;
  }
383

X
Xin Pan 已提交
384 385 386 387 388
  void SetDim(const std::string &name, const DDim &dim);

  void SetDims(const std::vector<std::string> &names,
               const std::vector<DDim> &dims) {
    size_t length = names.size();
389 390
    PADDLE_ENFORCE_EQ(length,
                      dims.size(),
391 392 393
                      platform::errors::InvalidArgument(
                          "The input variables number(%d) and input dimensions "
                          "number(%d) do not match.",
394 395
                          length,
                          dims.size()));
X
Xin Pan 已提交
396 397 398 399 400 401 402
    for (size_t i = 0; i < length; ++i) {
      if (names[i] == framework::kEmptyVarName) {
        continue;
      }
      SetDim(names[i], dims[i]);
    }
  }
403

F
fengjiayi 已提交
404 405 406 407
  std::vector<DDim> GetRepeatedDims(const std::string &name) const override;

  void SetRepeatedDims(const std::string &name,
                       const std::vector<DDim> &dims) override;
F
fengjiayi 已提交
408

Y
Yu Yang 已提交
409 410
  const OpDesc &op_;
  const BlockDesc &block_;
411 412
};

413 414 415 416 417 418 419
static void InitRuntimeAttributeMapByOpExtraInfo(const std::string &op_type,
                                                 AttributeMap *runtime_attrs) {
  const auto &extra_attr_map =
      operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(op_type);
  runtime_attrs->insert(extra_attr_map.begin(), extra_attr_map.end());
}

420 421 422 423
OpDesc::OpDesc(const std::string &type,
               const VariableNameMap &inputs,
               const VariableNameMap &outputs,
               const AttributeMap &attrs) {
424
  desc_.set_type(type);
F
fengjiayi 已提交
425 426 427
  inputs_ = inputs;
  outputs_ = outputs;
  attrs_ = attrs;
F
Fix bug  
fengjiayi 已提交
428
  need_update_ = true;
L
liuwei1031 已提交
429
  block_ = nullptr;
430
  InitRuntimeAttributeMapByOpExtraInfo(type, &runtime_attrs_);
F
fengjiayi 已提交
431 432
}

433 434 435 436 437 438
OpDesc::OpDesc(const OpDesc &other) {
  CopyFrom(other);
  block_ = other.block_;
  need_update_ = true;
}

X
Xin Pan 已提交
439 440 441 442
OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) {
  CopyFrom(other);
  block_ = block;
  need_update_ = true;
443 444 445
  for (auto &iter : attrs_) {
    UpdateVarAttr(iter.first, iter.second);
  }
X
Xin Pan 已提交
446 447
}

448
void OpDesc::CopyFrom(const OpDesc &op_desc) {
F
fengjiayi 已提交
449 450 451 452
  desc_.set_type(op_desc.Type());
  inputs_ = op_desc.inputs_;
  outputs_ = op_desc.outputs_;
  attrs_ = op_desc.attrs_;
453 454
  runtime_attrs_ = op_desc.runtime_attrs_;
  // The record of original_id_ is only for auto parallel.
455
  original_id_ = op_desc.original_id_;
456 457 458
  if (op_desc.dist_attr_) {
    dist_attr_.reset(new OperatorDistAttr(*op_desc.dist_attr_));
  }
F
fengjiayi 已提交
459 460 461
  need_update_ = true;
}

F
fengjiayi 已提交
462
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
463 464 465 466
    : desc_(desc), need_update_(false) {
  // restore inputs_
  int input_size = desc_.inputs_size();
  for (int i = 0; i < input_size; ++i) {
467
    const proto::OpDesc::Var &var = desc_.inputs(i);
468 469 470 471 472 473 474 475 476 477
    std::vector<std::string> &args = inputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore outputs_
  int output_size = desc_.outputs_size();
  for (int i = 0; i < output_size; ++i) {
478
    const proto::OpDesc::Var &var = desc_.outputs(i);
479 480 481 482 483 484 485 486
    std::vector<std::string> &args = outputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore attrs_
487
  InitRuntimeAttributeMapByOpExtraInfo(desc.type(), &runtime_attrs_);
488
  for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
489
    const std::string &attr_name = attr.name();
490
    // The sub_block referred to by the BLOCK attr hasn't been added
491 492 493 494 495 496 497
    // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS/VAR/VARS attr
    // here.
    auto attr_type = attr.type();
    if (attr_type != proto::AttrType::BLOCK &&
        attr_type != proto::AttrType::BLOCKS &&
        attr_type != proto::AttrType::VAR &&
        attr_type != proto::AttrType::VARS) {
498 499 500 501 502 503
      auto iter = runtime_attrs_.find(attr_name);
      if (iter == runtime_attrs_.end()) {
        attrs_[attr_name] = GetAttrValue(attr);
      } else {
        iter->second = GetAttrValue(attr);
      }
504
    }
505
  }
506
  this->block_ = block;
507 508
}

509 510 511 512 513 514 515 516 517
// Explicitly implement the assign operator, Since the added
// unique_ptr data member does not have the implicit assign operator.
OpDesc &OpDesc::operator=(const OpDesc &other) {
  CopyFrom(other);
  block_ = other.block_;
  need_update_ = true;
  return *this;
}

Y
Yu Yang 已提交
518
proto::OpDesc *OpDesc::Proto() {
519
  Flush();
520
  return &desc_;
F
fengjiayi 已提交
521 522
}

Y
Yu Yang 已提交
523
const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
F
fengjiayi 已提交
524
  auto it = inputs_.find(name);
525
  PADDLE_ENFORCE_NE(
526 527 528 529
      it,
      inputs_.end(),
      platform::errors::NotFound(
          "Input %s cannot be found in operator %s.", name, Type()));
F
fengjiayi 已提交
530 531 532
  return it->second;
}

533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555
std::vector<std::string> OpDesc::Input(const std::string &name,
                                       bool with_attr_var) const {
  // Attribute with VarDesc type will consider as Input
  if (with_attr_var) {
    auto it = attrs_.find(name);
    if (it != attrs_.end() && HasAttrVar(it->second))
      return AttrVarNames(it->second);
  }
  return this->Input(name);
}

VariableNameMap OpDesc::Inputs(bool with_attr_var) const {
  if (!with_attr_var) {
    return inputs_;
  }
  VariableNameMap res = inputs_;
  for (auto &attr : FilterAttrVar(attrs_)) {
    res[attr.first] = AttrVarNames(attr.second);
  }
  return res;
}

std::vector<std::string> OpDesc::InputArgumentNames(bool with_attr_var) const {
F
Update  
fengjiayi 已提交
556
  std::vector<std::string> retv;
557
  for (auto &ipt : this->Inputs(with_attr_var)) {
F
Update  
fengjiayi 已提交
558 559 560 561 562
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
563 564
void OpDesc::SetInput(const std::string &param_name,
                      const std::vector<std::string> &args) {
F
fengjiayi 已提交
565 566 567 568
  need_update_ = true;
  inputs_[param_name] = args;
}

Y
Yu Yang 已提交
569
const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
F
fengjiayi 已提交
570
  auto it = outputs_.find(name);
571
  PADDLE_ENFORCE_NE(
572 573 574 575
      it,
      outputs_.end(),
      platform::errors::NotFound(
          "Output %s cannot be found in operator %s.", name, Type()));
F
fengjiayi 已提交
576 577 578
  return it->second;
}

579 580 581 582
bool OpDesc::HasOutput(const std::string &name) const {
  return outputs_.find(name) != outputs_.end();
}

Y
Yu Yang 已提交
583
std::vector<std::string> OpDesc::OutputArgumentNames() const {
F
Update  
fengjiayi 已提交
584 585 586 587 588 589 590
  std::vector<std::string> retv;
  for (auto &ipt : this->outputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
591 592
void OpDesc::SetOutput(const std::string &param_name,
                       const std::vector<std::string> &args) {
F
fengjiayi 已提交
593 594 595 596
  need_update_ = true;
  this->outputs_[param_name] = args;
}

597 598 599 600 601
void OpDesc::RemoveOutput(const std::string &name) {
  outputs_.erase(name);
  need_update_ = true;
}

602 603 604 605 606
void OpDesc::RemoveInput(const std::string &name) {
  inputs_.erase(name);
  need_update_ = true;
}

607 608 609 610 611 612 613 614 615 616
bool OpDesc::HasProtoAttr(const std::string &name) const {
  auto &op_info = OpInfoMap::Instance();
  if (op_info.Has(desc_.type())) {
    auto op_info_ptr = op_info.Get(desc_.type());
    if (op_info_ptr.HasOpProtoAndChecker()) {
      const proto::OpProto &proto = op_info_ptr.Proto();
      for (int i = 0; i != proto.attrs_size(); ++i) {
        const proto::OpProto::Attr &attr = proto.attrs(i);
        if (attr.name() == name) {
          return true;
L
luotao1 已提交
617 618
        }
      }
L
luotao1 已提交
619 620 621 622 623
    }
  }
  return false;
}

624 625 626 627
proto::AttrType OpDesc::GetAttrType(const std::string &name,
                                    bool with_attr_var) const {
  auto attr = this->GetAttr(name, with_attr_var);
  return static_cast<proto::AttrType>(attr.index() - 1);
F
fengjiayi 已提交
628 629
}

630
std::vector<std::string> OpDesc::AttrNames(bool with_attr_var) const {
F
fengjiayi 已提交
631 632 633
  std::vector<std::string> retv;
  retv.reserve(attrs_.size());
  for (auto &attr : attrs_) {
634
    if (!with_attr_var && HasAttrVar(attr.second)) continue;
F
fengjiayi 已提交
635 636 637 638 639
    retv.push_back(attr.first);
  }
  return retv;
}

640 641
bool OpDesc::HasAttr(const std::string &name, bool with_attr_var) const {
  auto iter = attrs_.find(name);
642 643 644 645 646 647 648
  bool is_found = true;
  if (iter == attrs_.end()) {
    iter = runtime_attrs_.find(name);
    if (iter == runtime_attrs_.end()) {
      is_found = false;
    }
  }
649 650 651 652 653 654
  if (with_attr_var) {
    return is_found;
  }
  return is_found && !HasAttrVar(iter->second);
}

655 656
void OpDesc::RemoveAttr(const std::string &name) {
  attrs_.erase(name);
657
  runtime_attrs_.erase(name);
658 659 660
  need_update_ = true;
}

Y
Yu Yang 已提交
661
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
662 663
  AttributeMap *attrs_ptr = &(this->attrs_);

664 665
  bool is_runtime_attr = false;

666 667 668 669
  const auto &extra_attr_map =
      operators::ExtraInfoUtils::Instance().GetExtraAttrsMap(Type());
  auto extra_attr_iter = extra_attr_map.find(name);
  if (extra_attr_iter != extra_attr_map.end()) {
670
    is_runtime_attr = true;
671 672
    attrs_ptr = &(this->runtime_attrs_);
  }
M
minqiyang 已提交
673 674 675
  // NOTICE(minqiyang): pybind11 will take the empty list in python as
  // the std::vector<int> type in C++; so we have to change the attr's type
  // here if we meet this issue
R
Ruibiao Chen 已提交
676
  proto::AttrType attr_type = static_cast<proto::AttrType>(v.index() - 1);
M
minqiyang 已提交
677
  if (attr_type == proto::AttrType::INTS &&
R
Ruibiao Chen 已提交
678
      PADDLE_GET_CONST(std::vector<int>, v).size() == 0u) {
M
minqiyang 已提交
679
    // Find current attr via attr name and set the correct attribute value
680 681 682 683 684
    auto attr_type =
        is_runtime_attr
            ? static_cast<proto::AttrType>(extra_attr_iter->second.index() - 1)
            : GetProtoAttr(name).type();
    switch (attr_type) {
M
minqiyang 已提交
685
      case proto::AttrType::BOOLEANS: {
M
minqiyang 已提交
686 687
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BOOLEANS";
688
        attrs_ptr->operator[](name) = std::vector<bool>();
M
minqiyang 已提交
689 690 691
        break;
      }
      case proto::AttrType::INTS: {
M
minqiyang 已提交
692 693
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to INTS";
694
        attrs_ptr->operator[](name) = std::vector<int>();
M
minqiyang 已提交
695 696
        break;
      }
697
      case proto::AttrType::LONGS: {
M
minqiyang 已提交
698 699
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from LONGS to LONGS";
700
        attrs_ptr->operator[](name) = std::vector<int64_t>();
701 702
        break;
      }
M
minqiyang 已提交
703
      case proto::AttrType::FLOATS: {
M
minqiyang 已提交
704 705
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to FLOATS";
706
        attrs_ptr->operator[](name) = std::vector<float>();
M
minqiyang 已提交
707 708
        break;
      }
709 710 711 712 713 714
      case proto::AttrType::FLOAT64S: {
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to FLOAT64S";
        this->attrs_[name] = std::vector<double>();
        break;
      }
M
minqiyang 已提交
715
      case proto::AttrType::STRINGS: {
M
minqiyang 已提交
716 717
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to STRINGS";
718
        attrs_ptr->operator[](name) = std::vector<std::string>();
M
minqiyang 已提交
719 720 721
        break;
      }
      case proto::AttrType::BLOCKS: {
M
minqiyang 已提交
722 723
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BLOCKS";
724
        attrs_ptr->operator[](name) = std::vector<BlockDesc *>();
M
minqiyang 已提交
725 726
        return;
      }
M
minqiyang 已提交
727
      default:
728
        PADDLE_THROW(platform::errors::Unimplemented(
729
            "Unsupported attribute type (code %d).", attr_type));
M
minqiyang 已提交
730
    }
M
minqiyang 已提交
731 732
    need_update_ = true;
    return;
M
minqiyang 已提交
733 734
  }

735
  // In order to set bool attr properly
736 737 738 739 740 741 742 743 744 745 746 747 748 749
  if (attr_type == proto::AttrType::INT) {
    if (HasProtoAttr(name) &&
        GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) {
      attrs_ptr->operator[](name) = static_cast<bool>(PADDLE_GET_CONST(int, v));
      need_update_ = true;
      return;
    }
    if (extra_attr_iter != extra_attr_map.end() &&
        static_cast<proto::AttrType>(extra_attr_iter->second.index() - 1) ==
            proto::AttrType::BOOLEAN) {
      attrs_ptr->operator[](name) = static_cast<bool>(PADDLE_GET_CONST(int, v));
      need_update_ = true;
      return;
    }
750 751
  }

752
  attrs_ptr->operator[](name) = v;
F
fengjiayi 已提交
753 754 755
  need_update_ = true;
}

756 757 758 759 760 761 762 763 764 765
void OpDesc::SetVarAttr(const std::string &name, VarDesc *var) {
  this->attrs_[name] = var;
  need_update_ = true;
}

void OpDesc::SetVarsAttr(const std::string &name, std::vector<VarDesc *> vars) {
  this->attrs_[name] = vars;
  need_update_ = true;
}

A
Abhinav Arora 已提交
766 767
void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) {
  this->attrs_[name] = block;
F
fengjiayi 已提交
768
  need_update_ = true;
F
fengjiayi 已提交
769 770
}

771 772 773 774 775 776
void OpDesc::SetBlocksAttr(const std::string &name,
                           std::vector<BlockDesc *> blocks) {
  this->attrs_[name] = blocks;
  need_update_ = true;
}

Y
Yu Yang 已提交
777
void OpDesc::SetAttrMap(
F
fengjiayi 已提交
778 779 780 781 782
    const std::unordered_map<std::string, Attribute> &attr_map) {
  attrs_ = attr_map;
  need_update_ = true;
}

783 784 785 786 787 788
void OpDesc::SetRuntimeAttrMap(
    const std::unordered_map<std::string, Attribute> &attr_map) {
  runtime_attrs_ = attr_map;
  need_update_ = true;
}

789
Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const {
F
fengjiayi 已提交
790
  auto it = attrs_.find(name);
791 792 793
  if (it == attrs_.end()) {
    it = runtime_attrs_.find(name);
  }
794
  PADDLE_ENFORCE_NE(
795 796
      it,
      attrs_.end(),
797
      platform::errors::NotFound("Attribute %s is not found.", name));
798 799 800 801
  if (!with_attr_var) {
    PADDLE_ENFORCE_EQ(
        HasAttrVar(it->second),
        false,
802 803 804 805 806
        platform::errors::NotFound(
            "Attribute %s with constant value is not found, but found it with "
            "Variable(s) type, which maybe not supported in some scenarios "
            "currently, such as TensorRT et.al",
            name));
807
  }
F
fengjiayi 已提交
808 809 810
  return it->second;
}

M
minqiyang 已提交
811 812 813
const proto::OpProto::Attr &OpDesc::GetProtoAttr(
    const std::string &name) const {
  const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
M
minqiyang 已提交
814 815 816 817 818 819 820
  for (int i = 0; i != proto.attrs_size(); ++i) {
    const proto::OpProto::Attr &attr = proto.attrs(i);
    if (attr.name() == name) {
      return attr;
    }
  }

821 822
  PADDLE_THROW(platform::errors::NotFound(
      "Attribute %s is not found in proto %s.", name, proto.type()));
M
minqiyang 已提交
823 824
}

Y
yuyang18 已提交
825
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
Y
Fix bug  
yuyang18 已提交
826 827 828 829
  auto it = attrs_.find(name);
  if (it != attrs_.end()) {
    return it->second;
  } else {
Y
yuyang18 已提交
830
    return Attribute();
Y
Fix bug  
yuyang18 已提交
831 832 833
  }
}

G
gongweibao 已提交
834 835
std::vector<int> OpDesc::GetBlocksAttrIds(const std::string &name) const {
  auto it = attrs_.find(name);
836
  PADDLE_ENFORCE_NE(
837 838
      it,
      attrs_.end(),
839 840
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
R
Ruibiao Chen 已提交
841
  auto blocks = PADDLE_GET_CONST(std::vector<BlockDesc *>, it->second);
G
gongweibao 已提交
842 843 844 845 846 847 848 849 850 851

  std::vector<int> ids;
  for (auto n : blocks) {
    ids.push_back(n->ID());
  }

  return ids;
}

int OpDesc::GetBlockAttrId(const std::string &name) const {
F
fengjiayi 已提交
852
  auto it = attrs_.find(name);
853
  PADDLE_ENFORCE_NE(
854 855
      it,
      attrs_.end(),
856 857
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
R
Ruibiao Chen 已提交
858
  return PADDLE_GET_CONST(BlockDesc *, it->second)->ID();
F
fengjiayi 已提交
859 860
}

Y
Yu Yang 已提交
861
const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
F
fengjiayi 已提交
862 863 864
  return attrs_;
}

865 866
const AttributeMap &OpDesc::GetRuntimeAttrMap() const { return runtime_attrs_; }

Y
Yu Yang 已提交
867
void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
Y
Yancey1989 已提交
868 869
  RenameInput(old_name, new_name);
  RenameOutput(old_name, new_name);
F
fengjiayi 已提交
870 871 872
  need_update_ = true;
}

Y
Yu Yang 已提交
873 874
void OpDesc::RenameOutput(const std::string &old_name,
                          const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
875
  for (auto &output : outputs_) {
876 877
    std::replace(
        output.second.begin(), output.second.end(), old_name, new_name);
Y
Yang Yang(Tony) 已提交
878
  }
Y
yuyang18 已提交
879 880 881

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
R
Ruibiao Chen 已提交
882
    auto &op_vars = PADDLE_GET(std::vector<std::string>, it->second);
Y
yuyang18 已提交
883 884 885
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
886 887 888
  need_update_ = true;
}

Y
Yu Yang 已提交
889 890
void OpDesc::RenameInput(const std::string &old_name,
                         const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
891 892 893
  for (auto &input : inputs_) {
    std::replace(input.second.begin(), input.second.end(), old_name, new_name);
  }
Y
Yancey1989 已提交
894 895 896

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
R
Ruibiao Chen 已提交
897
    auto &op_vars = PADDLE_GET(std::vector<std::string>, it->second);
Y
Yancey1989 已提交
898 899 900
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
901 902 903
  need_update_ = true;
}

904
struct SetAttrDescVisitor {
905 906
  explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
  mutable proto::OpDesc::Attr *attr_;
Y
Yu Yang 已提交
907 908
  void operator()(int v) const { attr_->set_i(v); }
  void operator()(float v) const { attr_->set_f(v); }
909
  void operator()(double v) const { attr_->set_float64(v); }
Y
Yu Yang 已提交
910
  void operator()(const std::string &v) const { attr_->set_s(v); }
Q
QI JUN 已提交
911 912 913 914 915 916 917

  // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
  template <class T,
            class = typename std::enable_if<std::is_same<bool, T>::value>::type>
  void operator()(T b) const {
    attr_->set_b(b);
  }
Y
Yu Yang 已提交
918 919 920 921 922 923 924 925 926 927 928 929 930

  void operator()(const std::vector<int> &v) const {
    VectorToRepeated(v, attr_->mutable_ints());
  }
  void operator()(const std::vector<float> &v) const {
    VectorToRepeated(v, attr_->mutable_floats());
  }
  void operator()(const std::vector<std::string> &v) const {
    VectorToRepeated(v, attr_->mutable_strings());
  }
  void operator()(const std::vector<bool> &v) const {
    VectorToRepeated(v, attr_->mutable_bools());
  }
931 932 933 934 935 936 937 938 939 940 941 942 943

  void operator()(const std::vector<VarDesc *> &v) const {
    std::vector<std::string> var_names;
    for (auto var : v) {
      var_names.emplace_back(var->Name());
    }
    VectorToRepeated(var_names, attr_->mutable_vars_name());
  }

  void operator()(const VarDesc *desc) const {
    attr_->set_var_name(desc->Name());
  }

944 945 946
  void operator()(const std::vector<BlockDesc *> &v) const {
    std::vector<int> blocks_idx;
    for (auto blk : v) {
T
tangwei12 已提交
947
      blocks_idx.push_back(blk->ID());
948 949 950
    }
    VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
  }
T
tangwei12 已提交
951 952 953

  void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }

954
  void operator()(int64_t v) const { attr_->set_l(v); }
T
tangwei12 已提交
955 956 957 958 959

  void operator()(const std::vector<int64_t> &v) const {
    VectorToRepeated(v, attr_->mutable_longs());
  }

960 961 962 963
  void operator()(const std::vector<double> &v) const {
    VectorToRepeated(v, attr_->mutable_float64s());
  }

R
Ruibiao Chen 已提交
964
  void operator()(paddle::blank) const {
965 966 967 968
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported calling method of SetAttrDescVisitor object for "
        "`boosst::blank` type."));
  }
Y
Yu Yang 已提交
969 970
};

Y
Yu Yang 已提交
971
void OpDesc::Flush() {
L
Leo Chen 已提交
972 973
  VLOG(4) << "Flush "
          << " " << Type() << " " << need_update_;
F
fengjiayi 已提交
974
  if (need_update_) {
975
    this->desc_.mutable_inputs()->Clear();
F
fengjiayi 已提交
976
    for (auto &ipt : inputs_) {
977
      auto *input = desc_.add_inputs();
F
fengjiayi 已提交
978 979 980 981
      input->set_parameter(ipt.first);
      VectorToRepeated(ipt.second, input->mutable_arguments());
    }

982
    this->desc_.mutable_outputs()->Clear();
F
fengjiayi 已提交
983
    for (auto &opt : outputs_) {
984
      auto *output = desc_.add_outputs();
F
fengjiayi 已提交
985 986 987 988
      output->set_parameter(opt.first);
      VectorToRepeated(opt.second, output->mutable_arguments());
    }

989
    this->desc_.mutable_attrs()->Clear();
990 991 992 993 994 995 996 997 998
    auto set_attr_desc = [this](const std::string &attr_name,
                                const Attribute &attr) -> void {
      auto *attr_desc = desc_.add_attrs();
      attr_desc->set_name(attr_name);
      attr_desc->set_type(static_cast<proto::AttrType>(attr.index() - 1));
      SetAttrDescVisitor visitor(attr_desc);
      paddle::visit(visitor, attr);
    };

L
Leo Chen 已提交
999 1000 1001 1002 1003 1004 1005
    std::vector<std::pair<std::string, Attribute>> sorted_attrs{attrs_.begin(),
                                                                attrs_.end()};
    std::sort(
        sorted_attrs.begin(),
        sorted_attrs.end(),
        [](std::pair<std::string, Attribute> a,
           std::pair<std::string, Attribute> b) { return a.first < b.first; });
1006

L
Leo Chen 已提交
1007
    for (auto &attr : sorted_attrs) {
1008 1009 1010 1011
      set_attr_desc(attr.first, attr.second);
    }
    for (auto &attr : runtime_attrs_) {
      set_attr_desc(attr.first, attr.second);
F
fengjiayi 已提交
1012 1013 1014 1015 1016
    }

    need_update_ = false;
  }
}
Y
Yu Yang 已提交
1017

Y
Yu Yang 已提交
1018
void OpDesc::CheckAttrs() {
1019 1020
  PADDLE_ENFORCE_EQ(Type().empty(),
                    false,
1021 1022
                    platform::errors::PreconditionNotMet(
                        "CheckAttrs() can not be called before type is set."));
Y
Yu Yang 已提交
1023 1024 1025 1026 1027 1028
  auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
  if (checker == nullptr) {
    // checker is not configured. That operator could be generated by Paddle,
    // not by users.
    return;
  }
1029
  VLOG(10) << "begin to check attribute of " << Type();
T
tangwei12 已提交
1030
  checker->Check(&attrs_);
1031 1032 1033 1034 1035 1036 1037
  const auto &extra_attr_checkers =
      operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(Type());
  if (!extra_attr_checkers.empty()) {
    for (const auto &extra_checker : extra_attr_checkers) {
      extra_checker(&runtime_attrs_, false);
    }
  }
F
fengjiayi 已提交
1038 1039
}

H
hong 已提交
1040
void OpDesc::InferShape(const BlockDesc &block) {
1041 1042
  try {
    VLOG(3) << "CompileTime infer shape on " << Type();
H
hong 已提交
1043
    auto &op_info = OpInfoMap::Instance().Get(this->Type());
1044
    this->CheckAttrs();
H
hong 已提交
1045
    auto &infer_shape = op_info.infer_shape_;
1046
    PADDLE_ENFORCE_EQ(
1047 1048
        static_cast<bool>(infer_shape),
        true,
1049 1050
        platform::errors::NotFound(
            "Operator %s's infer_shape is not registered.", this->Type()));
1051 1052 1053 1054 1055
    CompileTimeInferShapeContext ctx(*this, block);
    if (VLOG_IS_ON(10)) {
      std::ostringstream sout;
      auto inames = this->InputArgumentNames();
      sout << " From [";
1056 1057
      std::copy(inames.begin(),
                inames.end(),
1058 1059 1060
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "] to [";
      auto onames = this->OutputArgumentNames();
1061 1062
      std::copy(onames.begin(),
                onames.end(),
1063 1064 1065 1066 1067
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "]";
      VLOG(10) << sout.str();
    }
    infer_shape(&ctx);
1068
  } catch (platform::EnforceNotMet &exception) {
1069
    framework::AppendErrorOpHint(Type(), &exception);
1070 1071 1072 1073
    throw std::move(exception);
  } catch (...) {
    std::rethrow_exception(std::current_exception());
  }
Y
Yu Yang 已提交
1074 1075
}

Y
Yu Yang 已提交
1076
void OpDesc::InferVarType(BlockDesc *block) const {
X
Xin Pan 已提交
1077 1078
  // There are a few places that var type can be set.
  // When VarDesc is created, default set to LOD_TENSOR.
T
tianshuo78520a 已提交
1079
  // When output variable is created, default is default set to LOD_TENSOR.
X
Xin Pan 已提交
1080 1081
  // We limit here to be the only place that operator defines its customized
  // var type inference. Hence, we don't do any "default" setting here.
Y
Yu Yang 已提交
1082 1083
  auto &info = OpInfoMap::Instance().Get(this->Type());
  if (info.infer_var_type_) {
M
minqiyang 已提交
1084
    InferVarTypeContext context(this, block);
M
minqiyang 已提交
1085
    info.infer_var_type_(&context);
Y
Yu Yang 已提交
1086 1087 1088
  }
}

1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102
OperatorDistAttr *OpDesc::MutableDistAttr() {
  if (dist_attr_) {
    return dist_attr_.get();
  } else {
    dist_attr_.reset(new OperatorDistAttr(*this));
    return dist_attr_.get();
  }
}

void OpDesc::SetDistAttr(const OperatorDistAttr &dist_attr) {
  MutableDistAttr();
  *dist_attr_ = dist_attr;
}

1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147
void OpDesc::UpdateVarAttr(const std::string &name, const Attribute &attr) {
  auto attr_type = static_cast<proto::AttrType>(attr.index() - 1);
  auto type = GetAttrType(name, true);
  if (type == proto::AttrType::VAR) {
    PADDLE_ENFORCE_EQ(
        attr_type,
        type,
        platform::errors::InvalidArgument(
            "Required attr.type == proto::AttrType::VAR, but received %s",
            attr_type));
    auto *var_desc = PADDLE_GET_CONST(VarDesc *, attr);
    VLOG(3) << "Update AttrVar " << name << " with " << var_desc->Name();
    attrs_[name] = FindVarRecursive(var_desc->Name());
  } else if (type == proto::AttrType::VARS) {
    PADDLE_ENFORCE_EQ(
        attr_type,
        type,
        platform::errors::InvalidArgument(
            "Required attr.type == proto::AttrType::VARS, but received %s",
            attr_type));
    auto vars_desc = PADDLE_GET_CONST(std::vector<VarDesc *>, attr);
    std::vector<VarDesc *> new_val;
    for (auto &var_desc : vars_desc) {
      VLOG(3) << "Update AttrVars " << name << " with " << var_desc->Name();
      new_val.emplace_back(FindVarRecursive(var_desc->Name()));
    }
    attrs_[name] = std::move(new_val);
  }
}

VarDesc *OpDesc::FindVarRecursive(const std::string &name) {
  auto *cur_block = block_;
  while (cur_block != nullptr && cur_block->ID() >= 0) {
    auto *var = block_->FindVar(name);
    if (var != nullptr) {
      return var;
    }
    cur_block = cur_block->ParentBlock();
  }
  PADDLE_THROW(platform::errors::NotFound(
      "Not found Var(%s) from Block(%d) back into global Block.",
      name,
      block_->ID()));
}

1148
CompileTimeInferShapeContext::CompileTimeInferShapeContext(
Y
Yu Yang 已提交
1149
    const OpDesc &op, const BlockDesc &block)
1150 1151 1152
    : op_(op), block_(block) {}

bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
1153 1154
  auto inputs = op_.Inputs(/*with_attr_var=*/true);
  if (inputs.find(name) == inputs.end()) {
1155 1156
    return false;
  }
1157 1158
  const std::vector<std::string> &input_names =
      op_.Input(name, /*with_attr_var=*/true);
1159 1160 1161 1162
  auto length = input_names.size();
  if (length == 0) {
    return false;
  }
1163
  PADDLE_ENFORCE_EQ(
1164 1165
      length,
      1UL,
1166 1167
      platform::errors::InvalidArgument("Input(%s) should have only one value, "
                                        "but it has %d values now.",
1168 1169
                                        name,
                                        length));
1170 1171 1172 1173
  return block_.HasVarRecursive(input_names[0]);
}

bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
1174 1175 1176
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
1177 1178 1179 1180 1181
  const std::vector<std::string> &output_names = op_.Output(name);
  auto length = output_names.size();
  if (length == 0) {
    return false;
  }
1182 1183
  PADDLE_ENFORCE_EQ(length,
                    1UL,
1184 1185 1186
                    platform::errors::InvalidArgument(
                        "Output(%s) should have only one value, "
                        "but it has %d values now.",
1187 1188
                        name,
                        length));
1189 1190 1191
  return block_.HasVarRecursive(output_names[0]);
}

1192
bool CompileTimeInferShapeContext::HasAttr(const std::string &name) const {
1193
  return op_.HasAttr(name, /*with_attr_var=*/false);
1194 1195
}

1196
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
1197 1198
  auto inputs = op_.Inputs(/*with_attr_var=*/true);
  if (inputs.find(name) == inputs.end()) {
1199 1200
    return false;
  }
1201 1202
  const std::vector<std::string> &input_names =
      op_.Input(name, /*with_attr_var=*/true);
1203 1204 1205 1206 1207 1208 1209 1210 1211
  if (input_names.empty()) {
    return false;
  }
  for (auto &input : input_names) {
    if (!block_.HasVarRecursive(input)) return false;
  }
  return true;
}

1212 1213
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name,
                                              bool allow_null) const {
1214 1215 1216
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
1217 1218 1219 1220
  const std::vector<std::string> &output_names = op_.Output(name);
  if (output_names.empty()) {
    return false;
  }
1221 1222 1223 1224 1225 1226 1227 1228 1229 1230
  if (allow_null) {
    for (auto &output : output_names) {
      if (block_.HasVarRecursive(output)) return true;
    }
    return false;
  } else {
    for (auto &output : output_names) {
      if (!block_.HasVarRecursive(output)) return false;
    }
    return true;
1231 1232 1233 1234
  }
}

AttrReader CompileTimeInferShapeContext::Attrs() const {
1235
  return AttrReader(op_.GetAttrMap(), op_.GetRuntimeAttrMap());
1236 1237
}

H
hong 已提交
1238
std::vector<std::string> CompileTimeInferShapeContext::Inputs(
1239
    const std::string &name) const {
1240
  return op_.Input(name, /*with_attr_var=*/true);
1241 1242
}

H
hong 已提交
1243
std::vector<std::string> CompileTimeInferShapeContext::Outputs(
1244 1245 1246 1247
    const std::string &name) const {
  return op_.Output(name);
}

F
fengjiayi 已提交
1248
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
F
fengjiayi 已提交
1249 1250
    const std::string &name) const {
  auto var = block_.FindVarRecursive(name);
1251 1252
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
1253 1254 1255 1256
  std::vector<DDim> res;
  try {
    auto shapes = var->GetShapes();
    for (const auto &s : shapes) {
1257
      res.push_back(s.empty() ? phi::make_ddim({0UL}) : phi::make_ddim(s));
F
fengjiayi 已提交
1258 1259
    }
  } catch (...) {
M
minqiyang 已提交
1260
    VLOG(5) << "GetRepeatedDim of variable " << name << " error.";
F
fengjiayi 已提交
1261 1262 1263
    std::rethrow_exception(std::current_exception());
  }
  return res;
1264 1265 1266 1267
}

void CompileTimeInferShapeContext::SetDim(const std::string &name,
                                          const DDim &dim) {
F
fengjiayi 已提交
1268
  block_.FindVarRecursive(name)->SetShape(vectorize(dim));
1269
}
F
fengjiayi 已提交
1270 1271 1272 1273

void CompileTimeInferShapeContext::SetRepeatedDims(
    const std::string &name, const std::vector<DDim> &dims) {
  auto var = block_.FindVarRecursive(name);
1274 1275
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
1276
  std::vector<std::vector<int64_t>> dim_vec(dims.size());
1277
  std::transform(dims.begin(), dims.end(), dim_vec.begin(), phi::vectorize<>);
F
fengjiayi 已提交
1278
  var->SetShapes(dim_vec);
1279
}
F
fengjiayi 已提交
1280

1281 1282
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }

1283 1284
bool CompileTimeInferShapeContext::IsRunMKLDNNKernel() const { return false; }

1285
proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
1286 1287 1288
    const std::string &name) const {
  return block_.FindVarRecursive(name)->GetType();
}
1289

1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305
std::vector<std::string> AttrVarNames(const Attribute &attr) {
  std::vector<std::string> vars_name;
  if (IsAttrVar(attr)) {
    vars_name.emplace_back(PADDLE_GET_CONST(VarDesc *, attr)->Name());
  } else if (IsAttrVars(attr)) {
    for (auto &iter : PADDLE_GET_CONST(std::vector<VarDesc *>, attr)) {
      vars_name.emplace_back(iter->Name());
    }
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Unsupported Attribute value type `%s` for AttrVarNames",
        platform::demangle(attr.type().name())));
  }
  return vars_name;
}

F
fengjiayi 已提交
1306 1307
}  // namespace framework
}  // namespace paddle