op_desc.cc 16.2 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/op_desc.h"
Y
Yu Yang 已提交
16
#include <functional>
17
#include <mutex>
Y
Yu Yang 已提交
18
#include <unordered_map>
19
#include "glog/logging.h"
F
fengjiayi 已提交
20
#include "paddle/framework/block_desc.h"
Y
Yu Yang 已提交
21
#include "paddle/framework/operator.h"
22
#include "paddle/framework/program_desc.h"
23
#include "paddle/framework/shape_inference.h"
Y
Yu Yang 已提交
24

F
fengjiayi 已提交
25 26 27
namespace paddle {
namespace framework {

Y
Yu Yang 已提交
28 29
class OpDesc;
class BlockDesc;
30 31
class CompileTimeInferShapeContext : public InferShapeContext {
 public:
Y
Yu Yang 已提交
32
  CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49

  bool HasInput(const std::string &name) const override;

  bool HasOutput(const std::string &name) const override;

  bool HasInputs(const std::string &name) const override;

  bool HasOutputs(const std::string &name) const override;

  AttrReader Attrs() const override;

  const std::vector<std::string> &Inputs(
      const std::string &name) const override;

  const std::vector<std::string> &Outputs(
      const std::string &name) const override;

Q
Qiao Longfei 已提交
50 51 52 53 54 55
  void ShareLoD(const std::string &in, const std::string &out, size_t i = 0,
                size_t j = 0) const override {
    PADDLE_ENFORCE_LT(i, Inputs(in).size());
    PADDLE_ENFORCE_LT(j, Outputs(out).size());
    auto *in_var = block_.FindVarRecursive(Inputs(in)[i]);
    auto *out_var = block_.FindVarRecursive(Outputs(out)[j]);
56
    if (in_var->GetType() != proto::VarDesc::LOD_TENSOR) {
Q
Qiao Longfei 已提交
57
      VLOG(3) << "input " << in << " is not LodTensor";
Q
Qiao Longfei 已提交
58 59
      return;
    }
60
    PADDLE_ENFORCE_EQ(in_var->GetType(), proto::VarDesc::LOD_TENSOR,
Q
Qiao Longfei 已提交
61 62
                      "The %d-th output of Output(%s) must be LoDTensor.", j,
                      out);
63
    out_var->SetLoDLevel(in_var->GetLoDLevel());
Q
Qiao Longfei 已提交
64
  }
D
dzhwinter 已提交
65

66 67 68
  bool IsRuntime() const override;

 protected:
69
  proto::VarDesc::VarType GetVarType(const std::string &name) const override;
Q
Qiao Longfei 已提交
70

71 72 73 74
  DDim GetDim(const std::string &name) const override;

  void SetDim(const std::string &name, const DDim &dim) override;

F
fengjiayi 已提交
75 76 77 78
  std::vector<DDim> GetRepeatedDims(const std::string &name) const override;

  void SetRepeatedDims(const std::string &name,
                       const std::vector<DDim> &dims) override;
F
fengjiayi 已提交
79

Y
Yu Yang 已提交
80 81
  const OpDesc &op_;
  const BlockDesc &block_;
82 83
};

Y
Yu Yang 已提交
84 85
OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
               const VariableNameMap &outputs, const AttributeMap &attrs) {
86
  desc_.set_type(type);
F
fengjiayi 已提交
87 88 89
  inputs_ = inputs;
  outputs_ = outputs;
  attrs_ = attrs;
F
Fix bug  
fengjiayi 已提交
90
  need_update_ = true;
F
fengjiayi 已提交
91 92
}

93
void OpDesc::CopyFrom(const OpDesc &op_desc) {
F
fengjiayi 已提交
94 95 96 97 98 99 100
  desc_.set_type(op_desc.Type());
  inputs_ = op_desc.inputs_;
  outputs_ = op_desc.outputs_;
  attrs_ = op_desc.attrs_;
  need_update_ = true;
}

101
OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog, BlockDesc *block)
102 103 104 105
    : desc_(desc), need_update_(false) {
  // restore inputs_
  int input_size = desc_.inputs_size();
  for (int i = 0; i < input_size; ++i) {
106
    const proto::OpDesc::Var &var = desc_.inputs(i);
107 108 109 110 111 112 113 114 115 116
    std::vector<std::string> &args = inputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore outputs_
  int output_size = desc_.outputs_size();
  for (int i = 0; i < output_size; ++i) {
117
    const proto::OpDesc::Var &var = desc_.outputs(i);
118 119 120 121 122 123 124 125
    std::vector<std::string> &args = outputs_[var.parameter()];
    int argu_size = var.arguments_size();
    args.reserve(argu_size);
    for (int j = 0; j < argu_size; ++j) {
      args.push_back(var.arguments(j));
    }
  }
  // restore attrs_
126
  for (const proto::OpDesc::Attr &attr : desc_.attrs()) {
127
    std::string attr_name = attr.name();
128
    if (attr.type() != proto::AttrType::BLOCK) {
129 130 131 132 133
      attrs_[attr_name] = GetAttrValue(attr);
    } else {
      auto bid = attr.block_idx();
      attrs_[attr_name] = prog->MutableBlock(bid);
    }
134
  }
135
  this->block_ = block;
136 137
}

Y
Yu Yang 已提交
138
proto::OpDesc *OpDesc::Proto() {
139
  Flush();
140
  return &desc_;
F
fengjiayi 已提交
141 142
}

Y
Yu Yang 已提交
143
const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
F
fengjiayi 已提交
144 145 146 147 148 149
  auto it = inputs_.find(name);
  PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name,
                 Type());
  return it->second;
}

