op_desc.cc 40.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_desc.h"
16

17
#include <string>
18

19
#include "glog/logging.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/framework/block_desc.h"
21
#include "paddle/fluid/framework/op_call_stack.h"
Y
yuyang18 已提交
22
#include "paddle/fluid/framework/op_proto_maker.h"
Y
Yi Wang 已提交
23 24
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
M
minqiyang 已提交
25
#include "paddle/fluid/framework/var_type_inference.h"
R
Ruibiao Chen 已提交
26
#include "paddle/utils/blank.h"
Y
Yu Yang 已提交
27

F
fengjiayi 已提交
28 29 30
namespace paddle {
namespace framework {

31 32
class CompileTimeInferShapeContext : public InferShapeContext {
 public:
Y
Yu Yang 已提交
33
  CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
34 35 36 37 38

  bool HasInput(const std::string &name) const override;

  bool HasOutput(const std::string &name) const override;

39 40
  bool HasAttr(const std::string &name) const override;

41 42
  bool HasInputs(const std::string &name) const override;

43 44
  bool HasOutputs(const std::string &name,
                  bool allow_null = false) const override;
45 46 47

  AttrReader Attrs() const override;

H
hong 已提交
48
  std::vector<std::string> Inputs(const std::string &name) const override;
49

H
hong 已提交
50
  std::vector<std::string> Outputs(const std::string &name) const override;
51

52 53 54
  std::string GetInputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
55 56
    PADDLE_ENFORCE_LT(idx,
                      op_proto->inputs().size(),
57 58 59
                      platform::errors::OutOfRange(
                          "The index should be less than the size of inputs of "
                          "operator %s, but got index is %d and size is %d",
60 61 62
                          op_.Type(),
                          idx,
                          op_proto->inputs().size()));
63 64 65 66 67 68 69
    return op_proto->inputs()[idx].name();
  }

  std::string GetOutputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
    PADDLE_ENFORCE_LT(
70 71
        idx,
        op_proto->outputs().size(),
72 73 74
        platform::errors::OutOfRange(
            "The index should be less than the size of outputs of "
            "operator %s, but got index is %d and size is %d",
75 76 77
            op_.Type(),
            idx,
            op_proto->outputs().size()));
78 79 80
    return op_proto->outputs()[idx].name();
  }

81 82 83
  void ShareDim(const std::string &in,
                const std::string &out,
                size_t i = 0,
84
                size_t j = 0) override {
85 86
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
87 88 89
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
90 91 92 93
                          Inputs(in).size(),
                          i));
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
94 95 96
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
97 98
                          Outputs(out).size(),
                          j));
99

H
hong 已提交
100 101
    std::string input_n = Inputs(in)[i];
    std::string output_n = Outputs(out)[j];
102

103 104
    PADDLE_ENFORCE_NE(input_n,
                      framework::kEmptyVarName,
105 106
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
107 108
    PADDLE_ENFORCE_NE(output_n,
                      framework::kEmptyVarName,
109 110
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
111 112 113 114

    auto *in_var = block_.FindVarRecursive(input_n);
    auto *out_var = block_.FindVarRecursive(output_n);

115
    PADDLE_ENFORCE_EQ(
116 117
        in_var->GetType(),
        out_var->GetType(),
118 119 120
        platform::errors::InvalidArgument(
            "The type of input %s and output %s do not match. The input type "
            "is %s, output type is %s.",
121 122 123
            input_n,
            output_n,
            DataTypeToString(in_var->GetType()),
124
            DataTypeToString(out_var->GetType())));
125 126 127 128

    SetDim(output_n, GetDim(input_n));
  }

H
hong 已提交
129 130 131 132 133 134
  void ShareAllLoD(const std::string &in,
                   const std::string &out) const override {
    auto &in_var_names = op_.Input(in);
    auto &out_var_names = op_.Output(out);

    PADDLE_ENFORCE_EQ(
135 136
        in_var_names.size(),
        out_var_names.size(),
H
hong 已提交
137
        platform::errors::PreconditionNotMet(
T
tianshuo78520a 已提交
138
            "Op [%s]:  Input var number should be equal with output var number",
H
hong 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
            op_.Type()));

    for (size_t i = 0; i < in_var_names.size(); ++i) {
      if (out_var_names[i] == framework::kEmptyVarName) {
        continue;
      }

      auto *in_var = block_.FindVarRecursive(in_var_names[i]);
      auto *out_var = block_.FindVarRecursive(out_var_names[i]);
      if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
          in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
        VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
        return;
      }
      out_var->SetLoDLevel(in_var->GetLoDLevel());
    }
  }

157 158 159
  void ShareLoD(const std::string &in,
                const std::string &out,
                size_t i = 0,
Q
Qiao Longfei 已提交
160
                size_t j = 0) const override {
161 162
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
163 164 165
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
166 167 168 169
                          Inputs(in).size(),
                          i));
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
170 171 172
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
173 174 175 176
                          Outputs(out).size(),
                          j));
    PADDLE_ENFORCE_NE(Inputs(in)[i],
                      framework::kEmptyVarName,
177 178
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
179 180
    PADDLE_ENFORCE_NE(Outputs(out)[j],
                      framework::kEmptyVarName,
181 182
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
Q
Qiao Longfei 已提交
183 184
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
C
chengduo 已提交
185 186
    if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
        in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
187
      VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
X
fix  
Xin Pan 已提交
188 189
      return;
    }
190
    out_var->SetLoDLevel(in_var->GetLoDLevel());
Q
Qiao Longfei 已提交
191
  }
D
dzhwinter 已提交
192

