op_desc.cc 34.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_desc.h"
16

17
#include <string>
18

19
#include "glog/logging.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/framework/block_desc.h"
21
#include "paddle/fluid/framework/op_call_stack.h"
Y
yuyang18 已提交
22
#include "paddle/fluid/framework/op_proto_maker.h"
Y
Yi Wang 已提交
23 24
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
M
minqiyang 已提交
25
#include "paddle/fluid/framework/var_type_inference.h"
Y
Yu Yang 已提交
26

F
fengjiayi 已提交
27 28 29
namespace paddle {
namespace framework {

30 31
class CompileTimeInferShapeContext : public InferShapeContext {
 public:
Y
Yu Yang 已提交
32
  CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
33 34 35 36 37

  bool HasInput(const std::string &name) const override;

  bool HasOutput(const std::string &name) const override;

38 39
  bool HasAttr(const std::string &name) const override;

40 41
  bool HasInputs(const std::string &name) const override;

42 43
  bool HasOutputs(const std::string &name,
                  bool allow_null = false) const override;
44 45 46

  AttrReader Attrs() const override;

H
hong 已提交
47
  std::vector<std::string> Inputs(const std::string &name) const override;
48

H
hong 已提交
49
  std::vector<std::string> Outputs(const std::string &name) const override;
50

51 52 53
  std::string GetInputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
54 55
    PADDLE_ENFORCE_LT(idx,
                      op_proto->inputs().size(),
56 57 58
                      platform::errors::OutOfRange(
                          "The index should be less than the size of inputs of "
                          "operator %s, but got index is %d and size is %d",
59 60 61
                          op_.Type(),
                          idx,
                          op_proto->inputs().size()));
62 63 64 65 66 67 68
    return op_proto->inputs()[idx].name();
  }

  std::string GetOutputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
    PADDLE_ENFORCE_LT(
69 70
        idx,
        op_proto->outputs().size(),
71 72 73
        platform::errors::OutOfRange(
            "The index should be less than the size of outputs of "
            "operator %s, but got index is %d and size is %d",
74 75 76
            op_.Type(),
            idx,
            op_proto->outputs().size()));
77 78 79
    return op_proto->outputs()[idx].name();
  }

80 81 82
  void ShareDim(const std::string &in,
                const std::string &out,
                size_t i = 0,
83
                size_t j = 0) override {
84 85
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
86 87 88
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
89 90 91 92
                          Inputs(in).size(),
                          i));
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
93 94 95
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
96 97
                          Outputs(out).size(),
                          j));
98

H
hong 已提交
99 100
    std::string input_n = Inputs(in)[i];
    std::string output_n = Outputs(out)[j];
101

102 103
    PADDLE_ENFORCE_NE(input_n,
                      framework::kEmptyVarName,
104 105
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
106 107
    PADDLE_ENFORCE_NE(output_n,
                      framework::kEmptyVarName,
108 109
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
110 111 112 113

    auto *in_var = block_.FindVarRecursive(input_n);
    auto *out_var = block_.FindVarRecursive(output_n);

114
    PADDLE_ENFORCE_EQ(
115 116
        in_var->GetType(),
        out_var->GetType(),
117 118 119
        platform::errors::InvalidArgument(
            "The type of input %s and output %s do not match. The input type "
            "is %s, output type is %s.",
120 121 122
            input_n,
            output_n,
            DataTypeToString(in_var->GetType()),
123
            DataTypeToString(out_var->GetType())));
124 125 126 127

    SetDim(output_n, GetDim(input_n));
  }

H
hong 已提交
128 129 130 131 132 133
  void ShareAllLoD(const std::string &in,
                   const std::string &out) const override {
    auto &in_var_names = op_.Input(in);
    auto &out_var_names = op_.Output(out);

    PADDLE_ENFORCE_EQ(
134 135
        in_var_names.size(),
        out_var_names.size(),
H
hong 已提交
136
        platform::errors::PreconditionNotMet(
T
tianshuo78520a 已提交
137
            "Op [%s]:  Input var number should be equal with output var number",
H
hong 已提交
138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
            op_.Type()));

    for (size_t i = 0; i < in_var_names.size(); ++i) {
      if (out_var_names[i] == framework::kEmptyVarName) {
        continue;
      }

      auto *in_var = block_.FindVarRecursive(in_var_names[i]);
      auto *out_var = block_.FindVarRecursive(out_var_names[i]);
      if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
          in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
        VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
        return;
      }
      out_var->SetLoDLevel(in_var->GetLoDLevel());
    }
  }

