op_desc.cc 34.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_desc.h"
16

17
#include <string>
18

R
Ruibiao Chen 已提交
19
#include "boost/blank.hpp"
20
#include "glog/logging.h"
Y
Yi Wang 已提交
21
#include "paddle/fluid/framework/block_desc.h"
22
#include "paddle/fluid/framework/op_call_stack.h"
Y
yuyang18 已提交
23
#include "paddle/fluid/framework/op_proto_maker.h"
Y
Yi Wang 已提交
24 25
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
M
minqiyang 已提交
26
#include "paddle/fluid/framework/var_type_inference.h"
Y
Yu Yang 已提交
27

F
fengjiayi 已提交
28 29 30
namespace paddle {
namespace framework {

31 32
class CompileTimeInferShapeContext : public InferShapeContext {
 public:
Y
Yu Yang 已提交
33
  CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
34 35 36 37 38

  bool HasInput(const std::string &name) const override;

  bool HasOutput(const std::string &name) const override;

39 40
  bool HasAttr(const std::string &name) const override;

41 42
  bool HasInputs(const std::string &name) const override;

43 44
  bool HasOutputs(const std::string &name,
                  bool allow_null = false) const override;
45 46 47

  AttrReader Attrs() const override;

H
hong 已提交
48
  std::vector<std::string> Inputs(const std::string &name) const override;
49

H
hong 已提交
50
  std::vector<std::string> Outputs(const std::string &name) const override;
51

52 53 54
  std::string GetInputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
55 56
    PADDLE_ENFORCE_LT(idx,
                      op_proto->inputs().size(),
57 58 59
                      platform::errors::OutOfRange(
                          "The index should be less than the size of inputs of "
                          "operator %s, but got index is %d and size is %d",
60 61 62
                          op_.Type(),
                          idx,
                          op_proto->inputs().size()));
63 64 65 66 67 68 69
    return op_proto->inputs()[idx].name();
  }

  std::string GetOutputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
    PADDLE_ENFORCE_LT(
70 71
        idx,
        op_proto->outputs().size(),
72 73 74
        platform::errors::OutOfRange(
            "The index should be less than the size of outputs of "
            "operator %s, but got index is %d and size is %d",
75 76 77
            op_.Type(),
            idx,
            op_proto->outputs().size()));
78 79 80
    return op_proto->outputs()[idx].name();
  }

81 82 83
  void ShareDim(const std::string &in,
                const std::string &out,
                size_t i = 0,
84
                size_t j = 0) override {
85 86
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
87 88 89
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
90 91 92 93
                          Inputs(in).size(),
                          i));
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
94 95 96
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
97 98
                          Outputs(out).size(),
                          j));
99

H
hong 已提交
100 101
    std::string input_n = Inputs(in)[i];
    std::string output_n = Outputs(out)[j];
102

103 104
    PADDLE_ENFORCE_NE(input_n,
                      framework::kEmptyVarName,
105 106
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
107 108
    PADDLE_ENFORCE_NE(output_n,
                      framework::kEmptyVarName,
109 110
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
111 112 113 114

    auto *in_var = block_.FindVarRecursive(input_n);
    auto *out_var = block_.FindVarRecursive(output_n);

115
    PADDLE_ENFORCE_EQ(
116 117
        in_var->GetType(),
        out_var->GetType(),
118 119 120
        platform::errors::InvalidArgument(
            "The type of input %s and output %s do not match. The input type "
            "is %s, output type is %s.",
121 122 123
            input_n,
            output_n,
            DataTypeToString(in_var->GetType()),
124
            DataTypeToString(out_var->GetType())));
125 126 127 128

    SetDim(output_n, GetDim(input_n));
  }

H
hong 已提交
129 130 131 132 133 134
  void ShareAllLoD(const std::string &in,
                   const std::string &out) const override {
    auto &in_var_names = op_.Input(in);
    auto &out_var_names = op_.Output(out);

    PADDLE_ENFORCE_EQ(
135 136
        in_var_names.size(),
        out_var_names.size(),
H
hong 已提交
137
        platform::errors::PreconditionNotMet(
T
tianshuo78520a 已提交
138
            "Op [%s]:  Input var number should be equal with output var number",
H
hong 已提交
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
            op_.Type()));

    for (size_t i = 0; i < in_var_names.size(); ++i) {
      if (out_var_names[i] == framework::kEmptyVarName) {
        continue;
      }

      auto *in_var = block_.FindVarRecursive(in_var_names[i]);
      auto *out_var = block_.FindVarRecursive(out_var_names[i]);
      if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
          in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
        VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
        return;
      }
      out_var->SetLoDLevel(in_var->GetLoDLevel());
    }
  }