Y
Yu Yang 已提交
150
std::vector<std::string> OpDesc::InputArgumentNames() const {
F
Update  
fengjiayi 已提交
151 152 153 154 155 156 157
  std::vector<std::string> retv;
  for (auto &ipt : this->inputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
158 159
void OpDesc::SetInput(const std::string &param_name,
                      const std::vector<std::string> &args) {
F
fengjiayi 已提交
160 161 162 163
  need_update_ = true;
  inputs_[param_name] = args;
}

Y
Yu Yang 已提交
164
const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
F
fengjiayi 已提交
165 166 167 168 169 170
  auto it = outputs_.find(name);
  PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s",
                 name, Type());
  return it->second;
}

Y
Yu Yang 已提交
171
std::vector<std::string> OpDesc::OutputArgumentNames() const {
F
Update  
fengjiayi 已提交
172 173 174 175 176 177 178
  std::vector<std::string> retv;
  for (auto &ipt : this->outputs_) {
    retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
  }
  return retv;
}

Y
Yu Yang 已提交
179 180
void OpDesc::SetOutput(const std::string &param_name,
                       const std::vector<std::string> &args) {
F
fengjiayi 已提交
181 182 183 184
  need_update_ = true;
  this->outputs_[param_name] = args;
}

Y
Yu Yang 已提交
185
proto::AttrType OpDesc::GetAttrType(const std::string &name) const {
F
fengjiayi 已提交
186 187
  auto it = attrs_.find(name);
  PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
188
  return static_cast<proto::AttrType>(it->second.which() - 1);
F
fengjiayi 已提交
189 190
}

Y
Yu Yang 已提交
191
std::vector<std::string> OpDesc::AttrNames() const {
F
fengjiayi 已提交
192 193 194 195 196 197 198 199
  std::vector<std::string> retv;
  retv.reserve(attrs_.size());
  for (auto &attr : attrs_) {
    retv.push_back(attr.first);
  }
  return retv;
}

Y
Yu Yang 已提交
200
void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
F
fengjiayi 已提交
201 202 203 204
  this->attrs_[name] = v;
  need_update_ = true;
}

Y
Yu Yang 已提交
205
void OpDesc::SetBlockAttr(const std::string &name, BlockDesc &block) {
206
  this->attrs_[name] = &block;
F
fengjiayi 已提交
207
  need_update_ = true;
F
fengjiayi 已提交
208 209
}

Y
Yu Yang 已提交
210
void OpDesc::SetAttrMap(
F
fengjiayi 已提交
211 212 213 214 215
    const std::unordered_map<std::string, Attribute> &attr_map) {
  attrs_ = attr_map;
  need_update_ = true;
}

Y
Yu Yang 已提交
216
Attribute OpDesc::GetAttr(const std::string &name) const {
F
fengjiayi 已提交
217 218 219 220 221
  auto it = attrs_.find(name);
  PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
  return it->second;
}