193
  int32_t GetLoDLevel(const std::string &in, size_t i = 0) const override {
194 195
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
196 197 198
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, input "
                          "variable %s of operator %s only has %d elements.",
199 200 201 202 203
                          in,
                          op_.Type(),
                          Inputs(in).size()));
    PADDLE_ENFORCE_NE(Inputs(in)[i],
                      framework::kEmptyVarName,
204 205
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] of operator %s is empty.",
206 207 208
                          in,
                          i,
                          op_.Type()));
C
chengduo 已提交
209
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
210
    PADDLE_ENFORCE_NOT_NULL(
211 212 213 214 215 216
        in_var,
        platform::errors::NotFound(
            "The input variable %s[%d] of operator %s is not found.",
            in,
            i,
            op_.Type()));
217
    return in_var->GetLoDLevel();
C
chengduo 已提交
218 219
  }

220 221
  void SetLoDLevel(const std::string &out,
                   int32_t lod_level,
222
                   size_t j = 0) const override {
223 224
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
225 226 227
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, output "
                          "variable %s of operator %s only has %d elements.",
228 229 230 231 232
                          out,
                          op_.Type(),
                          Outputs(out).size()));
    PADDLE_ENFORCE_NE(Outputs(out)[j],
                      framework::kEmptyVarName,
233 234
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] of operator %s is empty.",
235 236 237
                          out,
                          j,
                          op_.Type()));
238
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
239
    PADDLE_ENFORCE_NOT_NULL(
240 241 242 243 244 245
        out_var,
        platform::errors::NotFound(
            "The output variable %s[%d] of operator %s is not found.",
            out,
            j,
            op_.Type()));
246 247 248
    if (lod_level >= 0) {
      out_var->SetLoDLevel(lod_level);
    }
249 250
  }

C
Chen Weihang 已提交
251
  paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
252
  GetInputVarPtrs(const std::string &name) const override {
253
    const std::vector<std::string> arg_names = Inputs(name);
C
Chen Weihang 已提交
254
    paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
255
    res.reserve(arg_names.size());
256 257 258
    std::transform(arg_names.begin(),
                   arg_names.end(),
                   std::back_inserter(res),
259 260 261 262 263 264
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

C
Chen Weihang 已提交
265
  paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
266
  GetOutputVarPtrs(const std::string &name) const override {
267
    const std::vector<std::string> arg_names = Outputs(name);
C
Chen Weihang 已提交
268
    paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
269
    res.reserve(arg_names.size());
270 271 272
    std::transform(arg_names.begin(),
                   arg_names.end(),
                   std::back_inserter(res),
273 274 275 276 277 278
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

X
Xin Pan 已提交
279 280
  DDim GetInputDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
281 282
    PADDLE_ENFORCE_EQ(arg_names.size(),
                      1UL,
283 284 285
                      platform::errors::InvalidArgument(
                          "The input(%s) should hold only one element, but now "
                          "it holds %d elements.",
286 287
                          name,
                          arg_names.size()));
X
Xin Pan 已提交
288 289 290 291 292 293 294 295
    return this->GetDim(arg_names[0]);
  }

  std::vector<DDim> GetInputsDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
    return GetDims(arg_names);
  }

296 297
  bool IsRuntime() const override;

298 299
  bool IsRunMKLDNNKernel() const override;

300 301 302 303
  proto::VarType::Type GetInputVarType(const std::string &name) const override {
    return GetVarType(Inputs(name).at(0));
  }

X
Xin Pan 已提交
304 305 306 307 308 309 310 311 312 313
  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Inputs(name));
  }

  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Outputs(name));
  }

X
Xin Pan 已提交
314
  void SetOutputDim(const std::string &name, const DDim &dim) override {
H
hong 已提交
315
    auto arg_names = Outputs(name);
316 317
    PADDLE_ENFORCE_EQ(arg_names.size(),
                      1UL,
318 319 320
                      platform::errors::InvalidArgument(
                          "The iutput(%s) should hold only one element, but "
                          "now it holds %d elements.",
321 322
                          name,
                          arg_names.size()));
X
Xin Pan 已提交
323 324 325 326 327
    SetDim(arg_names[0], dim);
  }

  void SetOutputsDim(const std::string &name,
                     const std::vector<DDim> &dims) override {
H
hong 已提交
328
    auto names = Outputs(name);
X
Xin Pan 已提交
329 330 331
    SetDims(names, dims);
  }

332 333 334 335 336 337 338 339
  const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const override {
    return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
  }

  const phi::KernelSignature *GetPhiDefaultKernelSignature() const override {
    return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
  }

340
 protected:
X
Xin Pan 已提交
341 342 343 344 345
  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<std::string> &names) const {
    std::vector<proto::VarType::Type> retv;
    retv.resize(names.size());
    std::transform(
346 347 348 349 350
        names.begin(),
        names.end(),
        retv.begin(),
        std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType),
                  this,
X
Xin Pan 已提交
351 352 353 354 355
                  std::placeholders::_1));
    return retv;
  }

  proto::VarType::Type GetVarType(const std::string &name) const;
Q
Qiao Longfei 已提交
356

X
Xin Pan 已提交
357 358
  DDim GetDim(const std::string &name) const {
    auto var = block_.FindVarRecursive(name);
359 360
    PADDLE_ENFORCE_NOT_NULL(
        var, platform::errors::NotFound("Variable %s is not found.", name));
X
Xin Pan 已提交
361 362 363
    DDim res;
    try {
      auto shape = var->GetShape();
364
      res = shape.empty() ? phi::make_ddim({0UL}) : phi::make_ddim(shape);
X
Xin Pan 已提交
365 366 367 368 369 370 371 372 373 374 375
    } catch (...) {
      VLOG(5) << "GetDim of variable " << name << " error";
      std::rethrow_exception(std::current_exception());
    }
    return res;
  }

  std::vector<DDim> GetDims(const std::vector<std::string> &names) const {
    std::vector<DDim> ret;
    ret.reserve(names.size());
    std::transform(
376 377 378
        names.begin(),
        names.end(),
        std::back_inserter(ret),
X
Xin Pan 已提交
379 380 381
        [this](const std::string &name) { return this->GetDim(name); });
    return ret;
  }
