op_desc.cc 32.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/op_desc.h"
16

17
#include <string>
18

19
#include "glog/logging.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/framework/block_desc.h"
21
#include "paddle/fluid/framework/op_call_stack.h"
Y
yuyang18 已提交
22
#include "paddle/fluid/framework/op_proto_maker.h"
Y
Yi Wang 已提交
23 24
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
M
minqiyang 已提交
25
#include "paddle/fluid/framework/var_type_inference.h"
Y
Yu Yang 已提交
26

F
fengjiayi 已提交
27 28 29
namespace paddle {
namespace framework {

30 31
class CompileTimeInferShapeContext : public InferShapeContext {
 public:
Y
Yu Yang 已提交
32
  CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
33 34 35 36 37

  bool HasInput(const std::string &name) const override;

  bool HasOutput(const std::string &name) const override;

38 39
  bool HasAttr(const std::string &name) const override;

40 41 42 43 44 45
  bool HasInputs(const std::string &name) const override;

  bool HasOutputs(const std::string &name) const override;

  AttrReader Attrs() const override;

H
hong 已提交
46
  std::vector<std::string> Inputs(const std::string &name) const override;
47

H
hong 已提交
48
  std::vector<std::string> Outputs(const std::string &name) const override;
49

50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
  std::string GetInputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
    PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(),
                      platform::errors::OutOfRange(
                          "The index should be less than the size of inputs of "
                          "operator %s, but got index is %d and size is %d",
                          op_.Type(), idx, op_proto->inputs().size()));
    return op_proto->inputs()[idx].name();
  }

  std::string GetOutputNameByIdx(size_t idx) const override {
    auto &op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
    PADDLE_ENFORCE_LT(
        idx, op_proto->outputs().size(),
        platform::errors::OutOfRange(
            "The index should be less than the size of outputs of "
            "operator %s, but got index is %d and size is %d",
            op_.Type(), idx, op_proto->outputs().size()));
    return op_proto->outputs()[idx].name();
  }

73 74
  void ShareDim(const std::string &in, const std::string &out, size_t i = 0,
                size_t j = 0) override {
75 76 77 78 79 80 81 82 83 84 85
    PADDLE_ENFORCE_LT(i, Inputs(in).size(),
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
                          Inputs(in).size(), i));
    PADDLE_ENFORCE_LT(j, Outputs(out).size(),
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
                          Outputs(out).size(), j));

H
hong 已提交
86 87
    std::string input_n = Inputs(in)[i];
    std::string output_n = Outputs(out)[j];
88

89 90 91 92 93 94
    PADDLE_ENFORCE_NE(input_n, framework::kEmptyVarName,
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
    PADDLE_ENFORCE_NE(output_n, framework::kEmptyVarName,
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
95 96 97 98

    auto *in_var = block_.FindVarRecursive(input_n);
    auto *out_var = block_.FindVarRecursive(output_n);

99 100 101 102 103 104 105
    PADDLE_ENFORCE_EQ(
        in_var->GetType(), out_var->GetType(),
        platform::errors::InvalidArgument(
            "The type of input %s and output %s do not match. The input type "
            "is %s, output type is %s.",
            input_n, output_n, DataTypeToString(in_var->GetType()),
            DataTypeToString(out_var->GetType())));
106 107 108 109

    SetDim(output_n, GetDim(input_n));
  }

H
hong 已提交
110 111 112 113 114 115 116 117
  void ShareAllLoD(const std::string &in,
                   const std::string &out) const override {
    auto &in_var_names = op_.Input(in);
    auto &out_var_names = op_.Output(out);

    PADDLE_ENFORCE_EQ(
        in_var_names.size(), out_var_names.size(),
        platform::errors::PreconditionNotMet(
T
tianshuo78520a 已提交
118
            "Op [%s]:  Input var number should be equal with output var number",
H
hong 已提交
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
            op_.Type()));

    for (size_t i = 0; i < in_var_names.size(); ++i) {
      if (out_var_names[i] == framework::kEmptyVarName) {
        continue;
      }

      auto *in_var = block_.FindVarRecursive(in_var_names[i]);
      auto *out_var = block_.FindVarRecursive(out_var_names[i]);
      if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
          in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
        VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
        return;
      }
      out_var->SetLoDLevel(in_var->GetLoDLevel());
    }
  }