Y
Yu Yang 已提交
222
int OpDesc::GetBlockAttr(const std::string &name) const {
F
fengjiayi 已提交
223 224
  auto it = attrs_.find(name);
  PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
Y
Yu Yang 已提交
225
  return boost::get<BlockDesc *>(it->second)->ID();
F
fengjiayi 已提交
226 227
}

Y
Yu Yang 已提交
228
const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
F
fengjiayi 已提交
229 230 231
  return attrs_;
}

Y
Yu Yang 已提交
232
void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
F
Update  
fengjiayi 已提交
233
  for (auto &input : inputs_) {
F
fengjiayi 已提交
234 235
    std::replace(input.second.begin(), input.second.end(), old_name, new_name);
  }
F
Update  
fengjiayi 已提交
236 237
  for (auto &output : outputs_) {
    std::replace(output.second.begin(), output.second.end(), old_name,
F
fengjiayi 已提交
238 239 240 241 242
                 new_name);
  }
  need_update_ = true;
}

Y
Yu Yang 已提交
243 244
void OpDesc::RenameOutput(const std::string &old_name,
                          const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
245 246 247 248 249 250 251
  for (auto &output : outputs_) {
    std::replace(output.second.begin(), output.second.end(), old_name,
                 new_name);
  }
  need_update_ = true;
}

Y
Yu Yang 已提交
252 253
void OpDesc::RenameInput(const std::string &old_name,
                         const std::string &new_name) {
Y
Yang Yang(Tony) 已提交
254 255 256 257 258 259
  for (auto &input : inputs_) {
    std::replace(input.second.begin(), input.second.end(), old_name, new_name);
  }
  need_update_ = true;
}

Y
Yu Yang 已提交
260
struct SetAttrDescVisitor : public boost::static_visitor<void> {
261 262
  explicit SetAttrDescVisitor(proto::OpDesc::Attr *attr) : attr_(attr) {}
  mutable proto::OpDesc::Attr *attr_;
Y
Yu Yang 已提交
263 264 265
  void operator()(int v) const { attr_->set_i(v); }
  void operator()(float v) const { attr_->set_f(v); }
  void operator()(const std::string &v) const { attr_->set_s(v); }
Q
QI JUN 已提交
266 267 268 269 270 271 272

  // Please refer to https://github.com/PaddlePaddle/Paddle/issues/7162
  template <class T,
            class = typename std::enable_if<std::is_same<bool, T>::value>::type>
  void operator()(T b) const {
    attr_->set_b(b);
  }
Y
Yu Yang 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285

  void operator()(const std::vector<int> &v) const {
    VectorToRepeated(v, attr_->mutable_ints());
  }
  void operator()(const std::vector<float> &v) const {
    VectorToRepeated(v, attr_->mutable_floats());
  }
  void operator()(const std::vector<std::string> &v) const {
    VectorToRepeated(v, attr_->mutable_strings());
  }
  void operator()(const std::vector<bool> &v) const {
    VectorToRepeated(v, attr_->mutable_bools());
  }
Q
QI JUN 已提交
286
  void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->ID()); }
287
  void operator()(int64_t v) const { attr_->set_l(v); }
Y
Yu Yang 已提交
288 289 290
  void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};

Y
Yu Yang 已提交
291
void OpDesc::Flush() {
F
fengjiayi 已提交
292
  if (need_update_) {
293
    this->desc_.mutable_inputs()->Clear();
F
fengjiayi 已提交
294
    for (auto &ipt : inputs_) {
295
      auto *input = desc_.add_inputs();
F
fengjiayi 已提交
296 297 298 299
      input->set_parameter(ipt.first);
      VectorToRepeated(ipt.second, input->mutable_arguments());
    }

300
    this->desc_.mutable_outputs()->Clear();
F
fengjiayi 已提交
301
    for (auto &opt : outputs_) {
302
      auto *output = desc_.add_outputs();
F
fengjiayi 已提交
303 304 305 306
      output->set_parameter(opt.first);
      VectorToRepeated(opt.second, output->mutable_arguments());
    }

307
    this->desc_.mutable_attrs()->Clear();
F
fengjiayi 已提交
308
    for (auto &attr : attrs_) {
309
      auto *attr_desc = desc_.add_attrs();
F
fengjiayi 已提交
310 311
      attr_desc->set_name(attr.first);
      attr_desc->set_type(
312
          static_cast<proto::AttrType>(attr.second.which() - 1));
Y
Yu Yang 已提交
313 314
      SetAttrDescVisitor visitor(attr_desc);
      boost::apply_visitor(visitor, attr.second);
F
fengjiayi 已提交
315 316 317 318 319
    }

    need_update_ = false;
  }
}
Y
Yu Yang 已提交
320