382

X
Xin Pan 已提交
383 384 385 386 387
  void SetDim(const std::string &name, const DDim &dim);

  void SetDims(const std::vector<std::string> &names,
               const std::vector<DDim> &dims) {
    size_t length = names.size();
388 389
    PADDLE_ENFORCE_EQ(length,
                      dims.size(),
390 391 392
                      platform::errors::InvalidArgument(
                          "The input variables number(%d) and input dimensions "
                          "number(%d) do not match.",
393 394
                          length,
                          dims.size()));
X
Xin Pan 已提交
395 396 397 398 399 400 401
    for (size_t i = 0; i < length; ++i) {
      if (names[i] == framework::kEmptyVarName) {
        continue;
      }
      SetDim(names[i], dims[i]);
    }
  }
402

F
fengjiayi 已提交
403 404 405 406
  std::vector<DDim> GetRepeatedDims(const std::string &name) const override;

  void SetRepeatedDims(const std::string &name,
                       const std::vector<DDim> &dims) override;
F
fengjiayi 已提交
407

Y
Yu Yang 已提交
408 409
  const OpDesc &op_;
  const BlockDesc &block_;
410 411
};

412 413 414 415
OpDesc::OpDesc(const std::string &type,
               const VariableNameMap &inputs,
               const VariableNameMap &outputs,
               const AttributeMap &attrs) {
416
  desc_.set_type(type);
F
fengjiayi 已提交
417 418 419
  inputs_ = inputs;
  outputs_ = outputs;
  attrs_ = attrs;
F
Fix bug  
fengjiayi 已提交
420
  need_update_ = true;
L
liuwei1031 已提交
421
  block_ = nullptr;
F
fengjiayi 已提交
422 423
}

424 425 426 427 428 429
OpDesc::OpDesc(const OpDesc &other) {
  CopyFrom(other);
  block_ = other.block_;
  need_update_ = true;
}

X
Xin Pan 已提交
430 431 432 433
OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) {
  CopyFrom(other);
  block_ = block;
  need_update_ = true;
434 435 436
  for (auto &iter : attrs_) {
    UpdateVarAttr(iter.first, iter.second);
  }
X
Xin Pan 已提交
437 438
}

439
void OpDesc::CopyFrom(const OpDesc &op_desc) {
F
fengjiayi 已提交
440 441 442 443
  desc_.set_type(op_desc.Type());
  inputs_ = op_desc.inputs_;
  outputs_ = op_desc.outputs_;
  attrs_ = op_desc.attrs_;
444
  original_id_ = op_desc.original_id_;
445 446 447
  if (op_desc.dist_attr_) {
    dist_attr_.reset(new OperatorDistAttr(*op_desc.dist_attr_));
  }
F
fengjiayi 已提交
448 449 450
  need_update_ = true;
}

F
fengjiayi 已提交
451
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
452 453 454 455
    : desc_(desc), need_update_(false) {
  // restore inputs_
  int input_size = desc_.inputs_size();
  for (int i = 0; i < input_size; ++i) {
456
    const proto::OpDesc::Var &var = desc_.inputs(i);
457 458 459 460 461 462 463 464 465 466
    std::vector<std::string> &args = inputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore outputs_
  int output_size = desc_.outputs_size();
  for (int i = 0; i < output_size; ++i) {
467
    const proto::OpDesc::Var &var = desc_.outputs(i);
468 469 470 471 472 473 474 475
    std::vector<std::string> &args = outputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore attrs_
476
  for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
477
    std::string attr_name = attr.name();
478
    // The sub_block referred to by the BLOCK attr hasn't been added
479 480 481 482 483 484 485
    // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS/VAR/VARS attr
    // here.
    auto attr_type = attr.type();
    if (attr_type != proto::AttrType::BLOCK &&
        attr_type != proto::AttrType::BLOCKS &&
        attr_type != proto::AttrType::VAR &&
        attr_type != proto::AttrType::VARS) {
486 487
      attrs_[attr_name] = GetAttrValue(attr);
    }
488
  }
489
  this->block_ = block;
490 491
}

492 493 494 495 496 497 498 499 500
// Explicitly implement the assign operator, Since the added
// unique_ptr data member does not have the implicit assign operator.
OpDesc &OpDesc::operator=(const OpDesc &other) {
  CopyFrom(other);
  block_ = other.block_;
  need_update_ = true;
  return *this;
}

Y
Yu Yang 已提交
501
proto::OpDesc *OpDesc::Proto() {
502
  Flush();
503
  return &desc_;
F
fengjiayi 已提交
504 505
}

Y
Yu Yang 已提交
506
const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
F
fengjiayi 已提交
507
  auto it = inputs_.find(name);
508
  PADDLE_ENFORCE_NE(
509 510 511 512
      it,
      inputs_.end(),
      platform::errors::NotFound(
          "Input %s cannot be found in operator %s.", name, Type()));
F
fengjiayi 已提交
513 514 515
  return it->second;
}

516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538
std::vector<std::string> OpDesc::Input(const std::string &name,
                                       bool with_attr_var) const {
  // Attribute with VarDesc type will consider as Input
  if (with_attr_var) {
    auto it = attrs_.find(name);
    if (it != attrs_.end() && HasAttrVar(it->second))
      return AttrVarNames(it->second);
  }
  return this->Input(name);
}