Q
Qiao Longfei 已提交
137 138
  void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
                size_t j = 0) const override {
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
    PADDLE_ENFORCE_LT(i, Inputs(in).size(),
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
                          Inputs(in).size(), i));
    PADDLE_ENFORCE_LT(j, Outputs(out).size(),
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, expected "
                          "index less than %d, but received index is %d.",
                          Outputs(out).size(), j));
    PADDLE_ENFORCE_NE(Inputs(in)[i], framework::kEmptyVarName,
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] is empty.", in, i));
    PADDLE_ENFORCE_NE(Outputs(out)[j], framework::kEmptyVarName,
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] is empty.", out, j));
Q
Qiao Longfei 已提交
155 156
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
C
chengduo 已提交
157 158
    if (in_var->GetType() != proto::VarType::LOD_TENSOR &&
        in_var->GetType() != proto::VarType::LOD_TENSOR_ARRAY) {
159
      VLOG(3) << "input " << in << " is not LoDTensor or LoDTensorArray.";
X
fix  
Xin Pan 已提交
160 161
      return;
    }
162
    out_var->SetLoDLevel(in_var->GetLoDLevel());
Q
Qiao Longfei 已提交
163
  }
D
dzhwinter 已提交
164

165 166
  int32_t GetLoDLevel(const std::string &in, size_t i = 0) const override {
    PADDLE_ENFORCE_LT(i, Inputs(in).size(),
167 168 169 170
                      platform::errors::InvalidArgument(
                          "The input variable index is out of range, input "
                          "variable %s of operator %s only has %d elements.",
                          in, op_.Type(), Inputs(in).size()));
171
    PADDLE_ENFORCE_NE(Inputs(in)[i], framework::kEmptyVarName,
172 173 174
                      platform::errors::InvalidArgument(
                          "The input variable %s[%d] of operator %s is empty.",
                          in, i, op_.Type()));
C
chengduo 已提交
175
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
176
    PADDLE_ENFORCE_NOT_NULL(
177 178 179
        in_var, platform::errors::NotFound(
                    "The input variable %s[%d] of operator %s is not found.",
                    in, i, op_.Type()));
180
    return in_var->GetLoDLevel();
C
chengduo 已提交
181 182
  }

183 184 185
  void SetLoDLevel(const std::string &out, int32_t lod_level,
                   size_t j = 0) const override {
    PADDLE_ENFORCE_LT(j, Outputs(out).size(),
186 187 188 189
                      platform::errors::InvalidArgument(
                          "The output variable index is out of range, output "
                          "variable %s of operator %s only has %d elements.",
                          out, op_.Type(), Outputs(out).size()));
190
    PADDLE_ENFORCE_NE(Outputs(out)[j], framework::kEmptyVarName,
191 192 193
                      platform::errors::InvalidArgument(
                          "The output variable %s[%d] of operator %s is empty.",
                          out, j, op_.Type()));
194
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
195
    PADDLE_ENFORCE_NOT_NULL(
196 197 198
        out_var, platform::errors::NotFound(
                     "The output variable %s[%d] of operator %s is not found.",
                     out, j, op_.Type()));
199 200 201
    if (lod_level >= 0) {
      out_var->SetLoDLevel(lod_level);
    }
202 203
  }

204
  std::vector<InferShapeVarPtr> GetInputVarPtrs(
205
      const std::string &name) const override {
206 207 208 209 210 211 212 213 214 215 216
    const std::vector<std::string> arg_names = Inputs(name);
    std::vector<InferShapeVarPtr> res;
    res.reserve(arg_names.size());
    std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

  std::vector<InferShapeVarPtr> GetOutputVarPtrs(
217
      const std::string &name) const override {
218 219 220 221 222 223 224 225 226 227
    const std::vector<std::string> arg_names = Outputs(name);
    std::vector<InferShapeVarPtr> res;
    res.reserve(arg_names.size());
    std::transform(arg_names.begin(), arg_names.end(), std::back_inserter(res),
                   [this](const std::string &name) {
                     return block_.FindVarRecursive(name);
                   });
    return res;
  }

X
Xin Pan 已提交
228 229 230
  DDim GetInputDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
    PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
231 232 233 234
                      platform::errors::InvalidArgument(
                          "The input(%s) should hold only one element, but now "
                          "it holds %d elements.",
                          name, arg_names.size()));
X
Xin Pan 已提交
235 236 237 238 239 240 241 242
    return this->GetDim(arg_names[0]);
  }

  std::vector<DDim> GetInputsDim(const std::string &name) const override {
    const std::vector<std::string> &arg_names = Inputs(name);
    return GetDims(arg_names);
  }

243 244
  bool IsRuntime() const override;

245 246
  bool IsRunMKLDNNKernel() const override;

X
Xin Pan 已提交
247 248 249 250 251 252 253 254 255 256
  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Inputs(name));
  }

  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string &name) const override {
    return GetVarTypes(Outputs(name));
  }

X
Xin Pan 已提交
257
  void SetOutputDim(const std::string &name, const DDim &dim) override {
H
hong 已提交
258
    auto arg_names = Outputs(name);
X
Xin Pan 已提交
259
    PADDLE_ENFORCE_EQ(arg_names.size(), 1UL,
260 261 262 263
                      platform::errors::InvalidArgument(
                          "The iutput(%s) should hold only one element, but "
                          "now it holds %d elements.",
                          name, arg_names.size()));
X
Xin Pan 已提交
264 265 266 267 268
    SetDim(arg_names[0], dim);
  }

  void SetOutputsDim(const std::string &name,
                     const std::vector<DDim> &dims) override {
H
hong 已提交
269
    auto names = Outputs(name);
X
Xin Pan 已提交
270 271 272
    SetDims(names, dims);
  }