156 157 158
  void ShareLoD(const std::string &in,
                const std::string &out,
                size_t i = 0,
Q
Qiao Longfei 已提交
159
                size_t j = 0) const override {
160 161
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
162 163 164
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
165 166 167 168
                          Inputs(in).size(),
                          i));
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
169 170 171
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
172 173 174 175
                          Outputs(out).size(),
                          j));
    PADDLE_ENFORCE_NE(Inputs(in)[i],
                      framework::kEmptyVarName,
176 177
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
178 179
    PADDLE_ENFORCE_NE(Outputs(out)[j],
                      framework::kEmptyVarName,
180 181
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
Q
Qiao Longfei 已提交
182 183
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
C
chengduo 已提交
184 185
    if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
        in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
186
      VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
X
fix  
Xin Pan 已提交
187 188
      return;
    }
189
    out_var->SetLoDLevel(in_var->GetLoDLevel());
Q
Qiao Longfei 已提交
190
  }
D
dzhwinter 已提交
191

192
  int32_t GetLoDLevel(const std::string &in, size_t i = 0) const override {
193 194
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
195 196 197
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, input "
                          "variable %s of operator %s only has %d elements.",
198 199 200 201 202
                          in,
                          op_.Type(),
                          Inputs(in).size()));
    PADDLE_ENFORCE_NE(Inputs(in)[i],
                      framework::kEmptyVarName,
203 204
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] of operator %s is empty.",
205 206 207
                          in,
                          i,
                          op_.Type()));
C
chengduo 已提交
208
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
209
    PADDLE_ENFORCE_NOT_NULL(
210 211 212 213 214 215
        in_var,
        platform::errors::NotFound(
            "The input variable %s[%d] of operator %s is not found.",
            in,
            i,
            op_.Type()));
216
    return in_var->GetLoDLevel();
C
chengduo 已提交
217 218
  }

219 220
  void SetLoDLevel(const std::string &out,
                   int32_t lod_level,
221
                   size_t j = 0) const override {
222 223
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
224 225 226
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, output "
                          "variable %s of operator %s only has %d elements.",
227 228 229 230 231
                          out,
                          op_.Type(),
                          Outputs(out).size()));
    PADDLE_ENFORCE_NE(Outputs(out)[j],
                      framework::kEmptyVarName,
232 233
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] of operator %s is empty.",
234 235 236
                          out,
                          j,
                          op_.Type()));
237
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
238
    PADDLE_ENFORCE_NOT_NULL(
239 240 241 242 243 244
        out_var,
        platform::errors::NotFound(
            "The output variable %s[%d] of operator %s is not found.",
            out,
            j,
            op_.Type()));
245 246 247
    if (lod_level >= 0) {
      out_var->SetLoDLevel(lod_level);
    }
248 249
  }

C
Chen Weihang 已提交
250
  paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
251
  GetInputVarPtrs(const std::string &name) const override {
252
    const std::vector<std::string> arg_names = Inputs(name);
C
Chen Weihang 已提交
253
    paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
254
    res.reserve(arg_names.size());
255 256 257
    std::transform(arg_names.begin(),
                   arg_names.end(),
                   std::back_inserter(res),
258 259 260 261 262 263
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

C
Chen Weihang 已提交
264
  paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
265
  GetOutputVarPtrs(const std::string &name) const override {
266
    const std::vector<std::string> arg_names = Outputs(name);
C
Chen Weihang 已提交
267
    paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
268
    res.reserve(arg_names.size());
269 270 271
    std::transform(arg_names.begin(),
                   arg_names.end(),
                   std::back_inserter(res),
272 273 274 275 276 277
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

X
Xin Pan 已提交
278 279
  DDim GetInputDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
280 281
    PADDLE_ENFORCE_EQ(arg_names.size(),
                      1UL,
282 283 284
                      platform::errors::InvalidArgument(
                          "The input(%s) should hold only one element, but now "
                          "it holds %d elements.",
285 286
                          name,
                          arg_names.size()));
X
Xin Pan 已提交
287 288 289 290 291 292 293 294
    return this->GetDim(arg_names[0]);
  }

  std::vector<DDim> GetInputsDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
    return GetDims(arg_names);
  }

295 296
  bool IsRuntime() const override;

297 298
  bool IsRunMKLDNNKernel() const override;

299 300 301 302
  proto::VarType::Type GetInputVarType(const std::string &name) const override {
    return GetVarType(Inputs(name).at(0));
  }

X
Xin Pan 已提交
303 304 305 306 307 308 309 310 311 312
  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Inputs(name));
  }

  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Outputs(name));
  }

X
Xin Pan 已提交
313
  void SetOutputDim(const std::string &name, const DDim &dim) override {
H
hong 已提交
314
    auto arg_names = Outputs(name);
315 316
    PADDLE_ENFORCE_EQ(arg_names.size(),
                      1UL,
317 318 319
                      platform::errors::InvalidArgument(
                          "The iutput(%s) should hold only one element, but "
                          "now it holds %d elements.",
320 321
                          name,
                          arg_names.size()));
X
Xin Pan 已提交
322 323 324 325 326
    SetDim(arg_names[0], dim);
  }

  void SetOutputsDim(const std::string &name,
                     const std::vector<DDim> &dims) override {
H
hong 已提交
327
    auto names = Outputs(name);
X
Xin Pan 已提交
328 329 330
    SetDims(names, dims);
  }