321 322 323 324 325 326 327 328 329 330
static std::once_flag init_infer_shape_funcs;

static void InitInferShapeFuncs() {
  std::call_once(init_infer_shape_funcs, [] {
    auto &map = OpInfoMap::Instance();
    auto &info_map = *map.mutable_map();

    for (auto &kern_pair : OperatorWithKernel::AllOpKernels()) {
      auto op_type = kern_pair.first;
      auto &op_info = info_map.at(op_type);
Y
Yiqun Liu 已提交
331 332
      auto op = static_cast<OperatorWithKernel *>(op_info.Creator()(
          "", VariableNameMap{}, VariableNameMap{}, AttributeMap{}));
333 334 335 336 337 338
      if (op_info.infer_shape_) {  // infer_shape has been registered.
        continue;
      }
      op_info.infer_shape_ = [op](InferShapeContext *ctx) {
        op->InferShape(ctx);
      };
Y
Yu Yang 已提交
339
    }
340
  });
Y
Yu Yang 已提交
341 342
}

Y
Yu Yang 已提交
343
void OpDesc::CheckAttrs() {
F
fengjiayi 已提交
344 345
  PADDLE_ENFORCE(!Type().empty(),
                 "CheckAttr() can not be called before type is setted.");
Y
Yu Yang 已提交
346 347 348 349 350 351
  auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
  if (checker == nullptr) {
    // checker is not configured. That operator could be generated by Paddle,
    // not by users.
    return;
  }
F
fengjiayi 已提交
352 353 354
  checker->Check(attrs_);
}

Y
Yu Yang 已提交
355
void OpDesc::InferShape(const BlockDesc &block) const {
Y
Yu Yang 已提交
356
  VLOG(3) << "CompileTime infer shape on " << Type();
357 358 359 360
  InitInferShapeFuncs();
  auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
  PADDLE_ENFORCE(static_cast<bool>(infer_shape),
                 "%s's infer_shape has not been registered", this->Type());
Y
Yu Yang 已提交
361
  CompileTimeInferShapeContext ctx(*this, block);
Y
Yu Yang 已提交
362 363 364 365 366 367 368 369 370 371 372 373 374
  if (VLOG_IS_ON(10)) {
    std::ostringstream sout;
    auto inames = this->InputArgumentNames();
    sout << " From [";
    std::copy(inames.begin(), inames.end(),
              std::ostream_iterator<std::string>(sout, ", "));
    sout << "] to [";
    auto onames = this->OutputArgumentNames();
    std::copy(onames.begin(), onames.end(),
              std::ostream_iterator<std::string>(sout, ", "));
    sout << "]";
    VLOG(10) << sout.str();
  }
375
  infer_shape(&ctx);
Y
Yu Yang 已提交
376 377
}

Y
Yu Yang 已提交
378
void OpDesc::InferVarType(BlockDesc *block) const {
Y
Yu Yang 已提交
379 380 381 382 383
  auto &info = OpInfoMap::Instance().Get(this->Type());
  if (info.infer_var_type_) {
    info.infer_var_type_(*this, block);
  } else {
    // all output type is LoDTensor by default
Y
Yu Yang 已提交
384 385 386
    VLOG(10) << this->Type()
             << " has not registered InferVarType. Set output variables to "
                "LOD_TENSOR";
Y
Yu Yang 已提交
387 388
    for (auto &out_pair : this->outputs_) {
      for (auto &out_var_name : out_pair.second) {
Y
Yang Yang(Tony) 已提交
389
        block->FindRecursiveOrCreateVar(out_var_name)
Y
Yang Yu 已提交
390
            .SetType(proto::VarDesc::LOD_TENSOR);
Y
Yu Yang 已提交
391 392 393 394 395
      }
    }
  }
}

396
CompileTimeInferShapeContext::CompileTimeInferShapeContext(
Y
Yu Yang 已提交
397
    const OpDesc &op, const BlockDesc &block)