157 158 159
  void ShareLoD(const std::string &in,
                const std::string &out,
                size_t i = 0,
Q
Qiao Longfei 已提交
160
                size_t j = 0) const override {
161 162
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
163 164 165
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
166 167 168 169
                          Inputs(in).size(),
                          i));
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
170 171 172
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
173 174 175 176
                          Outputs(out).size(),
                          j));
    PADDLE_ENFORCE_NE(Inputs(in)[i],
                      framework::kEmptyVarName,
177 178
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
179 180
    PADDLE_ENFORCE_NE(Outputs(out)[j],
                      framework::kEmptyVarName,
181 182
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
Q
Qiao Longfei 已提交
183 184
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
C
chengduo 已提交
185 186
    if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
        in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
187
      VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
X
fix  
Xin Pan 已提交
188 189
      return;
    }
190
    out_var->SetLoDLevel(in_var->GetLoDLevel());
Q
Qiao Longfei 已提交
191
  }
D
dzhwinter 已提交
192

193
  int32_t GetLoDLevel(const std::string &in, size_t i = 0) const override {
194 195
    PADDLE_ENFORCE_LT(i,
                      Inputs(in).size(),
196 197 198
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, input "
                          "variable %s of operator %s only has %d elements.",
199 200 201 202 203
                          in,
                          op_.Type(),
                          Inputs(in).size()));
    PADDLE_ENFORCE_NE(Inputs(in)[i],
                      framework::kEmptyVarName,
204 205
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] of operator %s is empty.",
206 207 208
                          in,
                          i,
                          op_.Type()));
C
chengduo 已提交
209
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
210
    PADDLE_ENFORCE_NOT_NULL(
211 212 213 214 215 216
        in_var,
        platform::errors::NotFound(
            "The input variable %s[%d] of operator %s is not found.",
            in,
            i,
            op_.Type()));
217
    return in_var->GetLoDLevel();
C
chengduo 已提交
218 219
  }

220 221
  void SetLoDLevel(const std::string &out,
                   int32_t lod_level,
222
                   size_t j = 0) const override {
223 224
    PADDLE_ENFORCE_LT(j,
                      Outputs(out).size(),
225 226 227
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, output "
                          "variable %s of operator %s only has %d elements.",
228 229 230 231 232
                          out,
                          op_.Type(),
                          Outputs(out).size()));
    PADDLE_ENFORCE_NE(Outputs(out)[j],
                      framework::kEmptyVarName,
233 234
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] of operator %s is empty.",
235 236 237
                          out,
                          j,
                          op_.Type()));
238
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
239
    PADDLE_ENFORCE_NOT_NULL(
240 241 242 243 244 245
        out_var,
        platform::errors::NotFound(
            "The output variable %s[%d] of operator %s is not found.",
            out,
            j,
            op_.Type()));
246 247 248
    if (lod_level >= 0) {
      out_var->SetLoDLevel(lod_level);
    }
249 250
  }

C
Chen Weihang 已提交
251
  paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize>
252
  GetInputVarPtrs(const std::string &name) const override {
253
    const std::vector<std::string> arg_names = Inputs(name);
C
Chen Weihang 已提交
254
    paddle::small_vector<InferShapeVarPtr, phi::kInputSmallVectorSize> res;
255
    res.reserve(arg_names.size());
256 257 258
    std::transform(arg_names.begin(),
                   arg_names.end(),
                   std::back_inserter(res),
259 260 261 262 263 264
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

C
Chen Weihang 已提交
265
  paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize>
266
  GetOutputVarPtrs(const std::string &name) const override {
267
    const std::vector<std::string> arg_names = Outputs(name);
C
Chen Weihang 已提交
268
    paddle::small_vector<InferShapeVarPtr, phi::kOutputSmallVectorSize> res;
269
    res.reserve(arg_names.size());
270 271 272
    std::transform(arg_names.begin(),
                   arg_names.end(),
                   std::back_inserter(res),
273 274 275 276 277 278
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

X
Xin Pan 已提交
279 280
  DDim GetInputDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
281 282
    PADDLE_ENFORCE_EQ(arg_names.size(),
                      1UL,
283 284 285
                      platform::errors::InvalidArgument(
                          "The input(%s) should hold only one element, but now "
                          "it holds %d elements.",
286 287
                          name,
                          arg_names.size()));
X
Xin Pan 已提交
288 289 290 291 292 293 294 295
    return this->GetDim(arg_names[0]);
  }

  std::vector<DDim> GetInputsDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
    return GetDims(arg_names);
  }

296 297
  bool IsRuntime() const override;

298 299
  bool IsRunMKLDNNKernel() const override;

300 301 302 303
  proto::VarType::Type GetInputVarType(const std::string &name) const override {
    return GetVarType(Inputs(name).at(0));
  }

X
Xin Pan 已提交
304 305 306 307 308 309 310 311 312 313
  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Inputs(name));
  }

  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Outputs(name));
  }