331 332 333 334 335 336 337 338
  const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const override {
    return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
  }

  const phi::KernelSignature *GetPhiDefaultKernelSignature() const override {
    return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
  }

339
 protected:
X
Xin Pan 已提交
340 341 342 343 344
  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<std::string> &names) const {
    std::vector<proto::VarType::Type> retv;
    retv.resize(names.size());
    std::transform(
345 346 347 348 349
        names.begin(),
        names.end(),
        retv.begin(),
        std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType),
                  this,
X
Xin Pan 已提交
350 351 352 353 354
                  std::placeholders::_1));
    return retv;
  }

  proto::VarType::Type GetVarType(const std::string &name) const;
Q
Qiao Longfei 已提交
355

X
Xin Pan 已提交
356 357
  DDim GetDim(const std::string &name) const {
    auto var = block_.FindVarRecursive(name);
358 359
    PADDLE_ENFORCE_NOT_NULL(
        var, platform::errors::NotFound("Variable %s is not found.", name));
X
Xin Pan 已提交
360 361 362
    DDim res;
    try {
      auto shape = var->GetShape();
363
      res = shape.empty() ? phi::make_ddim({0UL}) : phi::make_ddim(shape);
X
Xin Pan 已提交
364 365 366 367 368 369 370 371 372 373 374
    } catch (...) {
      VLOG(5) << "GetDim of variable " << name << " error";
      std::rethrow_exception(std::current_exception());
    }
    return res;
  }

  std::vector<DDim> GetDims(const std::vector<std::string> &names) const {
    std::vector<DDim> ret;
    ret.reserve(names.size());
    std::transform(
375 376 377
        names.begin(),
        names.end(),
        std::back_inserter(ret),
X
Xin Pan 已提交
378 379 380
        [this](const std::string &name) { return this->GetDim(name); });
    return ret;
  }
381

X
Xin Pan 已提交
382 383 384 385 386
  void SetDim(const std::string &name, const DDim &dim);

  void SetDims(const std::vector<std::string> &names,
               const std::vector<DDim> &dims) {
    size_t length = names.size();
387 388
    PADDLE_ENFORCE_EQ(length,
                      dims.size(),
389 390 391
                      platform::errors::InvalidArgument(
                          "The input variables number(%d) and input dimensions "
                          "number(%d) do not match.",
392 393
                          length,
                          dims.size()));
X
Xin Pan 已提交
394 395 396 397 398 399 400
    for (size_t i = 0; i < length; ++i) {
      if (names[i] == framework::kEmptyVarName) {
        continue;
      }
      SetDim(names[i], dims[i]);
    }
  }
401

F
fengjiayi 已提交
402 403 404 405
  std::vector<DDim> GetRepeatedDims(const std::string &name) const override;

  void SetRepeatedDims(const std::string &name,
                       const std::vector<DDim> &dims) override;
F
fengjiayi 已提交
406

Y
Yu Yang 已提交
407 408
  const OpDesc &op_;
  const BlockDesc &block_;
409 410
};

411 412 413 414
OpDesc::OpDesc(const std::string &type,
               const VariableNameMap &inputs,
               const VariableNameMap &outputs,
               const AttributeMap &attrs) {
415
  desc_.set_type(type);
F
fengjiayi 已提交
416 417 418
  inputs_ = inputs;
  outputs_ = outputs;
  attrs_ = attrs;
F
Fix bug  
fengjiayi 已提交
419
  need_update_ = true;
L
liuwei1031 已提交
420
  block_ = nullptr;
F
fengjiayi 已提交
421 422
}

X
Xin Pan 已提交
423 424 425 426 427 428
OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) {
  CopyFrom(other);
  block_ = block;
  need_update_ = true;
}

429
void OpDesc::CopyFrom(const OpDesc &op_desc) {
F
fengjiayi 已提交
430 431 432 433
  desc_.set_type(op_desc.Type());
  inputs_ = op_desc.inputs_;
  outputs_ = op_desc.outputs_;
  attrs_ = op_desc.attrs_;
434 435
  // The record of original_id_ is only for auto parallel.
  original_id_ = op_desc.original_id_;
F
fengjiayi 已提交
436 437 438
  need_update_ = true;
}