VariableNameMap OpDesc::Inputs(bool with_attr_var) const {
  if (!with_attr_var) {
    return inputs_;
  }
  VariableNameMap res = inputs_;
  for (auto &attr : FilterAttrVar(attrs_)) {
    res[attr.first] = AttrVarNames(attr.second);
  }
  return res;
}

std::vector<std::string> OpDesc::InputArgumentNames(bool with_attr_var) const {
F
Update  
fengjiayi 已提交
539
  std::vector<std::string> retv;
540
  for (auto &ipt : this->Inputs(with_attr_var)) {
F
Update  
fengjiayi 已提交
541 542 543 544 545
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
546 547
void OpDesc::SetInput(const std::string &param_name,
                      const std::vector<std::string> &args) {
F
fengjiayi 已提交
548 549 550 551
  need_update_ = true;
  inputs_[param_name] = args;
}

Y
Yu Yang 已提交
552
const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
F
fengjiayi 已提交
553
  auto it = outputs_.find(name);
554
  PADDLE_ENFORCE_NE(
555 556 557 558
      it,
      outputs_.end(),
      platform::errors::NotFound(
          "Output %s cannot be found in operator %s.", name, Type()));
F
fengjiayi 已提交
559 560 561
  return it->second;
}

562 563 564 565
bool OpDesc::HasOutput(const std::string &name) const {
  return outputs_.find(name) != outputs_.end();
}

Y
Yu Yang 已提交
566
std::vector<std::string> OpDesc::OutputArgumentNames() const {
F
Update  
fengjiayi 已提交
567 568 569 570 571 572 573
  std::vector<std::string> retv;
  for (auto &ipt : this->outputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
574 575
void OpDesc::SetOutput(const std::string &param_name,
                       const std::vector<std::string> &args) {
F
fengjiayi 已提交
576 577 578 579
  need_update_ = true;
  this->outputs_[param_name] = args;
}

580 581 582 583 584
void OpDesc::RemoveOutput(const std::string &name) {
  outputs_.erase(name);
  need_update_ = true;
}

585 586 587 588 589
void OpDesc::RemoveInput(const std::string &name) {
  inputs_.erase(name);
  need_update_ = true;
}

590 591 592 593 594 595 596 597 598 599
bool OpDesc::HasProtoAttr(const std::string &name) const {
  auto &op_info = OpInfoMap::Instance();
  if (op_info.Has(desc_.type())) {
    auto op_info_ptr = op_info.Get(desc_.type());
    if (op_info_ptr.HasOpProtoAndChecker()) {
      const proto::OpProto &proto = op_info_ptr.Proto();
      for (int i = 0; i != proto.attrs_size(); ++i) {
        const proto::OpProto::Attr &attr = proto.attrs(i);
        if (attr.name() == name) {
          return true;
L
luotao1 已提交
600 601
        }
      }
L
luotao1 已提交
602 603 604 605 606
    }
  }
  return false;
}

607 608 609 610
proto::AttrType OpDesc::GetAttrType(const std::string &name,
                                    bool with_attr_var) const {
  auto attr = this->GetAttr(name, with_attr_var);
  return static_cast<proto::AttrType>(attr.index() - 1);
F
fengjiayi 已提交
611 612
}

613
std::vector<std::string> OpDesc::AttrNames(bool with_attr_var) const {
F
fengjiayi 已提交
614 615 616
  std::vector<std::string> retv;
  retv.reserve(attrs_.size());
  for (auto &attr : attrs_) {
617
    if (!with_attr_var && HasAttrVar(attr.second)) continue;
F
fengjiayi 已提交
618 619 620 621 622
    retv.push_back(attr.first);
  }
  return retv;
}

623 624 625 626 627 628 629 630 631
bool OpDesc::HasAttr(const std::string &name, bool with_attr_var) const {
  auto iter = attrs_.find(name);
  bool is_found = iter != attrs_.end();
  if (with_attr_var) {
    return is_found;
  }
  return is_found && !HasAttrVar(iter->second);
}

632 633 634 635 636
void OpDesc::RemoveAttr(const std::string &name) {
  attrs_.erase(name);
  need_update_ = true;
}

Y
Yu Yang 已提交
637
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
M
minqiyang 已提交
638 639 640
  // NOTICE(minqiyang): pybind11 will take the empty list in python as
  // the std::vector<int> type in C++; so we have to change the attr's type
  // here if we meet this issue
R
Ruibiao Chen 已提交
641
  proto::AttrType attr_type = static_cast<proto::AttrType>(v.index() - 1);
M
minqiyang 已提交
642
  if (attr_type == proto::AttrType::INTS &&
R
Ruibiao Chen 已提交
643
      PADDLE_GET_CONST(std::vector<int>, v).size() == 0u) {
M
minqiyang 已提交
644
    // Find current attr via attr name and set the correct attribute value
M
minqiyang 已提交
645
    const proto::OpProto::Attr &attr = GetProtoAttr(name);
M
minqiyang 已提交
646 647
    switch (attr.type()) {
      case proto::AttrType::BOOLEANS: {
M
minqiyang 已提交
648 649
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BOOLEANS";
M
minqiyang 已提交
650 651 652 653
        this->attrs_[name] = std::vector<bool>();
        break;
      }
      case proto::AttrType::INTS: {
M
minqiyang 已提交
654 655
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to INTS";
M
minqiyang 已提交
656 657 658
        this->attrs_[name] = std::vector<int>();
        break;
      }
659
      case proto::AttrType::LONGS: {
M
minqiyang 已提交
660 661
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from LONGS to LONGS";
662 663 664
        this->attrs_[name] = std::vector<int64_t>();
        break;
      }
M
minqiyang 已提交
665
      case proto::AttrType::FLOATS: {
M
minqiyang 已提交
666 667
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to FLOATS";
M
minqiyang 已提交
668 669 670
        this->attrs_[name] = std::vector<float>();
        break;
      }
671 672 673 674 675 676
      case proto::AttrType::FLOAT64S: {
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to FLOAT64S";
        this->attrs_[name] = std::vector<double>();
        break;
      }
M
minqiyang 已提交
677
      case proto::AttrType::STRINGS: {
M
minqiyang 已提交
678 679
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to STRINGS";
M
minqiyang 已提交
680 681 682 683
        this->attrs_[name] = std::vector<std::string>();
        break;
      }
      case proto::AttrType::BLOCKS: {
M
minqiyang 已提交
684 685
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BLOCKS";
M
minqiyang 已提交
686
        this->SetBlocksAttr(name, std::vector<BlockDesc *>());
M
minqiyang 已提交
687 688
        return;
      }
M
minqiyang 已提交
689
      default:
690 691
        PADDLE_THROW(platform::errors::Unimplemented(
            "Unsupported attribute type (code %d).", attr.type()));
M
minqiyang 已提交
692
    }
M
minqiyang 已提交
693 694
    need_update_ = true;
    return;
M
minqiyang 已提交
695 696
  }

697 698 699
  // In order to set bool attr properly
  if (attr_type == proto::AttrType::INT && HasProtoAttr(name) &&
      GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) {
R
Ruibiao Chen 已提交
700
    this->attrs_[name] = static_cast<bool>(PADDLE_GET_CONST(int, v));
701 702 703 704
    need_update_ = true;
    return;
  }

F
fengjiayi 已提交
705 706 707 708
  this->attrs_[name] = v;
  need_update_ = true;
}

709 710 711 712 713 714 715 716 717 718
void OpDesc::SetVarAttr(const std::string &name, VarDesc *var) {
  this->attrs_[name] = var;
  need_update_ = true;
}

void OpDesc::SetVarsAttr(const std::string &name, std::vector<VarDesc *> vars) {
  this->attrs_[name] = vars;
  need_update_ = true;
}

A
Abhinav Arora 已提交
719 720
void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) {
  this->attrs_[name] = block;
F
fengjiayi 已提交
721
  need_update_ = true;
F
fengjiayi 已提交
722 723
}

724 725 726 727 728 729
void OpDesc::SetBlocksAttr(const std::string &name,
                           std::vector<BlockDesc *> blocks) {
  this->attrs_[name] = blocks;
  need_update_ = true;
}

Y
Yu Yang 已提交
730
void OpDesc::SetAttrMap(
F
fengjiayi 已提交
731 732 733 734 735
    const std::unordered_map<std::string, Attribute> &attr_map) {
  attrs_ = attr_map;
  need_update_ = true;
}

736
Attribute OpDesc::GetAttr(const std::string &name, bool with_attr_var) const {
F
fengjiayi 已提交
737
  auto it = attrs_.find(name);
738
  PADDLE_ENFORCE_NE(
739 740
      it,
      attrs_.end(),
741
      platform::errors::NotFound("Attribute %s is not found.", name));
742 743 744 745 746 747
  if (!with_attr_var) {
    PADDLE_ENFORCE_EQ(
        HasAttrVar(it->second),
        false,
        platform::errors::NotFound("Attribute %s is not found.", name));
  }
F
fengjiayi 已提交
748 749 750
  return it->second;
}

M
minqiyang 已提交
751 752 753
const proto::OpProto::Attr &OpDesc::GetProtoAttr(
    const std::string &name) const {
  const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
M
minqiyang 已提交
754 755 756 757 758 759 760
  for (int i = 0; i != proto.attrs_size(); ++i) {
    const proto::OpProto::Attr &attr = proto.attrs(i);
    if (attr.name() == name) {
      return attr;
    }
  }

761 762
  PADDLE_THROW(platform::errors::NotFound(
      "Attribute %s is not found in proto %s.", name, proto.type()));
M
minqiyang 已提交
763 764
}

Y
yuyang18 已提交
765
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
Y
Fix bug  
yuyang18 已提交
766 767 768 769
  auto it = attrs_.find(name);
  if (it != attrs_.end()) {
    return it->second;
  } else {
Y
yuyang18 已提交
770
    return Attribute();
Y
Fix bug  
yuyang18 已提交
771 772 773
  }
}