X
Xin Pan 已提交
314
  void SetOutputDim(const std::string &name, const DDim &dim) override {
H
hong 已提交
315
    auto arg_names = Outputs(name);
316 317
    PADDLE_ENFORCE_EQ(arg_names.size(),
                      1UL,
318 319 320
                      platform::errors::InvalidArgument(
                          "The iutput(%s) should hold only one element, but "
                          "now it holds %d elements.",
321 322
                          name,
                          arg_names.size()));
X
Xin Pan 已提交
323 324 325 326 327
    SetDim(arg_names[0], dim);
  }

  void SetOutputsDim(const std::string &name,
                     const std::vector<DDim> &dims) override {
H
hong 已提交
328
    auto names = Outputs(name);
X
Xin Pan 已提交
329 330 331
    SetDims(names, dims);
  }

332 333 334 335 336 337 338 339
  const phi::ArgumentMappingFn *GetPhiArgumentMappingFn() const override {
    return phi::OpUtilsMap::Instance().GetArgumentMappingFn(op_.Type());
  }

  const phi::KernelSignature *GetPhiDefaultKernelSignature() const override {
    return &phi::DefaultKernelSignatureMap::Instance().Get(op_.Type());
  }

340
 protected:
X
Xin Pan 已提交
341 342 343 344 345
  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<std::string> &names) const {
    std::vector<proto::VarType::Type> retv;
    retv.resize(names.size());
    std::transform(
346 347 348 349 350
        names.begin(),
        names.end(),
        retv.begin(),
        std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType),
                  this,
X
Xin Pan 已提交
351 352 353 354 355
                  std::placeholders::_1));
    return retv;
  }

  proto::VarType::Type GetVarType(const std::string &name) const;
Q
Qiao Longfei 已提交
356

X
Xin Pan 已提交
357 358
  DDim GetDim(const std::string &name) const {
    auto var = block_.FindVarRecursive(name);
359 360
    PADDLE_ENFORCE_NOT_NULL(
        var, platform::errors::NotFound("Variable %s is not found.", name));
X
Xin Pan 已提交
361 362 363
    DDim res;
    try {
      auto shape = var->GetShape();
364
      res = shape.empty() ? phi::make_ddim({0UL}) : phi::make_ddim(shape);
X
Xin Pan 已提交
365 366 367 368 369 370 371 372 373 374 375
    } catch (...) {
      VLOG(5) << "GetDim of variable " << name << " error";
      std::rethrow_exception(std::current_exception());
    }
    return res;
  }

  std::vector<DDim> GetDims(const std::vector<std::string> &names) const {
    std::vector<DDim> ret;
    ret.reserve(names.size());
    std::transform(
376 377 378
        names.begin(),
        names.end(),
        std::back_inserter(ret),
X
Xin Pan 已提交
379 380 381
        [this](const std::string &name) { return this->GetDim(name); });
    return ret;
  }
382

X
Xin Pan 已提交
383 384 385 386 387
  void SetDim(const std::string &name, const DDim &dim);

  void SetDims(const std::vector<std::string> &names,
               const std::vector<DDim> &dims) {
    size_t length = names.size();
388 389
    PADDLE_ENFORCE_EQ(length,
                      dims.size(),
390 391 392
                      platform::errors::InvalidArgument(
                          "The input variables number(%d) and input dimensions "
                          "number(%d) do not match.",
393 394
                          length,
                          dims.size()));
X
Xin Pan 已提交
395 396 397 398 399 400 401
    for (size_t i = 0; i < length; ++i) {
      if (names[i] == framework::kEmptyVarName) {
        continue;
      }
      SetDim(names[i], dims[i]);
    }
  }
402

F
fengjiayi 已提交
403 404 405 406
  std::vector<DDim> GetRepeatedDims(const std::string &name) const override;

  void SetRepeatedDims(const std::string &name,
                       const std::vector<DDim> &dims) override;
F
fengjiayi 已提交
407

Y
Yu Yang 已提交
408 409
  const OpDesc &op_;
  const BlockDesc &block_;
410 411
};

412 413 414 415
OpDesc::OpDesc(const std::string &type,
               const VariableNameMap &inputs,
               const VariableNameMap &outputs,
               const AttributeMap &attrs) {
416
  desc_.set_type(type);
F
fengjiayi 已提交
417 418 419
  inputs_ = inputs;
  outputs_ = outputs;
  attrs_ = attrs;
F
Fix bug  
fengjiayi 已提交
420
  need_update_ = true;
L
liuwei1031 已提交
421
  block_ = nullptr;
F
fengjiayi 已提交
422 423
}

X
Xin Pan 已提交
424 425 426 427 428 429
OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) {
  CopyFrom(other);
  block_ = block;
  need_update_ = true;
}

430
void OpDesc::CopyFrom(const OpDesc &op_desc) {
F
fengjiayi 已提交
431 432 433 434
  desc_.set_type(op_desc.Type());
  inputs_ = op_desc.inputs_;
  outputs_ = op_desc.outputs_;
  attrs_ = op_desc.attrs_;
435 436
  // The record of original_id_ is only for auto parallel.
  original_id_ = op_desc.original_id_;
F
fengjiayi 已提交
437 438 439
  need_update_ = true;
}