273
 protected:
X
Xin Pan 已提交
274 275 276 277 278 279 280 281 282 283 284 285
  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<std::string> &names) const {
    std::vector<proto::VarType::Type> retv;
    retv.resize(names.size());
    std::transform(
        names.begin(), names.end(), retv.begin(),
        std::bind(std::mem_fn(&CompileTimeInferShapeContext::GetVarType), this,
                  std::placeholders::_1));
    return retv;
  }

  proto::VarType::Type GetVarType(const std::string &name) const;
Q
Qiao Longfei 已提交
286

X
Xin Pan 已提交
287 288
  DDim GetDim(const std::string &name) const {
    auto var = block_.FindVarRecursive(name);
289 290
    PADDLE_ENFORCE_NOT_NULL(
        var, platform::errors::NotFound("Variable %s is not found.", name));
X
Xin Pan 已提交
291 292 293
    DDim res;
    try {
      auto shape = var->GetShape();
294
      res = shape.empty() ? pten::make_ddim({0UL}) : pten::make_ddim(shape);
X
Xin Pan 已提交
295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
    } catch (...) {
      VLOG(5) << "GetDim of variable " << name << " error";
      std::rethrow_exception(std::current_exception());
    }
    return res;
  }

  std::vector<DDim> GetDims(const std::vector<std::string> &names) const {
    std::vector<DDim> ret;
    ret.reserve(names.size());
    std::transform(
        names.begin(), names.end(), std::back_inserter(ret),
        [this](const std::string &name) { return this->GetDim(name); });
    return ret;
  }
310

X
Xin Pan 已提交
311 312 313 314 315
  void SetDim(const std::string &name, const DDim &dim);

  void SetDims(const std::vector<std::string> &names,
               const std::vector<DDim> &dims) {
    size_t length = names.size();
316 317 318 319 320
    PADDLE_ENFORCE_EQ(length, dims.size(),
                      platform::errors::InvalidArgument(
                          "The input variables number(%d) and input dimensions "
                          "number(%d) do not match.",
                          length, dims.size()));
X
Xin Pan 已提交
321 322 323 324 325 326 327
    for (size_t i = 0; i < length; ++i) {
      if (names[i] == framework::kEmptyVarName) {
        continue;
      }
      SetDim(names[i], dims[i]);
    }
  }
328

F
fengjiayi 已提交
329 330 331 332
  std::vector<DDim> GetRepeatedDims(const std::string &name) const override;

  void SetRepeatedDims(const std::string &name,
                       const std::vector<DDim> &dims) override;
F
fengjiayi 已提交
333

Y
Yu Yang 已提交
334 335
  const OpDesc &op_;
  const BlockDesc &block_;
336 337
};

Y
Yu Yang 已提交
338 339
OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
               const VariableNameMap &outputs, const AttributeMap &attrs) {
340
  desc_.set_type(type);
F
fengjiayi 已提交
341 342 343
  inputs_ = inputs;
  outputs_ = outputs;
  attrs_ = attrs;
F
Fix bug  
fengjiayi 已提交
344
  need_update_ = true;
L
liuwei1031 已提交
345
  block_ = nullptr;
F
fengjiayi 已提交
346 347
}

X
Xin Pan 已提交
348 349 350 351 352 353
OpDesc::OpDesc(const OpDesc &other, BlockDesc *block) {
  CopyFrom(other);
  block_ = block;
  need_update_ = true;
}

354
void OpDesc::CopyFrom(const OpDesc &op_desc) {
F
fengjiayi 已提交
355 356 357 358
  desc_.set_type(op_desc.Type());
  inputs_ = op_desc.inputs_;
  outputs_ = op_desc.outputs_;
  attrs_ = op_desc.attrs_;
359 360
  // The record of original_id_ is only for auto parallel.
  original_id_ = op_desc.original_id_;
F
fengjiayi 已提交
361 362 363
  need_update_ = true;
}