F
fengjiayi 已提交
439
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
440 441 442 443
    : desc_(desc), need_update_(false) {
  // restore inputs_
  int input_size = desc_.inputs_size();
  for (int i = 0; i < input_size; ++i) {
444
    const proto::OpDesc::Var &var = desc_.inputs(i);
445 446 447 448 449 450 451 452 453 454
    std::vector<std::string> &args = inputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore outputs_
  int output_size = desc_.outputs_size();
  for (int i = 0; i < output_size; ++i) {
455
    const proto::OpDesc::Var &var = desc_.outputs(i);
456 457 458 459 460 461 462 463
    std::vector<std::string> &args = outputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore attrs_
464
  for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
465
    std::string attr_name = attr.name();
466
    // The sub_block referred to by the BLOCK attr hasn't been added
X
Xin Pan 已提交
467 468 469
    // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS attr here.
    if (attr.type() != proto::AttrType::BLOCK &&
        attr.type() != proto::AttrType::BLOCKS) {
470 471
      attrs_[attr_name] = GetAttrValue(attr);
    }
472
  }
473
  this->block_ = block;
474 475
}

Y
Yu Yang 已提交
476
proto::OpDesc *OpDesc::Proto() {
477
  Flush();
478
  return &desc_;
F
fengjiayi 已提交
479 480
}

Y
Yu Yang 已提交
481
const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
F
fengjiayi 已提交
482
  auto it = inputs_.find(name);
483
  PADDLE_ENFORCE_NE(
484 485 486 487
      it,
      inputs_.end(),
      platform::errors::NotFound(
          "Input %s cannot be found in operator %s.", name, Type()));
F
fengjiayi 已提交
488 489 490
  return it->second;
}