F
fengjiayi 已提交
440
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
441 442 443 444
    : desc_(desc), need_update_(false) {
  // restore inputs_
  int input_size = desc_.inputs_size();
  for (int i = 0; i < input_size; ++i) {
445
    const proto::OpDesc::Var &var = desc_.inputs(i);
446 447 448 449 450 451 452 453 454 455
    std::vector<std::string> &args = inputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore outputs_
  int output_size = desc_.outputs_size();
  for (int i = 0; i < output_size; ++i) {
456
    const proto::OpDesc::Var &var = desc_.outputs(i);
457 458 459 460 461 462 463 464
    std::vector<std::string> &args = outputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore attrs_
465
  for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
466
    std::string attr_name = attr.name();
467
    // The sub_block referred to by the BLOCK attr hasn't been added
X
Xin Pan 已提交
468 469 470
    // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS attr here.
    if (attr.type() != proto::AttrType::BLOCK &&
        attr.type() != proto::AttrType::BLOCKS) {
471 472
      attrs_[attr_name] = GetAttrValue(attr);
    }
473
  }
474
  this->block_ = block;
475 476
}

Y
Yu Yang 已提交
477
proto::OpDesc *OpDesc::Proto() {
478
  Flush();
479
  return &desc_;
F
fengjiayi 已提交
480 481
}

Y
Yu Yang 已提交
482
const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
F
fengjiayi 已提交
483
  auto it = inputs_.find(name);
484
  PADDLE_ENFORCE_NE(
485 486 487 488
      it,
      inputs_.end(),
      platform::errors::NotFound(
          "Input %s cannot be found in operator %s.", name, Type()));
F
fengjiayi 已提交
489 490 491
  return it->second;
}