F
fengjiayi 已提交
364
OpDesc::OpDesc(const proto::OpDesc &desc, BlockDesc *block)
365 366 367 368
    : desc_(desc), need_update_(false) {
  // restore inputs_
  int input_size = desc_.inputs_size();
  for (int i = 0; i < input_size; ++i) {
369
    const proto::OpDesc::Var &var = desc_.inputs(i);
370 371 372 373 374 375 376 377 378 379
    std::vector<std::string> &args = inputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore outputs_
  int output_size = desc_.outputs_size();
  for (int i = 0; i < output_size; ++i) {
380
    const proto::OpDesc::Var &var = desc_.outputs(i);
381 382 383 384 385 386 387 388
    std::vector<std::string> &args = outputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore attrs_
389
  for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
390
    std::string attr_name = attr.name();
391
    // The sub_block referred to by the BLOCK attr hasn't been added
X
Xin Pan 已提交
392 393 394
    // to ProgramDesc class yet, we skip setting BLOCK/BLOCKS attr here.
    if (attr.type() != proto::AttrType::BLOCK &&
        attr.type() != proto::AttrType::BLOCKS) {
395 396
      attrs_[attr_name] = GetAttrValue(attr);
    }
397
  }
398
  this->block_ = block;
399 400
}

Y
Yu Yang 已提交
401
proto::OpDesc *OpDesc::Proto() {
402
  Flush();
403
  return &desc_;
F
fengjiayi 已提交
404 405
}

Y
Yu Yang 已提交
406
const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
F
fengjiayi 已提交
407
  auto it = inputs_.find(name);
408 409 410 411
  PADDLE_ENFORCE_NE(
      it, inputs_.end(),
      platform::errors::NotFound("Input %s cannot be found in operator %s.",
                                 name, Type()));
F
fengjiayi 已提交
412 413 414
  return it->second;
}