Y
Yu Yang 已提交
491
std::vector<std::string> OpDesc::InputArgumentNames() const {
F
Update  
fengjiayi 已提交
492 493 494 495 496 497 498
  std::vector<std::string> retv;
  for (auto &ipt : this->inputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
499 500
void OpDesc::SetInput(const std::string &param_name,
                      const std::vector<std::string> &args) {
F
fengjiayi 已提交
501 502 503 504
  need_update_ = true;
  inputs_[param_name] = args;
}

Y
Yu Yang 已提交
505
const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
F
fengjiayi 已提交
506
  auto it = outputs_.find(name);
507
  PADDLE_ENFORCE_NE(
508 509 510 511
      it,
      outputs_.end(),
      platform::errors::NotFound(
          "Output %s cannot be found in operator %s.", name, Type()));
F
fengjiayi 已提交
512 513 514
  return it->second;
}

515 516 517 518
bool OpDesc::HasOutput(const std::string &name) const {
  return outputs_.find(name) != outputs_.end();
}

Y
Yu Yang 已提交
519
std::vector<std::string> OpDesc::OutputArgumentNames() const {
F
Update  
fengjiayi 已提交
520 521 522 523 524 525 526
  std::vector<std::string> retv;
  for (auto &ipt : this->outputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
527 528
void OpDesc::SetOutput(const std::string &param_name,
                       const std::vector<std::string> &args) {
F
fengjiayi 已提交
529 530 531 532
  need_update_ = true;
  this->outputs_[param_name] = args;
}

533 534 535 536 537
void OpDesc::RemoveOutput(const std::string &name) {
  outputs_.erase(name);
  need_update_ = true;
}

538 539 540 541 542
void OpDesc::RemoveInput(const std::string &name) {
  inputs_.erase(name);
  need_update_ = true;
}

543 544 545 546 547 548 549 550 551 552
bool OpDesc::HasProtoAttr(const std::string &name) const {
  auto &op_info = OpInfoMap::Instance();
  if (op_info.Has(desc_.type())) {
    auto op_info_ptr = op_info.Get(desc_.type());
    if (op_info_ptr.HasOpProtoAndChecker()) {
      const proto::OpProto &proto = op_info_ptr.Proto();
      for (int i = 0; i != proto.attrs_size(); ++i) {
        const proto::OpProto::Attr &attr = proto.attrs(i);
        if (attr.name() == name) {
          return true;
L
luotao1 已提交
553 554
        }
      }
L
luotao1 已提交
555 556 557 558 559
    }
  }
  return false;
}

Y
Yu Yang 已提交
560
proto::AttrType OpDesc::GetAttrType(const std::string &name) const {
F
fengjiayi 已提交
561
  auto it = attrs_.find(name);
562
  PADDLE_ENFORCE_NE(
563 564
      it,
      attrs_.end(),
565
      platform::errors::NotFound("Attribute %s is not found.", name));
566
  return static_cast<proto::AttrType>(it->second.which() - 1);
F
fengjiayi 已提交
567 568
}

Y
Yu Yang 已提交
569
std::vector<std::string> OpDesc::AttrNames() const {
F
fengjiayi 已提交
570 571 572 573 574 575 576 577
  std::vector<std::string> retv;
  retv.reserve(attrs_.size());
  for (auto &attr : attrs_) {
    retv.push_back(attr.first);
  }
  return retv;
}

578 579 580 581 582
void OpDesc::RemoveAttr(const std::string &name) {
  attrs_.erase(name);
  need_update_ = true;
}

Y
Yu Yang 已提交
583
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
M
minqiyang 已提交
584 585 586 587 588
  // NOTICE(minqiyang): pybind11 will take the empty list in python as
  // the std::vector<int> type in C++; so we have to change the attr's type
  // here if we meet this issue
  proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
  if (attr_type == proto::AttrType::INTS &&
589
      BOOST_GET_CONST(std::vector<int>, v).size() == 0u) {
M
minqiyang 已提交
590
    // Find current attr via attr name and set the correct attribute value
M
minqiyang 已提交
591
    const proto::OpProto::Attr &attr = GetProtoAttr(name);
M
minqiyang 已提交
592 593
    switch (attr.type()) {
      case proto::AttrType::BOOLEANS: {
M
minqiyang 已提交
594 595
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BOOLEANS";
M
minqiyang 已提交
596 597 598 599
        this->attrs_[name] = std::vector<bool>();
        break;
      }
      case proto::AttrType::INTS: {
M
minqiyang 已提交
600 601
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to INTS";
M
minqiyang 已提交
602 603 604
        this->attrs_[name] = std::vector<int>();
        break;
      }
605
      case proto::AttrType::LONGS: {
M
minqiyang 已提交
606 607
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from LONGS to LONGS";
608 609 610
        this->attrs_[name] = std::vector<int64_t>();
        break;
      }
M
minqiyang 已提交
611
      case proto::AttrType::FLOATS: {
M
minqiyang 已提交
612 613
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to FLOATS";
M
minqiyang 已提交
614 615 616 617
        this->attrs_[name] = std::vector<float>();
        break;
      }
      case proto::AttrType::STRINGS: {
M
minqiyang 已提交
618 619
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to STRINGS";
M
minqiyang 已提交
620 621 622 623
        this->attrs_[name] = std::vector<std::string>();
        break;
      }
      case proto::AttrType::BLOCKS: {
M
minqiyang 已提交
624 625
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BLOCKS";
M
minqiyang 已提交
626
        this->SetBlocksAttr(name, std::vector<BlockDesc *>());
M
minqiyang 已提交
627 628
        return;
      }
M
minqiyang 已提交
629
      default:
630 631
        PADDLE_THROW(platform::errors::Unimplemented(
            "Unsupported attribute type (code %d).", attr.type()));
M
minqiyang 已提交
632
    }
M
minqiyang 已提交
633 634
    need_update_ = true;
    return;
M
minqiyang 已提交
635 636
  }

637 638 639
  // In order to set bool attr properly
  if (attr_type == proto::AttrType::INT && HasProtoAttr(name) &&
      GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) {
640
    this->attrs_[name] = static_cast<bool>(BOOST_GET_CONST(int, v));
641 642 643 644
    need_update_ = true;
    return;
  }

F
fengjiayi 已提交
645 646 647 648
  this->attrs_[name] = v;
  need_update_ = true;
}

A
Abhinav Arora 已提交
649 650
void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) {
  this->attrs_[name] = block;
F
fengjiayi 已提交
651
  need_update_ = true;
F
fengjiayi 已提交
652 653
}

654 655 656 657 658 659
void OpDesc::SetBlocksAttr(const std::string &name,
                           std::vector<BlockDesc *> blocks) {
  this->attrs_[name] = blocks;
  need_update_ = true;
}

Y
Yu Yang 已提交
660
void OpDesc::SetAttrMap(
F
fengjiayi 已提交
661 662 663 664 665
    const std::unordered_map<std::string, Attribute> &attr_map) {
  attrs_ = attr_map;
  need_update_ = true;
}

Y
Yu Yang 已提交
666
Attribute OpDesc::GetAttr(const std::string &name) const {
F
fengjiayi 已提交
667
  auto it = attrs_.find(name);
668
  PADDLE_ENFORCE_NE(
669 670
      it,
      attrs_.end(),
671
      platform::errors::NotFound("Attribute %s is not found.", name));
F
fengjiayi 已提交
672 673 674
  return it->second;
}

M
minqiyang 已提交
675 676 677
const proto::OpProto::Attr &OpDesc::GetProtoAttr(
    const std::string &name) const {
  const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
M
minqiyang 已提交
678 679 680 681 682 683 684
  for (int i = 0; i != proto.attrs_size(); ++i) {
    const proto::OpProto::Attr &attr = proto.attrs(i);
    if (attr.name() == name) {
      return attr;
    }
  }

685 686
  PADDLE_THROW(platform::errors::NotFound(
      "Attribute %s is not found in proto %s.", name, proto.type()));
M
minqiyang 已提交
687 688
}

Y
yuyang18 已提交
689
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
Y
Fix bug  
yuyang18 已提交
690 691 692 693
  auto it = attrs_.find(name);
  if (it != attrs_.end()) {
    return it->second;
  } else {
Y
yuyang18 已提交
694
    return Attribute();
Y
Fix bug  
yuyang18 已提交
695 696 697
  }
}