Y
Yu Yang 已提交
492
std::vector<std::string> OpDesc::InputArgumentNames() const {
F
Update  
fengjiayi 已提交
493 494 495 496 497 498 499
  std::vector<std::string> retv;
  for (auto &ipt : this->inputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
500 501
void OpDesc::SetInput(const std::string &param_name,
                      const std::vector<std::string> &args) {
F
fengjiayi 已提交
502 503 504 505
  need_update_ = true;
  inputs_[param_name] = args;
}

Y
Yu Yang 已提交
506
const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
F
fengjiayi 已提交
507
  auto it = outputs_.find(name);
508
  PADDLE_ENFORCE_NE(
509 510 511 512
      it,
      outputs_.end(),
      platform::errors::NotFound(
          "Output %s cannot be found in operator %s.", name, Type()));
F
fengjiayi 已提交
513 514 515
  return it->second;
}

516 517 518 519
bool OpDesc::HasOutput(const std::string &name) const {
  return outputs_.find(name) != outputs_.end();
}

Y
Yu Yang 已提交
520
std::vector<std::string> OpDesc::OutputArgumentNames() const {
F
Update  
fengjiayi 已提交
521 522 523 524 525 526 527
  std::vector<std::string> retv;
  for (auto &ipt : this->outputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
528 529
void OpDesc::SetOutput(const std::string &param_name,
                       const std::vector<std::string> &args) {
F
fengjiayi 已提交
530 531 532 533
  need_update_ = true;
  this->outputs_[param_name] = args;
}

534 535 536 537 538
void OpDesc::RemoveOutput(const std::string &name) {
  outputs_.erase(name);
  need_update_ = true;
}

539 540 541 542 543
void OpDesc::RemoveInput(const std::string &name) {
  inputs_.erase(name);
  need_update_ = true;
}

544 545 546 547 548 549 550 551 552 553
bool OpDesc::HasProtoAttr(const std::string &name) const {
  auto &op_info = OpInfoMap::Instance();
  if (op_info.Has(desc_.type())) {
    auto op_info_ptr = op_info.Get(desc_.type());
    if (op_info_ptr.HasOpProtoAndChecker()) {
      const proto::OpProto &proto = op_info_ptr.Proto();
      for (int i = 0; i != proto.attrs_size(); ++i) {
        const proto::OpProto::Attr &attr = proto.attrs(i);
        if (attr.name() == name) {
          return true;
L
luotao1 已提交
554 555
        }
      }
L
luotao1 已提交
556 557 558 559 560
    }
  }
  return false;
}

Y
Yu Yang 已提交
561
proto::AttrType OpDesc::GetAttrType(const std::string &name) const {
F
fengjiayi 已提交
562
  auto it = attrs_.find(name);
563
  PADDLE_ENFORCE_NE(
564 565
      it,
      attrs_.end(),
566
      platform::errors::NotFound("Attribute %s is not found.", name));
R
Ruibiao Chen 已提交
567
  return static_cast<proto::AttrType>(it->second.index() - 1);
F
fengjiayi 已提交
568 569
}

Y
Yu Yang 已提交
570
std::vector<std::string> OpDesc::AttrNames() const {
F
fengjiayi 已提交
571 572 573 574 575 576 577 578
  std::vector<std::string> retv;
  retv.reserve(attrs_.size());
  for (auto &attr : attrs_) {
    retv.push_back(attr.first);
  }
  return retv;
}

579 580 581 582 583
void OpDesc::RemoveAttr(const std::string &name) {
  attrs_.erase(name);
  need_update_ = true;
}

Y
Yu Yang 已提交
584
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
M
minqiyang 已提交
585 586 587
  // NOTICE(minqiyang): pybind11 will take the empty list in python as
  // the std::vector<int> type in C++; so we have to change the attr's type
  // here if we meet this issue
R
Ruibiao Chen 已提交
588
  proto::AttrType attr_type = static_cast<proto::AttrType>(v.index() - 1);
M
minqiyang 已提交
589
  if (attr_type == proto::AttrType::INTS &&
590
      BOOST_GET_CONST(std::vector<int>, v).size() == 0u) {
M
minqiyang 已提交
591
    // Find current attr via attr name and set the correct attribute value
M
minqiyang 已提交
592
    const proto::OpProto::Attr &attr = GetProtoAttr(name);
M
minqiyang 已提交
593 594
    switch (attr.type()) {
      case proto::AttrType::BOOLEANS: {
M
minqiyang 已提交
595 596
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BOOLEANS";
M
minqiyang 已提交
597 598 599 600
        this->attrs_[name] = std::vector<bool>();
        break;
      }
      case proto::AttrType::INTS: {
M
minqiyang 已提交
601 602
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to INTS";
M
minqiyang 已提交
603 604 605
        this->attrs_[name] = std::vector<int>();
        break;
      }
606
      case proto::AttrType::LONGS: {
M
minqiyang 已提交
607 608
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from LONGS to LONGS";
609 610 611
        this->attrs_[name] = std::vector<int64_t>();
        break;
      }
M
minqiyang 已提交
612
      case proto::AttrType::FLOATS: {
M
minqiyang 已提交
613 614
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to FLOATS";
M
minqiyang 已提交
615 616 617 618
        this->attrs_[name] = std::vector<float>();
        break;
      }
      case proto::AttrType::STRINGS: {
M
minqiyang 已提交
619 620
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to STRINGS";
M
minqiyang 已提交
621 622 623 624
        this->attrs_[name] = std::vector<std::string>();
        break;
      }
      case proto::AttrType::BLOCKS: {
M
minqiyang 已提交
625 626
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BLOCKS";
M
minqiyang 已提交
627
        this->SetBlocksAttr(name, std::vector<BlockDesc *>());
M
minqiyang 已提交
628 629
        return;
      }
M
minqiyang 已提交
630
      default:
631 632
        PADDLE_THROW(platform::errors::Unimplemented(
            "Unsupported attribute type (code %d).", attr.type()));
M
minqiyang 已提交
633
    }
M
minqiyang 已提交
634 635
    need_update_ = true;
    return;
M
minqiyang 已提交
636 637
  }

638 639 640
  // In order to set bool attr properly
  if (attr_type == proto::AttrType::INT && HasProtoAttr(name) &&
      GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) {
641
    this->attrs_[name] = static_cast<bool>(BOOST_GET_CONST(int, v));
642 643 644 645
    need_update_ = true;
    return;
  }

F
fengjiayi 已提交
646 647 648 649
  this->attrs_[name] = v;
  need_update_ = true;
}

A
Abhinav Arora 已提交
650 651
void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) {
  this->attrs_[name] = block;
F
fengjiayi 已提交
652
  need_update_ = true;
F
fengjiayi 已提交
653 654
}

655 656 657 658 659 660
void OpDesc::SetBlocksAttr(const std::string &name,
                           std::vector<BlockDesc *> blocks) {
  this->attrs_[name] = blocks;
  need_update_ = true;
}

Y
Yu Yang 已提交
661
void OpDesc::SetAttrMap(
F
fengjiayi 已提交
662 663 664 665 666
    const std::unordered_map<std::string, Attribute> &attr_map) {
  attrs_ = attr_map;
  need_update_ = true;
}

Y
Yu Yang 已提交
667
Attribute OpDesc::GetAttr(const std::string &name) const {
F
fengjiayi 已提交
668
  auto it = attrs_.find(name);
669
  PADDLE_ENFORCE_NE(
670 671
      it,
      attrs_.end(),
672
      platform::errors::NotFound("Attribute %s is not found.", name));
F
fengjiayi 已提交
673 674 675
  return it->second;
}

M
minqiyang 已提交
676 677 678
const proto::OpProto::Attr &OpDesc::GetProtoAttr(
    const std::string &name) const {
  const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
M
minqiyang 已提交
679 680 681 682 683 684 685
  for (int i = 0; i != proto.attrs_size(); ++i) {
    const proto::OpProto::Attr &attr = proto.attrs(i);
    if (attr.name() == name) {
      return attr;
    }
  }

686 687
  PADDLE_THROW(platform::errors::NotFound(
      "Attribute %s is not found in proto %s.", name, proto.type()));
M
minqiyang 已提交
688 689
}

Y
yuyang18 已提交
690
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
Y
Fix bug  
yuyang18 已提交
691 692 693 694
  auto it = attrs_.find(name);
  if (it != attrs_.end()) {
    return it->second;
  } else {
Y
yuyang18 已提交
695
    return Attribute();
Y
Fix bug  
yuyang18 已提交
696 697 698
  }
}