Y
Yu Yang 已提交
415
std::vector<std::string> OpDesc::InputArgumentNames() const {
F
Update  
fengjiayi 已提交
416 417 418 419 420 421 422
  std::vector<std::string> retv;
  for (auto &ipt : this->inputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
423 424
void OpDesc::SetInput(const std::string &param_name,
                      const std::vector<std::string> &args) {
F
fengjiayi 已提交
425 426 427 428
  need_update_ = true;
  inputs_[param_name] = args;
}

Y
Yu Yang 已提交
429
const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
F
fengjiayi 已提交
430
  auto it = outputs_.find(name);
431 432 433 434
  PADDLE_ENFORCE_NE(
      it, outputs_.end(),
      platform::errors::NotFound("Output %s cannot be found in operator %s.",
                                 name, Type()));
F
fengjiayi 已提交
435 436 437
  return it->second;
}

438 439 440 441
bool OpDesc::HasOutput(const std::string &name) const {
  return outputs_.find(name) != outputs_.end();
}

Y
Yu Yang 已提交
442
std::vector<std::string> OpDesc::OutputArgumentNames() const {
F
Update  
fengjiayi 已提交
443 444 445 446 447 448 449
  std::vector<std::string> retv;
  for (auto &ipt : this->outputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
450 451
void OpDesc::SetOutput(const std::string &param_name,
                       const std::vector<std::string> &args) {
F
fengjiayi 已提交
452 453 454 455
  need_update_ = true;
  this->outputs_[param_name] = args;
}

456 457 458 459 460
void OpDesc::RemoveOutput(const std::string &name) {
  outputs_.erase(name);
  need_update_ = true;
}

461 462 463 464 465
void OpDesc::RemoveInput(const std::string &name) {
  inputs_.erase(name);
  need_update_ = true;
}

466 467 468 469 470 471 472 473 474 475
bool OpDesc::HasProtoAttr(const std::string &name) const {
  auto &op_info = OpInfoMap::Instance();
  if (op_info.Has(desc_.type())) {
    auto op_info_ptr = op_info.Get(desc_.type());
    if (op_info_ptr.HasOpProtoAndChecker()) {
      const proto::OpProto &proto = op_info_ptr.Proto();
      for (int i = 0; i != proto.attrs_size(); ++i) {
        const proto::OpProto::Attr &attr = proto.attrs(i);
        if (attr.name() == name) {
          return true;
L
luotao1 已提交
476 477
        }
      }
L
luotao1 已提交
478 479 480 481 482
    }
  }
  return false;
}

Y
Yu Yang 已提交
483
proto::AttrType OpDesc::GetAttrType(const std::string &name) const {
F
fengjiayi 已提交
484
  auto it = attrs_.find(name);
485 486
  PADDLE_ENFORCE_NE(it, attrs_.end(), platform::errors::NotFound(
                                          "Attribute %s is not found.", name));
487
  return static_cast<proto::AttrType>(it->second.which() - 1);
F
fengjiayi 已提交
488 489
}

Y
Yu Yang 已提交
490
std::vector<std::string> OpDesc::AttrNames() const {
F
fengjiayi 已提交
491 492 493 494 495 496 497 498
  std::vector<std::string> retv;
  retv.reserve(attrs_.size());
  for (auto &attr : attrs_) {
    retv.push_back(attr.first);
  }
  return retv;
}

499 500 501 502 503
void OpDesc::RemoveAttr(const std::string &name) {
  attrs_.erase(name);
  need_update_ = true;
}

Y
Yu Yang 已提交
504
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
M
minqiyang 已提交
505 506 507 508 509
  // NOTICE(minqiyang): pybind11 will take the empty list in python as
  // the std::vector<int> type in C++; so we have to change the attr's type
  // here if we meet this issue
  proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
  if (attr_type == proto::AttrType::INTS &&
510
      BOOST_GET_CONST(std::vector<int>, v).size() == 0u) {
M
minqiyang 已提交
511
    // Find current attr via attr name and set the correct attribute value
M
minqiyang 已提交
512
    const proto::OpProto::Attr &attr = GetProtoAttr(name);
M
minqiyang 已提交
513 514
    switch (attr.type()) {
      case proto::AttrType::BOOLEANS: {
M
minqiyang 已提交
515 516
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BOOLEANS";
M
minqiyang 已提交
517 518 519 520
        this->attrs_[name] = std::vector<bool>();
        break;
      }
      case proto::AttrType::INTS: {
M
minqiyang 已提交
521 522
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to INTS";
M
minqiyang 已提交
523 524 525
        this->attrs_[name] = std::vector<int>();
        break;
      }
526
      case proto::AttrType::LONGS: {
M
minqiyang 已提交
527 528
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from LONGS to LONGS";
529 530 531
        this->attrs_[name] = std::vector<int64_t>();
        break;
      }
M
minqiyang 已提交
532
      case proto::AttrType::FLOATS: {
M
minqiyang 已提交
533 534
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to FLOATS";
M
minqiyang 已提交
535 536 537 538
        this->attrs_[name] = std::vector<float>();
        break;
      }
      case proto::AttrType::STRINGS: {
M
minqiyang 已提交
539 540
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to STRINGS";
M
minqiyang 已提交
541 542 543 544
        this->attrs_[name] = std::vector<std::string>();
        break;
      }
      case proto::AttrType::BLOCKS: {
M
minqiyang 已提交
545 546
        VLOG(11) << "SetAttr: " << Type() << ", " << name
                 << " from INTS to BLOCKS";
M
minqiyang 已提交
547
        this->SetBlocksAttr(name, std::vector<BlockDesc *>());
M
minqiyang 已提交
548 549
        return;
      }
M
minqiyang 已提交
550
      default:
551 552
        PADDLE_THROW(platform::errors::Unimplemented(
            "Unsupported attribute type (code %d).", attr.type()));
M
minqiyang 已提交
553
    }
M
minqiyang 已提交
554 555
    need_update_ = true;
    return;
M
minqiyang 已提交
556 557
  }

558 559 560
  // In order to set bool attr properly
  if (attr_type == proto::AttrType::INT && HasProtoAttr(name) &&
      GetProtoAttr(name).type() == proto::AttrType::BOOLEAN) {
561
    this->attrs_[name] = static_cast<bool>(BOOST_GET_CONST(int, v));
562 563 564 565
    need_update_ = true;
    return;
  }

F
fengjiayi 已提交
566 567 568 569
  this->attrs_[name] = v;
  need_update_ = true;
}

A
Abhinav Arora 已提交
570 571
void OpDesc::SetBlockAttr(const std::string &name, BlockDesc *block) {
  this->attrs_[name] = block;
F
fengjiayi 已提交
572
  need_update_ = true;
F
fengjiayi 已提交
573 574
}

575 576 577 578 579 580
void OpDesc::SetBlocksAttr(const std::string &name,
                           std::vector<BlockDesc *> blocks) {
  this->attrs_[name] = blocks;
  need_update_ = true;
}

Y
Yu Yang 已提交
581
void OpDesc::SetAttrMap(
F
fengjiayi 已提交
582 583 584 585 586
    const std::unordered_map<std::string, Attribute> &attr_map) {
  attrs_ = attr_map;
  need_update_ = true;
}

Y
Yu Yang 已提交
587
Attribute OpDesc::GetAttr(const std::string &name) const {
F
fengjiayi 已提交
588
  auto it = attrs_.find(name);
589 590
  PADDLE_ENFORCE_NE(it, attrs_.end(), platform::errors::NotFound(
                                          "Attribute %s is not found.", name));
F
fengjiayi 已提交
591 592 593
  return it->second;
}

M
minqiyang 已提交
594 595 596
const proto::OpProto::Attr &OpDesc::GetProtoAttr(
    const std::string &name) const {
  const proto::OpProto &proto = OpInfoMap::Instance().Get(Type()).Proto();
M
minqiyang 已提交
597 598 599 600 601 602 603
  for (int i = 0; i != proto.attrs_size(); ++i) {
    const proto::OpProto::Attr &attr = proto.attrs(i);
    if (attr.name() == name) {
      return attr;
    }
  }

604 605
  PADDLE_THROW(platform::errors::NotFound(
      "Attribute %s is not found in proto %s.", name, proto.type()));
M
minqiyang 已提交
606 607
}

Y
yuyang18 已提交
608
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
Y
Fix bug  
yuyang18 已提交
609 610 611 612
  auto it = attrs_.find(name);
  if (it != attrs_.end()) {
    return it->second;
  } else {
Y
yuyang18 已提交
613
    return Attribute();
Y
Fix bug  
yuyang18 已提交
614 615 616
  }
}

G
gongweibao 已提交
617 618
std::vector<int> OpDesc::GetBlocksAttrIds(const std::string &name) const {
  auto it = attrs_.find(name);
619 620 621 622
  PADDLE_ENFORCE_NE(
      it, attrs_.end(),
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
623
  auto blocks = BOOST_GET_CONST(std::vector<BlockDesc *>, it->second);
G
gongweibao 已提交
624 625 626 627 628 629 630 631 632 633

  std::vector<int> ids;
  for (auto n : blocks) {
    ids.push_back(n->ID());
  }

  return ids;
}

int OpDesc::GetBlockAttrId(const std::string &name) const {
F
fengjiayi 已提交
634
  auto it = attrs_.find(name);
635 636 637 638
  PADDLE_ENFORCE_NE(
      it, attrs_.end(),
      platform::errors::NotFound(
          "Attribute `%s` is not found in operator `%s`.", name, desc_.type()));
639
  return BOOST_GET_CONST(BlockDesc *, it->second)->ID();
F
fengjiayi 已提交
640 641
}

Y
Yu Yang 已提交
642
const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
F
fengjiayi 已提交
643 644 645
  return attrs_;
}

Y
Yu Yang 已提交
646
void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
Y
Yancey1989 已提交
647 648
  RenameInput(old_name, new_name);
  RenameOutput(old_name, new_name);
F
fengjiayi 已提交
649 650 651
  need_update_ = true;
}