G
gongweibao 已提交
698 699
std::vector<int> OpDesc::GetBlocksAttrIds(const std::string &name) const {
  auto it = attrs_.find(name);
700
  PADDLE_ENFORCE_NE(
701 702
      it,
      attrs_.end(),
703 704
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
705
  auto blocks = BOOST_GET_CONST(std::vector<BlockDesc *>, it->second);
G
gongweibao 已提交
706 707 708 709 710 711 712 713 714 715

  std::vector<int> ids;
  for (auto n : blocks) {
    ids.push_back(n->ID());
  }

  return ids;
}

int OpDesc::GetBlockAttrId(const std::string &name) const {
F
fengjiayi 已提交
716
  auto it = attrs_.find(name);
717
  PADDLE_ENFORCE_NE(
718 719
      it,
      attrs_.end(),
720 721
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
722
  return BOOST_GET_CONST(BlockDesc *, it->second)->ID();
F
fengjiayi 已提交
723 724
}

Y
Yu Yang 已提交
725
const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
F
fengjiayi 已提交
726 727 728
  return attrs_;
}

Y
Yu Yang 已提交
729
void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
Y
Yancey1989 已提交
730 731
  RenameInput(old_name, new_name);
  RenameOutput(old_name, new_name);
F
fengjiayi 已提交
732 733 734
  need_update_ = true;
}

Y
Yu Yang 已提交
735 736
void OpDesc::RenameOutput(const std::string &old_name,
                          const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
737
  for (auto &output : outputs_) {
738 739
    std::replace(
        output.second.begin(), output.second.end(), old_name, new_name);
Y
Yang Yang(Tony) 已提交
740
  }
Y
yuyang18 已提交
741 742 743

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
744
    auto &op_vars = BOOST_GET(std::vector<std::string>, it->second);
Y
yuyang18 已提交
745 746 747
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
748 749 750
  need_update_ = true;
}

Y
Yu Yang 已提交
751 752
void OpDesc::RenameInput(const std::string &old_name,
                         const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
753 754 755
  for (auto &input : inputs_) {
    std::replace(input.second.begin(), input.second.end(), old_name, new_name);
  }
Y
Yancey1989 已提交
756 757 758

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
759
    auto &op_vars = BOOST_GET(std::vector<std::string>, it->second);
Y
Yancey1989 已提交
760 761 762
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
763 764 765
  need_update_ = true;
}

Y
Yu Yang 已提交
766
struct SetAttrDescVisitor : public boost::static_visitor<void> {
767 768
  explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
  mutable proto::OpDesc::Attr *attr_;
Y
Yu Yang 已提交
769 770 771
  void operator()(int v) const { attr_->set_i(v); }
  void operator()(float v) const { attr_->set_f(v); }
  void operator()(const std::string &v) const { attr_->set_s(v); }
Q
QI JUN 已提交
772 773 774 775 776 777 778

  // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
  template <class T,
            class = typename std::enable_if<std::is_same<bool, T>::value>::type>
  void operator()(T b) const {
    attr_->set_b(b);
  }
Y
Yu Yang 已提交
779 780 781 782 783 784 785 786 787 788 789 790 791

  void operator()(const std::vector<int> &v) const {
    VectorToRepeated(v, attr_->mutable_ints());
  }
  void operator()(const std::vector<float> &v) const {
    VectorToRepeated(v, attr_->mutable_floats());
  }
  void operator()(const std::vector<std::string> &v) const {
    VectorToRepeated(v, attr_->mutable_strings());
  }
  void operator()(const std::vector<bool> &v) const {
    VectorToRepeated(v, attr_->mutable_bools());
  }
792 793 794
  void operator()(const std::vector<BlockDesc *> &v) const {
    std::vector<int> blocks_idx;
    for (auto blk : v) {
T
tangwei12 已提交
795
      blocks_idx.push_back(blk->ID());
796 797 798
    }
    VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
  }
T
tangwei12 已提交
799 800 801

  void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }

802
  void operator()(int64_t v) const { attr_->set_l(v); }
T
tangwei12 已提交
803 804 805 806 807

  void operator()(const std::vector<int64_t> &v) const {
    VectorToRepeated(v, attr_->mutable_longs());
  }

808 809 810 811
  void operator()(const std::vector<double> &v) const {
    VectorToRepeated(v, attr_->mutable_float64s());
  }

812 813 814 815 816
  void operator()(boost::blank) const {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported calling method of SetAttrDescVisitor object for "
        "`boosst::blank` type."));
  }
Y
Yu Yang 已提交
817 818
};