G
gongweibao 已提交
699 700
std::vector<int> OpDesc::GetBlocksAttrIds(const std::string &name) const {
  auto it = attrs_.find(name);
701
  PADDLE_ENFORCE_NE(
702 703
      it,
      attrs_.end(),
704 705
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
706
  auto blocks = BOOST_GET_CONST(std::vector<BlockDesc *>, it->second);
G
gongweibao 已提交
707 708 709 710 711 712 713 714 715 716

  std::vector<int> ids;
  for (auto n : blocks) {
    ids.push_back(n->ID());
  }

  return ids;
}

int OpDesc::GetBlockAttrId(const std::string &name) const {
F
fengjiayi 已提交
717
  auto it = attrs_.find(name);
718
  PADDLE_ENFORCE_NE(
719 720
      it,
      attrs_.end(),
721 722
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
723
  return BOOST_GET_CONST(BlockDesc *, it->second)->ID();
F
fengjiayi 已提交
724 725
}

Y
Yu Yang 已提交
726
const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
F
fengjiayi 已提交
727 728 729
  return attrs_;
}

Y
Yu Yang 已提交
730
void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
Y
Yancey1989 已提交
731 732
  RenameInput(old_name, new_name);
  RenameOutput(old_name, new_name);
F
fengjiayi 已提交
733 734 735
  need_update_ = true;
}

Y
Yu Yang 已提交
736 737
void OpDesc::RenameOutput(const std::string &old_name,
                          const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
738
  for (auto &output : outputs_) {
739 740
    std::replace(
        output.second.begin(), output.second.end(), old_name, new_name);
Y
Yang Yang(Tony) 已提交
741
  }
Y
yuyang18 已提交
742 743 744

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
745
    auto &op_vars = BOOST_GET(std::vector<std::string>, it->second);
Y
yuyang18 已提交
746 747 748
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
749 750 751
  need_update_ = true;
}

Y
Yu Yang 已提交
752 753
void OpDesc::RenameInput(const std::string &old_name,
                         const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
754 755 756
  for (auto &input : inputs_) {
    std::replace(input.second.begin(), input.second.end(), old_name, new_name);
  }
Y
Yancey1989 已提交
757 758 759

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
760
    auto &op_vars = BOOST_GET(std::vector<std::string>, it->second);
Y
Yancey1989 已提交
761 762 763
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
764 765 766
  need_update_ = true;
}

Y
Yu Yang 已提交
767
struct SetAttrDescVisitor : public boost::static_visitor<void> {
768 769
  explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
  mutable proto::OpDesc::Attr *attr_;
Y
Yu Yang 已提交
770 771 772
  void operator()(int v) const { attr_->set_i(v); }
  void operator()(float v) const { attr_->set_f(v); }
  void operator()(const std::string &v) const { attr_->set_s(v); }
Q
QI JUN 已提交
773 774 775 776 777 778 779

  // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
  template <class T,
            class = typename std::enable_if<std::is_same<bool, T>::value>::type>
  void operator()(T b) const {
    attr_->set_b(b);
  }
Y
Yu Yang 已提交
780 781 782 783 784 785 786 787 788 789 790 791 792

  void operator()(const std::vector<int> &v) const {
    VectorToRepeated(v, attr_->mutable_ints());
  }
  void operator()(const std::vector<float> &v) const {
    VectorToRepeated(v, attr_->mutable_floats());
  }
  void operator()(const std::vector<std::string> &v) const {
    VectorToRepeated(v, attr_->mutable_strings());
  }
  void operator()(const std::vector<bool> &v) const {
    VectorToRepeated(v, attr_->mutable_bools());
  }
793 794 795
  void operator()(const std::vector<BlockDesc *> &v) const {
    std::vector<int> blocks_idx;
    for (auto blk : v) {
T
tangwei12 已提交
796
      blocks_idx.push_back(blk->ID());
797 798 799
    }
    VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
  }
T
tangwei12 已提交
800 801 802

  void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }

803
  void operator()(int64_t v) const { attr_->set_l(v); }
T
tangwei12 已提交
804 805 806 807 808

  void operator()(const std::vector<int64_t> &v) const {
    VectorToRepeated(v, attr_->mutable_longs());
  }

809 810 811 812
  void operator()(const std::vector<double> &v) const {
    VectorToRepeated(v, attr_->mutable_float64s());
  }

813 814 815 816 817
  void operator()(boost::blank) const {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported calling method of SetAttrDescVisitor object for "
        "`boosst::blank` type."));
  }