Y
Yu Yang 已提交
652 653
void OpDesc::RenameOutput(const std::string &old_name,
                          const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
654 655 656 657
  for (auto &output : outputs_) {
    std::replace(output.second.begin(), output.second.end(), old_name,
                 new_name);
  }
Y
yuyang18 已提交
658 659 660

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
661
    auto &op_vars = BOOST_GET(std::vector<std::string>, it->second);
Y
yuyang18 已提交
662 663 664
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
665 666 667
  need_update_ = true;
}

Y
Yu Yang 已提交
668 669
void OpDesc::RenameInput(const std::string &old_name,
                         const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
670 671 672
  for (auto &input : inputs_) {
    std::replace(input.second.begin(), input.second.end(), old_name, new_name);
  }
Y
Yancey1989 已提交
673 674 675

  auto it = attrs_.find(framework::OpProtoAndCheckerMaker::OpRoleVarAttrName());
  if (it != attrs_.end()) {
676
    auto &op_vars = BOOST_GET(std::vector<std::string>, it->second);
Y
Yancey1989 已提交
677 678 679
    std::replace(op_vars.begin(), op_vars.end(), old_name, new_name);
  }

Y
Yang Yang(Tony) 已提交
680 681 682
  need_update_ = true;
}

Y
Yu Yang 已提交
683
struct SetAttrDescVisitor : public boost::static_visitor<void> {
684 685
  explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
  mutable proto::OpDesc::Attr *attr_;
Y
Yu Yang 已提交
686 687 688
  void operator()(int v) const { attr_->set_i(v); }
  void operator()(float v) const { attr_->set_f(v); }
  void operator()(const std::string &v) const { attr_->set_s(v); }
Q
QI JUN 已提交
689 690 691 692 693 694 695

  // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
  template <class T,
            class = typename std::enable_if<std::is_same<bool, T>::value>::type>
  void operator()(T b) const {
    attr_->set_b(b);
  }
Y
Yu Yang 已提交
696 697 698 699 700 701 702 703 704 705 706 707 708

  void operator()(const std::vector<int> &v) const {
    VectorToRepeated(v, attr_->mutable_ints());
  }
  void operator()(const std::vector<float> &v) const {
    VectorToRepeated(v, attr_->mutable_floats());
  }
  void operator()(const std::vector<std::string> &v) const {
    VectorToRepeated(v, attr_->mutable_strings());
  }
  void operator()(const std::vector<bool> &v) const {
    VectorToRepeated(v, attr_->mutable_bools());
  }
709 710 711
  void operator()(const std::vector<BlockDesc *> &v) const {
    std::vector<int> blocks_idx;
    for (auto blk : v) {
T
tangwei12 已提交
712
      blocks_idx.push_back(blk->ID());
713 714 715
    }
    VectorToRepeated(blocks_idx, attr_->mutable_blocks_idx());
  }
T
tangwei12 已提交
716 717 718

  void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }

719
  void operator()(int64_t v) const { attr_->set_l(v); }