G
gongweibao 已提交
774 775
std::vector<int> OpDesc::GetBlocksAttrIds(const std::string &name) const {
  auto it = attrs_.find(name);
776
  PADDLE_ENFORCE_NE(
777 778
      it,
      attrs_.end(),
779 780
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
R
Ruibiao Chen 已提交
781
  auto blocks = PADDLE_GET_CONST(std::vector<BlockDesc *>, it->second);
G
gongweibao 已提交
782 783 784 785 786 787 788 789 790 791

  std::vector<int> ids;
  for (auto n : blocks) {
    ids.push_back(n->ID());
  }

  return ids;
}

int OpDesc::GetBlockAttrId(const std::string &name) const {
F
fengjiayi 已提交
792
  auto it = attrs_.find(name);
793
  PADDLE_ENFORCE_NE(
794 795
      it,
      attrs_.end(),
796 797
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
R
Ruibiao Chen 已提交
798
  return PADDLE_GET_CONST(BlockDesc *, it->second)->ID();
F
fengjiayi 已提交
799 800
}

Y
Yu Yang 已提交
801
const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
F
fengjiayi 已提交
802 803 804
  return attrs_;
}

Y
Yu Yang 已提交
805
void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
Y
Yancey1989 已提交
806 807
  RenameInput(old_name, new_name);
  RenameOutput(old_name, new_name);
F
fengjiayi 已提交
808 809 810
  need_update_ = true;
}