Y
Yu Yang 已提交
818 819
};

Y
Yu Yang 已提交
820
void OpDesc::Flush() {
F
fengjiayi 已提交
821
  if (need_update_) {
822
    this->desc_.mutable_inputs()->Clear();
F
fengjiayi 已提交
823
    for (auto &ipt : inputs_) {
824
      auto *input = desc_.add_inputs();
F
fengjiayi 已提交
825 826 827 828
      input->set_parameter(ipt.first);
      VectorToRepeated(ipt.second, input->mutable_arguments());
    }

829
    this->desc_.mutable_outputs()->Clear();
F
fengjiayi 已提交
830
    for (auto &opt : outputs_) {
831
      auto *output = desc_.add_outputs();
F
fengjiayi 已提交
832 833 834 835
      output->set_parameter(opt.first);
      VectorToRepeated(opt.second, output->mutable_arguments());
    }

836
    this->desc_.mutable_attrs()->Clear();
F
fengjiayi 已提交
837
    for (auto &attr : attrs_) {
838
      auto *attr_desc = desc_.add_attrs();
F
fengjiayi 已提交
839 840
      attr_desc->set_name(attr.first);
      attr_desc->set_type(
R
Ruibiao Chen 已提交
841
          static_cast<proto::AttrType>(attr.second.index() - 1));
Y
Yu Yang 已提交
842
      SetAttrDescVisitor visitor(attr_desc);
R
Ruibiao Chen 已提交
843
      paddle::visit(visitor, attr.second);
F
fengjiayi 已提交
844 845 846 847 848
    }

    need_update_ = false;
  }
}
Y
Yu Yang 已提交
849

Y
Yu Yang 已提交
850
void OpDesc::CheckAttrs() {
851 852
  PADDLE_ENFORCE_EQ(Type().empty(),
                    false,
853 854
                    platform::errors::PreconditionNotMet(
                        "CheckAttrs() can not be called before type is set."));
Y
Yu Yang 已提交
855 856 857 858 859 860
  auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
  if (checker == nullptr) {
    // checker is not configured. That operator could be generated by Paddle,
    // not by users.
    return;
  }
861
  VLOG(10) << "begin to check attribute of " << Type();
T
tangwei12 已提交
862
  checker->Check(&attrs_);
F
fengjiayi 已提交
863 864
}

H
hong 已提交
865
void OpDesc::InferShape(const BlockDesc &block) {
866 867
  try {
    VLOG(3) << "CompileTime infer shape on " << Type();
H
hong 已提交
868 869 870 871 872 873 874 875
    auto &op_info = OpInfoMap::Instance().Get(this->Type());
    auto *checker = op_info.Checker();
    if (checker != nullptr) {
      // set dafault value here
      VLOG(10) << "begin to check attribute of " << Type();
      checker->Check(&attrs_);
    }
    auto &infer_shape = op_info.infer_shape_;
876
    PADDLE_ENFORCE_EQ(
877 878
        static_cast<bool>(infer_shape),
        true,
879 880
        platform::errors::NotFound(
            "Operator %s's infer_shape is not registered.", this->Type()));
881 882 883 884 885
    CompileTimeInferShapeContext ctx(*this, block);
    if (VLOG_IS_ON(10)) {
      std::ostringstream sout;
      auto inames = this->InputArgumentNames();
      sout << " From [";
886 887
      std::copy(inames.begin(),
                inames.end(),
888 889 890
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "] to [";
      auto onames = this->OutputArgumentNames();
891 892
      std::copy(onames.begin(),
                onames.end(),
893 894 895 896 897
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "]";
      VLOG(10) << sout.str();
    }
    infer_shape(&ctx);
898
  } catch (platform::EnforceNotMet &exception) {
899
    framework::AppendErrorOpHint(Type(), &exception);
900 901 902 903
    throw std::move(exception);
  } catch (...) {
    std::rethrow_exception(std::current_exception());
  }
Y
Yu Yang 已提交
904 905
}

Y
Yu Yang 已提交
906
void OpDesc::InferVarType(BlockDesc *block) const {
X
Xin Pan 已提交
907 908
  // There are a few places that var type can be set.
  // When VarDesc is created, default set to LOD_TENSOR.
T
tianshuo78520a 已提交
909
  // When output variable is created, default is default set to LOD_TENSOR.
X
Xin Pan 已提交
910 911
  // We limit here to be the only place that operator defines its customized
  // var type inference. Hence, we don't do any "default" setting here.
Y
Yu Yang 已提交
912 913
  auto &info = OpInfoMap::Instance().Get(this->Type());
  if (info.infer_var_type_) {
M
minqiyang 已提交
914
    InferVarTypeContext context(this, block);
M
minqiyang 已提交
915
    info.infer_var_type_(&context);
Y
Yu Yang 已提交
916 917 918
  }
}

919
CompileTimeInferShapeContext::CompileTimeInferShapeContext(
Y
Yu Yang 已提交
920
    const OpDesc &op, const BlockDesc &block)
921 922 923
    : op_(op), block_(block) {}

bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
924 925 926
  if (op_.Inputs().find(name) == op_.Inputs().end()) {
    return false;
  }
927 928 929 930 931
  const std::vector<std::string> &input_names = op_.Input(name);
  auto length = input_names.size();
  if (length == 0) {
    return false;
  }
932
  PADDLE_ENFORCE_EQ(
933 934
      length,
      1UL,
935 936
      platform::errors::InvalidArgument("Input(%s) should have only one value, "
                                        "but it has %d values now.",
937 938
                                        name,
                                        length));
939 940 941 942
  return block_.HasVarRecursive(input_names[0]);
}

bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
943 944 945
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
946 947 948 949 950
  const std::vector<std::string> &output_names = op_.Output(name);
  auto length = output_names.size();
  if (length == 0) {
    return false;
  }
951 952
  PADDLE_ENFORCE_EQ(length,
                    1UL,
953 954 955
                    platform::errors::InvalidArgument(
                        "Output(%s) should have only one value, "
                        "but it has %d values now.",
956 957
                        name,
                        length));
958 959 960
  return block_.HasVarRecursive(output_names[0]);
}