398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464
    : op_(op), block_(block) {}

bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
  const std::vector<std::string> &input_names = op_.Input(name);
  auto length = input_names.size();
  if (length == 0) {
    return false;
  }
  PADDLE_ENFORCE_EQ(length, 1UL,
                    "Input(%s) should have only one value, "
                    "but it have %d now",
                    name, length);
  return block_.HasVarRecursive(input_names[0]);
}

bool CompileTimeInferShapeContext::HasOutput(const std::string &name) const {
  const std::vector<std::string> &output_names = op_.Output(name);
  auto length = output_names.size();
  if (length == 0) {
    return false;
  }
  PADDLE_ENFORCE_EQ(length, 1UL,
                    "Output(%s) should have only one value, "
                    "but it have %d now",
                    name, length);
  return block_.HasVarRecursive(output_names[0]);
}

bool CompileTimeInferShapeContext::HasInputs(const std::string &name) const {
  const std::vector<std::string> &input_names = op_.Input(name);
  if (input_names.empty()) {
    return false;
  }
  for (auto &input : input_names) {
    if (!block_.HasVarRecursive(input)) return false;
  }
  return true;
}

bool CompileTimeInferShapeContext::HasOutputs(const std::string &name) const {
  const std::vector<std::string> &output_names = op_.Output(name);
  if (output_names.empty()) {
    return false;
  }
  for (auto &output : output_names) {
    if (!block_.HasVarRecursive(output)) return false;
  }
  return true;
}

AttrReader CompileTimeInferShapeContext::Attrs() const {
  return AttrReader(op_.GetAttrMap());
}

const std::vector<std::string> &CompileTimeInferShapeContext::Inputs(
    const std::string &name) const {
  return op_.Input(name);
}

const std::vector<std::string> &CompileTimeInferShapeContext::Outputs(
    const std::string &name) const {
  return op_.Output(name);
}

DDim CompileTimeInferShapeContext::GetDim(const std::string &name) const {
  auto var = block_.FindVarRecursive(name);
  PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
F
fengjiayi 已提交
465
  DDim res;
Y
Yang Yang(Tony) 已提交
466
  try {
F
fengjiayi 已提交
467
    auto shape = var->GetShape();
F
fengjiayi 已提交
468
    res = shape.empty() ? make_ddim({0UL}) : make_ddim(shape);
Y
Yang Yang(Tony) 已提交
469 470 471 472
  } catch (...) {
    VLOG(5) << "GetDim of variable " << name << " error";
    std::rethrow_exception(std::current_exception());
  }
F
fengjiayi 已提交
473 474 475
  return res;
}

F
fengjiayi 已提交
476
std::vector<DDim> CompileTimeInferShapeContext::GetRepeatedDims(
F
fengjiayi 已提交
477 478 479 480 481 482 483 484 485 486 487 488 489 490
    const std::string &name) const {
  auto var = block_.FindVarRecursive(name);
  PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
  std::vector<DDim> res;
  try {
    auto shapes = var->GetShapes();
    for (const auto &s : shapes) {
      res.push_back(s.empty() ? make_ddim({0UL}) : make_ddim(s));
    }
  } catch (...) {
    VLOG(5) << "GetRepeatedDim of variable " << name << " error.";
    std::rethrow_exception(std::current_exception());
  }
  return res;
491 492 493 494
}

void CompileTimeInferShapeContext::SetDim(const std::string &name,
                                          const DDim &dim) {
F
fengjiayi 已提交
495
  block_.FindVarRecursive(name)->SetShape(vectorize(dim));
496
}
F
fengjiayi 已提交
497 498 499 500 501 502 503 504 505 506

void CompileTimeInferShapeContext::SetRepeatedDims(
    const std::string &name, const std::vector<DDim> &dims) {
  auto var = block_.FindVarRecursive(name);
  PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
  std::vector<std::vector<int64_t>> dim_vec(dims.size());
  std::transform(dims.begin(), dims.end(), dim_vec.begin(), vectorize);
  var->SetShapes(dim_vec);
}

507 508
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }

509
proto::VarDesc::VarType CompileTimeInferShapeContext::GetVarType(
510 511 512
    const std::string &name) const {
  return block_.FindVarRecursive(name)->GetType();
}
513

F
fengjiayi 已提交
514 515
}  // namespace framework
}  // namespace paddle