Y
Yu Yang 已提交
819
void OpDesc::Flush() {
F
fengjiayi 已提交
820
  if (need_update_) {
821
    this->desc_.mutable_inputs()->Clear();
F
fengjiayi 已提交
822
    for (auto &ipt : inputs_) {
823
      auto *input = desc_.add_inputs();
F
fengjiayi 已提交
824 825 826 827
      input->set_parameter(ipt.first);
      VectorToRepeated(ipt.second, input->mutable_arguments());
    }

828
    this->desc_.mutable_outputs()->Clear();
F
fengjiayi 已提交
829
    for (auto &opt : outputs_) {
830
      auto *output = desc_.add_outputs();
F
fengjiayi 已提交
831 832 833 834
      output->set_parameter(opt.first);
      VectorToRepeated(opt.second, output->mutable_arguments());
    }

835
    this->desc_.mutable_attrs()->Clear();
F
fengjiayi 已提交
836
    for (auto &attr : attrs_) {
837
      auto *attr_desc = desc_.add_attrs();
F
fengjiayi 已提交
838 839
      attr_desc->set_name(attr.first);
      attr_desc->set_type(
840
          static_cast<proto::AttrType>(attr.second.which() - 1));
Y
Yu Yang 已提交
841 842
      SetAttrDescVisitor visitor(attr_desc);
      boost::apply_visitor(visitor, attr.second);
F
fengjiayi 已提交
843 844 845 846 847
    }

    need_update_ = false;
  }
}
Y
Yu Yang 已提交
848

Y
Yu Yang 已提交
849
void OpDesc::CheckAttrs() {
850 851
  PADDLE_ENFORCE_EQ(Type().empty(),
                    false,
852 853
                    platform::errors::PreconditionNotMet(
                        "CheckAttrs() can not be called before type is set."));
Y
Yu Yang 已提交
854 855 856 857 858 859
  auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
  if (checker == nullptr) {
    // checker is not configured. That operator could be generated by Paddle,
    // not by users.
    return;
  }
860
  VLOG(10) << "begin to check attribute of " << Type();
T
tangwei12 已提交
861
  checker->Check(&attrs_);
F
fengjiayi 已提交
862 863
}

H
hong 已提交
864
void OpDesc::InferShape(const BlockDesc &block) {
865 866
  try {
    VLOG(3) << "CompileTime infer shape on " << Type();
H
hong 已提交
867 868 869 870 871 872 873 874
    auto &op_info = OpInfoMap::Instance().Get(this->Type());
    auto *checker = op_info.Checker();
    if (checker != nullptr) {
      // set dafault value here
      VLOG(10) << "begin to check attribute of " << Type();
      checker->Check(&attrs_);
    }
    auto &infer_shape = op_info.infer_shape_;
875
    PADDLE_ENFORCE_EQ(
876 877
        static_cast<bool>(infer_shape),
        true,
878 879
        platform::errors::NotFound(
            "Operator %s's infer_shape is not registered.", this->Type()));
880 881 882 883 884
    CompileTimeInferShapeContext ctx(*this, block);
    if (VLOG_IS_ON(10)) {
      std::ostringstream sout;
      auto inames = this->InputArgumentNames();
      sout << " From [";
885 886
      std::copy(inames.begin(),
                inames.end(),
887 888 889
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "] to [";
      auto onames = this->OutputArgumentNames();
890 891
      std::copy(onames.begin(),
                onames.end(),
892 893 894 895 896
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "]";
      VLOG(10) << sout.str();
    }
    infer_shape(&ctx);
897
  } catch (platform::EnforceNotMet &exception) {
898
    framework::AppendErrorOpHint(Type(), &exception);
899 900 901 902
    throw std::move(exception);
  } catch (...) {
    std::rethrow_exception(std::current_exception());
  }
Y
Yu Yang 已提交
903 904
}

Y
Yu Yang 已提交
905
void OpDesc::InferVarType(BlockDesc *block) const {
X
Xin Pan 已提交
906 907
  // There are a few places that var type can be set.
  // When VarDesc is created, default set to LOD_TENSOR.
T
tianshuo78520a 已提交
908
  // When output variable is created, default is default set to LOD_TENSOR.
X
Xin Pan 已提交
909 910
  // We limit here to be the only place that operator defines its customized
  // var type inference. Hence, we don't do any "default" setting here.
Y
Yu Yang 已提交
911 912
  auto &info = OpInfoMap::Instance().Get(this->Type());
  if (info.infer_var_type_) {
M
minqiyang 已提交
913
    InferVarTypeContext context(this, block);
M
minqiyang 已提交
914
    info.infer_var_type_(&context);
Y
Yu Yang 已提交
915 916 917
  }
}

918
CompileTimeInferShapeContext::CompileTimeInferShapeContext(
Y
Yu Yang 已提交
919
    const OpDesc &op, const BlockDesc &block)
920 921 922
    : op_(op), block_(block) {}

bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
923 924 925
  if (op_.Inputs().find(name) == op_.Inputs().end()) {
    return false;
  }
926 927 928 929 930
  const std::vector<std::string> &input_names = op_.Input(name);
  auto length = input_names.size();
  if (length == 0) {
    return false;
  }
931
  PADDLE_ENFORCE_EQ(
932 933
      length,
      1UL,
934 935
      platform::errors::InvalidArgument("Input(%s) should have only one value, "
                                        "but it has %d values now.",
936 937
                                        name,
                                        length));
938 939 940 941
  return block_.HasVarRecursive(input_names[0]);
}

bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
942 943 944
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
945 946 947 948 949
  const std::vector<std::string> &output_names = op_.Output(name);
  auto length = output_names.size();
  if (length == 0) {
    return false;
  }
950 951
  PADDLE_ENFORCE_EQ(length,
                    1UL,
952 953 954
                    platform::errors::InvalidArgument(
                        "Output(%s) should have only one value, "
                        "but it has %d values now.",
955 956
                        name,
                        length));
957 958 959
  return block_.HasVarRecursive(output_names[0]);
}