961 962 963 964
bool CompileTimeInferShapeContext::HasAttr(const std::string &name) const {
  return op_.HasAttr(name);
}

965
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
966 967 968
  if (op_.Inputs().find(name) == op_.Inputs().end()) {
    return false;
  }
969 970 971 972 973 974 975 976 977 978
  const std::vector<std::string> &input_names = op_.Input(name);
  if (input_names.empty()) {
    return false;
  }
  for (auto &input : input_names) {
    if (!block_.HasVarRecursive(input)) return false;
  }
  return true;
}

979 980
bool CompileTimeInferShapeContext::HasOutputs(const std::string &name,
                                              bool allow_null) const {
981 982 983
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
984 985 986 987
  const std::vector<std::string> &output_names = op_.Output(name);
  if (output_names.empty()) {
    return false;
  }
988 989 990 991 992 993 994 995 996 997
  if (allow_null) {
    for (auto &output : output_names) {
      if (block_.HasVarRecursive(output)) return true;
    }
    return false;
  } else {
    for (auto &output : output_names) {
      if (!block_.HasVarRecursive(output)) return false;
    }
    return true;
998 999 1000 1001 1002 1003 1004
  }
}

AttrReader CompileTimeInferShapeContext::Attrs() const {
  return AttrReader(op_.GetAttrMap());
}

H
hong 已提交
1005
std::vector<std::string> CompileTimeInferShapeContext::Inputs(
1006 1007 1008 1009
    const std::string &name) const {
  return op_.Input(name);
}

H
hong 已提交
1010
std::vector<std::string> CompileTimeInferShapeContext::Outputs(
1011 1012 1013 1014
    const std::string &name) const {
  return op_.Output(name);
}

F
fengjiayi 已提交
1015
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
F
fengjiayi 已提交
1016 1017
    const std::string &name) const {
  auto var = block_.FindVarRecursive(name);
1018 1019
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
1020 1021 1022 1023
  std::vector<DDim> res;
  try {
    auto shapes = var->GetShapes();
    for (const auto &s : shapes) {
1024
      res.push_back(s.empty() ? phi::make_ddim({0UL}) : phi::make_ddim(s));
F
fengjiayi 已提交
1025 1026
    }
  } catch (...) {
M
minqiyang 已提交
1027
    VLOG(5) << "GetRepeatedDim of variable " << name << " error.";
F
fengjiayi 已提交
1028 1029 1030
    std::rethrow_exception(std::current_exception());
  }
  return res;
1031 1032 1033 1034
}

void CompileTimeInferShapeContext::SetDim(const std::string &name,
                                          const DDim &dim) {
F
fengjiayi 已提交
1035
  block_.FindVarRecursive(name)->SetShape(vectorize(dim));
1036
}
F
fengjiayi 已提交
1037 1038 1039 1040

void CompileTimeInferShapeContext::SetRepeatedDims(
    const std::string &name, const std::vector<DDim> &dims) {
  auto var = block_.FindVarRecursive(name);
1041 1042
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
1043
  std::vector<std::vector<int64_t>> dim_vec(dims.size());
1044
  std::transform(dims.begin(), dims.end(), dim_vec.begin(), phi::vectorize<>);
F
fengjiayi 已提交
1045
  var->SetShapes(dim_vec);
1046
}
F
fengjiayi 已提交
1047

1048 1049
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }

1050 1051
bool CompileTimeInferShapeContext::IsRunMKLDNNKernel() const { return false; }

1052
proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
1053 1054 1055
    const std::string &name) const {
  return block_.FindVarRecursive(name)->GetType();
}
1056

F
fengjiayi 已提交
1057 1058
}  // namespace framework
}  // namespace paddle