Y
Yu Yang 已提交
811 812
void OpDesc::RenameOutput(const std::string &old_name,
                          const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
813
  for (auto &output : outputs_) {
814 815
    std::replace(
        output.second.begin(), output.second.end(), old_name, new_name);
Y
Yang Yang(Tony) 已提交
816
  }
Y
yuyang18 已提交
817 818 819

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
R
Ruibiao Chen 已提交
820
    auto &op_vars = PADDLE_GET(std::vector<std::string>, it->second);
Y
yuyang18 已提交
821 822 823
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
824 825 826
  need_update_ = true;
}

Y
Yu Yang 已提交
827 828
void OpDesc::RenameInput(const std::string &old_name,
                         const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
829 830 831
  for (auto &input : inputs_) {
    std::replace(input.second.begin(), input.second.end(), old_name, new_name);
  }
Y
Yancey1989 已提交
832 833 834

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
R
Ruibiao Chen 已提交
835
    auto &op_vars = PADDLE_GET(std::vector<std::string>, it->second);
Y
Yancey1989 已提交
836 837 838
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
839 840 841
  need_update_ = true;
}

842
struct SetAttrDescVisitor {
843 844
  explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
  mutable proto::OpDesc::Attr *attr_;
Y
Yu Yang 已提交
845 846
  void operator()(int v) const { attr_->set_i(v); }
  void operator()(float v) const { attr_->set_f(v); }
847
  void operator()(double v) const { attr_->set_float64(v); }
Y
Yu Yang 已提交
848
  void operator()(const std::string &v) const { attr_->set_s(v); }
Q
QI JUN 已提交
849 850 851 852 853 854 855

  // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
  template <class T,
            class = typename std::enable_if<std::is_same<bool, T>::value>::type>
  void operator()(T b) const {
    attr_->set_b(b);
  }
Y
Yu Yang 已提交
856 857 858 859 860 861 862 863 864 865 866 867 868

  void operator()(const std::vector<int> &v) const {
    VectorToRepeated(v, attr_->mutable_ints());
  }
  void operator()(const std::vector<float> &v) const {
    VectorToRepeated(v, attr_->mutable_floats());
  }
  void operator()(const std::vector<std::string> &v) const {
    VectorToRepeated(v, attr_->mutable_strings());
  }
  void operator()(const std::vector<bool> &v) const {
    VectorToRepeated(v, attr_->mutable_bools());
  }
869 870 871 872 873 874 875 876 877 878 879 880 881

  void operator()(const std::vector<VarDesc *> &v) const {
    std::vector<std::string> var_names;
    for (auto var : v) {
      var_names.emplace_back(var->Name());
    }
    VectorToRepeated(var_names, attr_->mutable_vars_name());
  }

  void operator()(const VarDesc *desc) const {
    attr_->set_var_name(desc->Name());
  }

882 883 884
  void operator()(const std::vector<BlockDesc *> &v) const {
    std::vector<int> blocks_idx;
    for (auto blk : v) {
T
tangwei12 已提交
885
      blocks_idx.push_back(blk->ID());
886 887 888
    }
    VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
  }
T
tangwei12 已提交
889 890 891

  void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }

892
  void operator()(int64_t v) const { attr_->set_l(v); }
T
tangwei12 已提交
893 894 895 896 897

  void operator()(const std::vector<int64_t> &v) const {
    VectorToRepeated(v, attr_->mutable_longs());
  }

898 899 900 901
  void operator()(const std::vector<double> &v) const {
    VectorToRepeated(v, attr_->mutable_float64s());
  }

R
Ruibiao Chen 已提交
902
  void operator()(paddle::blank) const {
903 904 905 906
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported calling method of SetAttrDescVisitor object for "
        "`boosst::blank` type."));
  }
Y
Yu Yang 已提交
907 908
};

Y
Yu Yang 已提交
909
void OpDesc::Flush() {
L
Leo Chen 已提交
910 911
  VLOG(4) << "Flush "
          << " " << Type() << " " << need_update_;
F
fengjiayi 已提交
912
  if (need_update_) {
913
    this->desc_.mutable_inputs()->Clear();
F
fengjiayi 已提交
914
    for (auto &ipt : inputs_) {
915
      auto *input = desc_.add_inputs();
F
fengjiayi 已提交
916 917 918 919
      input->set_parameter(ipt.first);
      VectorToRepeated(ipt.second, input->mutable_arguments());
    }

920
    this->desc_.mutable_outputs()->Clear();
F
fengjiayi 已提交
921
    for (auto &opt : outputs_) {
922
      auto *output = desc_.add_outputs();
F
fengjiayi 已提交
923 924 925 926
      output->set_parameter(opt.first);
      VectorToRepeated(opt.second, output->mutable_arguments());
    }

927
    this->desc_.mutable_attrs()->Clear();
L
Leo Chen 已提交
928 929 930 931 932 933 934 935
    std::vector<std::pair<std::string, Attribute>> sorted_attrs{attrs_.begin(),
                                                                attrs_.end()};
    std::sort(
        sorted_attrs.begin(),
        sorted_attrs.end(),
        [](std::pair<std::string, Attribute> a,
           std::pair<std::string, Attribute> b) { return a.first < b.first; });
    for (auto &attr : sorted_attrs) {
936
      auto *attr_desc = desc_.add_attrs();
F
fengjiayi 已提交
937 938
      attr_desc->set_name(attr.first);
      attr_desc->set_type(
R
Ruibiao Chen 已提交
939
          static_cast<proto::AttrType>(attr.second.index() - 1));
Y
Yu Yang 已提交
940
      SetAttrDescVisitor visitor(attr_desc);
R
Ruibiao Chen 已提交
941
      paddle::visit(visitor, attr.second);
F
fengjiayi 已提交
942 943 944 945 946
    }

    need_update_ = false;
  }
}
Y
Yu Yang 已提交
947