960 961 962 963
bool CompileTimeInferShapeContext::HasAttr(const std::string &name) const {
  return op_.HasAttr(name);
}

964
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
965 966 967
  if (op_.Inputs().find(name) == op_.Inputs().end()) {
    return false;
  }
968 969 970 971 972 973 974 975 976 977
  const std::vector<std::string> &input_names = op_.Input(name);
  if (input_names.empty()) {
    return false;
  }
  for (auto &input : input_names) {
    if (!block_.HasVarRecursive(input)) return false;
  }
  return true;
}

978 979
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name,
                                              bool allow_null) const {
980 981 982
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
983 984 985 986
  const std::vector<std::string> &output_names = op_.Output(name);
  if (output_names.empty()) {
    return false;
  }
987 988 989 990 991 992 993 994 995 996
  if (allow_null) {
    for (auto &output : output_names) {
      if (block_.HasVarRecursive(output)) return true;
    }
    return false;
  } else {
    for (auto &output : output_names) {
      if (!block_.HasVarRecursive(output)) return false;
    }
    return true;
997 998 999 1000 1001 1002 1003
  }
}

AttrReader CompileTimeInferShapeContext::Attrs() const {
  return AttrReader(op_.GetAttrMap());
}

H
hong 已提交
1004
std::vector<std::string> CompileTimeInferShapeContext::Inputs(
1005 1006 1007 1008
    const std::string &name) const {
  return op_.Input(name);
}

H
hong 已提交
1009
std::vector<std::string> CompileTimeInferShapeContext::Outputs(
1010 1011 1012 1013
    const std::string &name) const {
  return op_.Output(name);
}

F
fengjiayi 已提交
1014
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
F
fengjiayi 已提交
1015 1016
    const std::string &name) const {
  auto var = block_.FindVarRecursive(name);
1017 1018
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
1019 1020 1021 1022
  std::vector<DDim> res;
  try {
    auto shapes = var->GetShapes();
    for (const auto &s : shapes) {
1023
      res.push_back(s.empty() ? phi::make_ddim({0UL}) : phi::make_ddim(s));
F
fengjiayi 已提交
1024 1025
    }
  } catch (...) {
M
minqiyang 已提交
1026
    VLOG(5) << "GetRepeatedDim of variable " << name << " error.";
F
fengjiayi 已提交
1027 1028 1029
    std::rethrow_exception(std::current_exception());
  }
  return res;
1030 1031 1032 1033
}

void CompileTimeInferShapeContext::SetDim(const std::string &name,
                                          const DDim &dim) {
F
fengjiayi 已提交
1034
  block_.FindVarRecursive(name)->SetShape(vectorize(dim));
1035
}
F
fengjiayi 已提交
1036 1037 1038 1039

void CompileTimeInferShapeContext::SetRepeatedDims(
    const std::string &name, const std::vector<DDim> &dims) {
  auto var = block_.FindVarRecursive(name);
1040 1041
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
1042
  std::vector<std::vector<int64_t>> dim_vec(dims.size());
1043
  std::transform(dims.begin(), dims.end(), dim_vec.begin(), phi::vectorize<>);
F
fengjiayi 已提交
1044
  var->SetShapes(dim_vec);
1045
}
F
fengjiayi 已提交
1046

1047 1048
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }

1049 1050
bool CompileTimeInferShapeContext::IsRunMKLDNNKernel() const { return false; }

1051
proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
1052 1053 1054
    const std::string &name) const {
  return block_.FindVarRecursive(name)->GetType();
}
1055

F
fengjiayi 已提交
1056 1057
}  // namespace framework
}  // namespace paddle