T
tangwei12 已提交
720 721 722 723 724

  void operator()(const std::vector<int64_t> &v) const {
    VectorToRepeated(v, attr_->mutable_longs());
  }

725 726 727 728
  void operator()(const std::vector<double> &v) const {
    VectorToRepeated(v, attr_->mutable_float64s());
  }

729 730 731 732 733
  void operator()(boost::blank) const {
    PADDLE_THROW(platform::errors::Unavailable(
        "Unsupported calling method of SetAttrDescVisitor object for "
        "`boosst::blank` type."));
  }
Y
Yu Yang 已提交
734 735
};

Y
Yu Yang 已提交
736
void OpDesc::Flush() {
F
fengjiayi 已提交
737
  if (need_update_) {
738
    this->desc_.mutable_inputs()->Clear();
F
fengjiayi 已提交
739
    for (auto &ipt : inputs_) {
740
      auto *input = desc_.add_inputs();
F
fengjiayi 已提交
741 742 743 744
      input->set_parameter(ipt.first);
      VectorToRepeated(ipt.second, input->mutable_arguments());
    }

745
    this->desc_.mutable_outputs()->Clear();
F
fengjiayi 已提交
746
    for (auto &opt : outputs_) {
747
      auto *output = desc_.add_outputs();
F
fengjiayi 已提交
748 749 750 751
      output->set_parameter(opt.first);
      VectorToRepeated(opt.second, output->mutable_arguments());
    }

752
    this->desc_.mutable_attrs()->Clear();
F
fengjiayi 已提交
753
    for (auto &attr : attrs_) {
754
      auto *attr_desc = desc_.add_attrs();
F
fengjiayi 已提交
755 756
      attr_desc->set_name(attr.first);
      attr_desc->set_type(
757
          static_cast<proto::AttrType>(attr.second.which() - 1));
Y
Yu Yang 已提交
758 759
      SetAttrDescVisitor visitor(attr_desc);
      boost::apply_visitor(visitor, attr.second);
F
fengjiayi 已提交
760 761 762 763 764
    }

    need_update_ = false;
  }
}
Y
Yu Yang 已提交
765

Y
Yu Yang 已提交
766
void OpDesc::CheckAttrs() {
767 768 769
  PADDLE_ENFORCE_EQ(Type().empty(), false,
                    platform::errors::PreconditionNotMet(
                        "CheckAttrs() can not be called before type is set."));
Y
Yu Yang 已提交
770 771 772 773 774 775
  auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
  if (checker == nullptr) {
    // checker is not configured. That operator could be generated by Paddle,
    // not by users.
    return;
  }
776
  VLOG(10) << "begin to check attribute of " << Type();
T
tangwei12 已提交
777
  checker->Check(&attrs_);
F
fengjiayi 已提交
778 779
}

Y
Yu Yang 已提交
780
void OpDesc::InferShape(const BlockDesc &block) const {
781 782 783
  try {
    VLOG(3) << "CompileTime infer shape on " << Type();
    auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
784 785 786 787
    PADDLE_ENFORCE_EQ(
        static_cast<bool>(infer_shape), true,
        platform::errors::NotFound(
            "Operator %s's infer_shape is not registered.", this->Type()));
788 789 790 791 792 793 794 795 796 797 798 799 800 801 802
    CompileTimeInferShapeContext ctx(*this, block);
    if (VLOG_IS_ON(10)) {
      std::ostringstream sout;
      auto inames = this->InputArgumentNames();
      sout << " From [";
      std::copy(inames.begin(), inames.end(),
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "] to [";
      auto onames = this->OutputArgumentNames();
      std::copy(onames.begin(), onames.end(),
                std::ostream_iterator<std::string>(sout, ", "));
      sout << "]";
      VLOG(10) << sout.str();
    }
    infer_shape(&ctx);
803
  } catch (platform::EnforceNotMet &exception) {
804
    framework::AppendErrorOpHint(Type(), &exception);
805 806 807 808
    throw std::move(exception);
  } catch (...) {
    std::rethrow_exception(std::current_exception());
  }
Y
Yu Yang 已提交
809 810
}

Y
Yu Yang 已提交
811
void OpDesc::InferVarType(BlockDesc *block) const {
X
Xin Pan 已提交
812 813
  // There are a few places that var type can be set.
  // When VarDesc is created, default set to LOD_TENSOR.
T
tianshuo78520a 已提交
814
  // When output variable is created, default is default set to LOD_TENSOR.
X
Xin Pan 已提交
815 816
  // We limit here to be the only place that operator defines its customized
  // var type inference. Hence, we don't do any "default" setting here.
Y
Yu Yang 已提交
817 818
  auto &info = OpInfoMap::Instance().Get(this->Type());
  if (info.infer_var_type_) {
M
minqiyang 已提交
819
    InferVarTypeContext context(this, block);
M
minqiyang 已提交
820
    info.infer_var_type_(&context);
Y
Yu Yang 已提交
821 822 823
  }
}

824
CompileTimeInferShapeContext::CompileTimeInferShapeContext(
Y
Yu Yang 已提交
825
    const OpDesc &op, const BlockDesc &block)
826 827 828
    : op_(op), block_(block) {}

bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
829 830 831
  if (op_.Inputs().find(name) == op_.Inputs().end()) {
    return false;
  }
832 833 834 835 836
  const std::vector<std::string> &input_names = op_.Input(name);
  auto length = input_names.size();
  if (length == 0) {
    return false;
  }
837 838 839 840
  PADDLE_ENFORCE_EQ(length, 1UL, platform::errors::InvalidArgument(
                                     "Input(%s) should have only one value, "
                                     "but it has %d values now.",
                                     name, length));
841 842 843 844
  return block_.HasVarRecursive(input_names[0]);
}

bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
845 846 847
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
848 849 850 851 852
  const std::vector<std::string> &output_names = op_.Output(name);
  auto length = output_names.size();
  if (length == 0) {
    return false;
  }
853 854 855 856
  PADDLE_ENFORCE_EQ(length, 1UL, platform::errors::InvalidArgument(
                                     "Output(%s) should have only one value, "
                                     "but it has %d values now.",
                                     name, length));
857 858 859
  return block_.HasVarRecursive(output_names[0]);
}

860 861 862 863
bool CompileTimeInferShapeContext::HasAttr(const std::string &name) const {
  return op_.HasAttr(name);
}

864
bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
865 866 867
  if (op_.Inputs().find(name) == op_.Inputs().end()) {
    return false;
  }
868 869 870 871 872 873 874 875 876 877 878
  const std::vector<std::string> &input_names = op_.Input(name);
  if (input_names.empty()) {
    return false;
  }
  for (auto &input : input_names) {
    if (!block_.HasVarRecursive(input)) return false;
  }
  return true;
}

bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
879 880 881
  if (op_.Outputs().find(name) == op_.Outputs().end()) {
    return false;
  }
882 883 884 885 886 887 888 889 890 891 892 893 894 895
  const std::vector<std::string> &output_names = op_.Output(name);
  if (output_names.empty()) {
    return false;
  }
  for (auto &output : output_names) {
    if (!block_.HasVarRecursive(output)) return false;
  }
  return true;
}

AttrReader CompileTimeInferShapeContext::Attrs() const {
  return AttrReader(op_.GetAttrMap());
}

H
hong 已提交
896
std::vector<std::string> CompileTimeInferShapeContext::Inputs(
897 898 899 900
    const std::string &name) const {
  return op_.Input(name);
}

H
hong 已提交
901
std::vector<std::string> CompileTimeInferShapeContext::Outputs(
902 903 904 905
    const std::string &name) const {
  return op_.Output(name);
}

F
fengjiayi 已提交
906
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
F
fengjiayi 已提交
907 908
    const std::string &name) const {
  auto var = block_.FindVarRecursive(name);
909 910
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
911 912 913 914
  std::vector<DDim> res;
  try {
    auto shapes = var->GetShapes();
    for (const auto &s : shapes) {
915
      res.push_back(s.empty() ? pten::make_ddim({0UL}) : pten::make_ddim(s));
F
fengjiayi 已提交
916 917
    }
  } catch (...) {
M
minqiyang 已提交
918
    VLOG(5) << "GetRepeatedDim of variable " << name << " error.";
F
fengjiayi 已提交
919 920 921
    std::rethrow_exception(std::current_exception());
  }
  return res;
922 923 924 925
}

void CompileTimeInferShapeContext::SetDim(const std::string &name,
                                          const DDim &dim) {
F
fengjiayi 已提交
926
  block_.FindVarRecursive(name)->SetShape(vectorize(dim));
927
}
F
fengjiayi 已提交
928 929 930 931

void CompileTimeInferShapeContext::SetRepeatedDims(
    const std::string &name, const std::vector<DDim> &dims) {
  auto var = block_.FindVarRecursive(name);
932 933
  PADDLE_ENFORCE_NOT_NULL(
      var, platform::errors::NotFound("Variable %s is not found.", name));
F
fengjiayi 已提交
934
  std::vector<std::vector<int64_t>> dim_vec(dims.size());
935
  std::transform(dims.begin(), dims.end(), dim_vec.begin(), pten::vectorize<>);
F
fengjiayi 已提交
936
  var->SetShapes(dim_vec);
937
}
F
fengjiayi 已提交
938

939 940
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }

941 942
bool CompileTimeInferShapeContext::IsRunMKLDNNKernel() const { return false; }

943
proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
944 945 946
    const std::string &name) const {
  return block_.FindVarRecursive(name)->GetType();
}
947

F
fengjiayi 已提交
948 949
}  // namespace framework
}  // namespace paddle