Y
Yu Yang 已提交
948
void OpDesc::CheckAttrs() {
949 950
  PADDLE_ENFORCE_EQ(Type().empty(),
                    false,
951 952
                    platform::errors::PreconditionNotMet(
                        "CheckAttrs() can not be called before type is set."));
Y
Yu Yang 已提交
953 954 955 956 957 958
  auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
  if (checker == nullptr) {
    // checker is not configured. That operator could be generated by Paddle,
    // not by users.
    return;
  }
959
  VLOG(10) << "begin to check attribute of " << Type();
T
tangwei12 已提交
960
  checker->Check(&attrs_);
F
fengjiayi 已提交
961 962
}

H
hong 已提交
963
void OpDesc::InferShape(const BlockDesc &block) {
964 965
  try {
    VLOG(3) << "CompileTime infer shape on " << Type();
H
hong 已提交
966
    auto &op_info = OpInfoMap::Instance().Get(this->Type());
967
    this->CheckAttrs();
H
hong 已提交
968
    auto &infer_shape = op_info.infer_shape_;
969
    PADDLE_ENFORCE_EQ(
970 971
        static_cast<bool>(infer_shape),
        true,
972 973
        platform::errors::NotFound(
            "Operator %s's infer_shape is not registered.", this->Type()));
974 975 976 977 978
    CompileTimeInferShapeContext ctx(*this, block);
    if (VLOG_IS_ON(10)) {
      std::ostringstream sout;
      auto inames = this->InputArgumentNames();
      sout << " From [";
979 980
      std::copy(inames.begin(),
                inames.end(),
981 982 983
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "] to [";
      auto onames = this->OutputArgumentNames();
984 985
      std::copy(onames.begin(),
                onames.end(),
986 987 988 989 990
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "]";
      VLOG(10) << sout.str();
    }
    infer_shape(&ctx);
991
  } catch (platform::EnforceNotMet &exception) {
992
    framework::AppendErrorOpHint(Type(), &exception);
993 994 995 996
    throw std::move(exception);
  } catch (...) {
    std::rethrow_exception(std::current_exception());
  }
Y
Yu Yang 已提交
997 998
}

Y
Yu Yang 已提交
999
void OpDesc::InferVarType(BlockDesc *block) const {
X
Xin Pan 已提交
1000 1001
  // There are a few places that var type can be set.
  // When VarDesc is created, default set to LOD_TENSOR.
T
tianshuo78520a 已提交
1002
  // When output variable is created, default is default set to LOD_TENSOR.
X
Xin Pan 已提交
1003 1004
  // We limit here to be the only place that operator defines its customized
  // var type inference. Hence, we don't do any "default" setting here.
Y
Yu Yang 已提交
1005 1006
  auto &info = OpInfoMap::Instance().Get(this->Type());
  if (info.infer_var_type_) {
M
minqiyang 已提交
1007
    InferVarTypeContext context(this, block);
M
minqiyang 已提交
1008
    info.infer_var_type_(&context);
Y
Yu Yang 已提交
1009 1010 1011
  }
}

1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
OperatorDistAttr *OpDesc::MutableDistAttr() {
  if (dist_attr_) {
    return dist_attr_.get();
  } else {
    dist_attr_.reset(new OperatorDistAttr(*this));
    return dist_attr_.get();
  }
}

void OpDesc::SetDistAttr(const OperatorDistAttr &dist_attr) {
  MutableDistAttr();
  *dist_attr_ = dist_attr;
}

1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070
void OpDesc::UpdateVarAttr(const std::string &name, const Attribute &attr) {
  auto attr_type = static_cast<proto::AttrType>(attr.index() - 1);
  auto type = GetAttrType(name, true);
  if (type == proto::AttrType::VAR) {
    PADDLE_ENFORCE_EQ(
        attr_type,
        type,
        platform::errors::InvalidArgument(
            "Required attr.type == proto::AttrType::VAR, but received %s",
            attr_type));
    auto *var_desc = PADDLE_GET_CONST(VarDesc *, attr);
    VLOG(3) << "Update AttrVar " << name << " with " << var_desc->Name();
    attrs_[name] = FindVarRecursive(var_desc->Name());
  } else if (type == proto::AttrType::VARS) {
    PADDLE_ENFORCE_EQ(
        attr_type,
        type,
        platform::errors::InvalidArgument(
            "Required attr.type == proto::AttrType::VARS, but received %s",
            attr_type));
    auto vars_desc = PADDLE_GET_CONST(std::vector<VarDesc *>, attr);
    std::vector<VarDesc *> new_val;
    for (auto &var_desc : vars_desc) {
      VLOG(3) << "Update AttrVars " << name << " with " << var_desc->Name();
      new_val.emplace_back(FindVarRecursive(var_desc->Name()));
    }
    attrs_[name] = std::move(new_val);
  }
}

VarDesc *OpDesc::FindVarRecursive(const std::string &name) {
  auto *cur_block = block_;
  while (cur_block != nullptr && cur_block->ID() >= 0) {
    auto *var = block_->FindVar(name);
    if (var != nullptr) {
      return var;
    }
    cur_block = cur_block->ParentBlock();
  }
  PADDLE_THROW(platform::errors::NotFound(
      "Not found Var(%s) from Block(%d) back into global Block.",
      name,
      block_->ID()));
}

1071
CompileTimeInferShapeContext::CompileTimeInferShapeContext(
Y
Yu Yang 已提交
1072
    const OpDesc &op, const BlockDesc &block)
1073 1074 1075
    : op_(op), block_(block) {}

bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
1076 1077
  auto inputs = op_.Inputs(/*with_attr_var=*/true);
  if (inputs.find(name) == inputs.end()) {
1078 1079
    return false;
  }
1080 1081
  const std::vector<std::string> &input_names =
      op_.Input(name, /*with_attr_var=*/true);
1082 1083 1084 1085
  auto length = input_names.size();
  if (length == 0) {
    return false;
  }
1086
  PADDLE_ENFORCE_EQ(
1087 1088
      length,
      1UL,
1089 1090
      platform::errors::InvalidArgument("Input(%s) should have only one value, "
                                        "but it has %d values now.",
1091 1092
                                        name,
                                        length));
1093 1094 1095 1096
  return block_.HasVarRecursive(input_names[0]);
}

bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
1097 1098 1099
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
1100 1101 1102 1103 1104
  const std::vector<std::string> &output_names = op_.Output(name);
  auto length = output_names.size();
  if (length == 0) {
    return false;
  }
1105 1106
  PADDLE_ENFORCE_EQ(length,
                    1UL,
1107 1108 1109
                    platform::errors::InvalidArgument(
                        "Output(%s) should have only one value, "
                        "but it has %d values now.",
1110 1111
                        name,
                        length));
1112 1113 1114
  return block_.HasVarRecursive(output_names[0]);
}

1115
bool CompileTimeInferShapeContext::HasAttr(const std::string &name) const {
1116
  return op_.HasAttr(name, /*with_attr_var=*/false);
1117 1118
}

1119
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
1120 1121
  auto inputs = op_.Inputs(/*with_attr_var=*/true);
  if (inputs.find(name) == inputs.end()) {
1122 1123
    return false;
  }
1124 1125
  const std::vector<std::string> &input_names =
      op_.Input(name, /*with_attr_var=*/true);
1126 1127 1128 1129 1130 1131 1132 1133 1134
  if (input_names.empty()) {
    return false;
  }
  for (auto &input : input_names) {
    if (!block_.HasVarRecursive(input)) return false;
  }
  return true;
}

1135 1136
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name,
                                              bool allow_null) const {
1137 1138 1139
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
1140 1141 1142 1143
  const std::vector<std::string> &output_names = op_.Output(name);
  if (output_names.empty()) {
    return false;
  }
1144 1145 1146 1147 1148 1149 1150 1151 1152 1153
  if (allow_null) {
    for (auto &output : output_names) {
      if (block_.HasVarRecursive(output)) return true;
    }
    return false;
  } else {
    for (auto &output : output_names) {
      if (!block_.HasVarRecursive(output)) return false;
    }
    return true;
1154 1155 1156 1157 1158 1159 1160
  }
}

AttrReader CompileTimeInferShapeContext::Attrs() const {
  return AttrReader(op_.GetAttrMap());
}

H
hong 已提交
1161
std::vector<std::string> CompileTimeInferShapeContext::Inputs(
1162
    const std::string &name) const {
1163
  return op_.Input(name, /*with_attr_var=*/true);
1164 1165
}

H
hong 已提交
1166
std::vector<std::string> CompileTimeInferShapeContext::Outputs(
1167 1168 1169 1170
    const std::string &name) const {
  return op_.Output(name);
}

F
fengjiayi 已提交
1171
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
F
fengjiayi 已提交
1172 1173
    const std::string &name) const {
  auto var = block_.FindVarRecursive(name);
1174 1175
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
1176 1177 1178 1179
  std::vector<DDim> res;
  try {
    auto shapes = var->GetShapes();
    for (const auto &s : shapes) {
1180
      res.push_back(s.empty() ? phi::make_ddim({0UL}) : phi::make_ddim(s));
F
fengjiayi 已提交
1181 1182
    }
  } catch (...) {
M
minqiyang 已提交
1183
    VLOG(5) << "GetRepeatedDim of variable " << name << " error.";
F
fengjiayi 已提交
1184 1185 1186
    std::rethrow_exception(std::current_exception());
  }
  return res;
1187 1188 1189 1190
}

void CompileTimeInferShapeContext::SetDim(const std::string &name,
                                          const DDim &dim) {
F
fengjiayi 已提交
1191
  block_.FindVarRecursive(name)->SetShape(vectorize(dim));
1192
}
F
fengjiayi 已提交
1193 1194 1195 1196

void CompileTimeInferShapeContext::SetRepeatedDims(
    const std::string &name, const std::vector<DDim> &dims) {
  auto var = block_.FindVarRecursive(name);
1197 1198
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
1199
  std::vector<std::vector<int64_t>> dim_vec(dims.size());
1200
  std::transform(dims.begin(), dims.end(), dim_vec.begin(), phi::vectorize<>);
F
fengjiayi 已提交
1201
  var->SetShapes(dim_vec);
1202
}
F
fengjiayi 已提交
1203

1204 1205
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }

1206 1207
bool CompileTimeInferShapeContext::IsRunMKLDNNKernel() const { return false; }

1208
proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
1209 1210 1211
    const std::string &name) const {
  return block_.FindVarRecursive(name)->GetType();
}
1212

1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228
std::vector<std::string> AttrVarNames(const Attribute &attr) {
  std::vector<std::string> vars_name;
  if (IsAttrVar(attr)) {
    vars_name.emplace_back(PADDLE_GET_CONST(VarDesc *, attr)->Name());
  } else if (IsAttrVars(attr)) {
    for (auto &iter : PADDLE_GET_CONST(std::vector<VarDesc *>, attr)) {
      vars_name.emplace_back(iter->Name());
    }
  } else {
    PADDLE_THROW(platform::errors::Unimplemented(
        "Unsupported Attribute value type `%s` for AttrVarNames",
        platform::demangle(attr.type().name())));
  }
  return vars_name;
}

F
fengjiayi 已提交
1229 1230
}  // namespace framework
}  // namespace paddle