var_desc.cc 15.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/var_desc.h"
16 17

#include "glog/logging.h"
18
#include "paddle/fluid/framework/attribute.h"
19
#include "paddle/fluid/framework/data_type.h"
Y
Yi Wang 已提交
20
#include "paddle/fluid/platform/enforce.h"
F
fengjiayi 已提交
21 22 23 24

namespace paddle {
namespace framework {

25 26 27 28 29 30 31
VarDesc::VarDesc(const VarDesc &other)
    : desc_(other.desc_),
      attrs_(other.attrs_),
      original_id_(other.original_id_) {
  if (other.dist_attr_) {
    dist_attr_.reset(new TensorDistAttr(*other.dist_attr_));
  }
32 33 34 35 36 37 38 39 40 41
  need_updated_ = true;
}

VarDesc::VarDesc(const proto::VarDesc &desc) : desc_(desc) {
  // Restore attrs_ for auto parallel
  for (const proto::VarDesc::Attr &attr : desc_.attrs()) {
    std::string attr_name = attr.name();
    attrs_[attr_name] = GetAttrValue(attr);
  }
  need_updated_ = true;
42 43
}

44
proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
Q
QI JUN 已提交
45

46 47
void VarDesc::SetType(proto::VarType::Type type) {
  desc_.mutable_type()->set_type(type);
L
Leo Chen 已提交
48
  need_updated_ = true;
49
}
Q
QI JUN 已提交
50

Y
Yu Yang 已提交
51
void VarDesc::SetShape(const std::vector<int64_t> &dims) {
Y
Yu Yang 已提交
52
  VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
L
Leo Chen 已提交
53
  need_updated_ = true;
F
fengjiayi 已提交
54 55
}

F
fengjiayi 已提交
56
void VarDesc::SetTensorDescNum(size_t num) {
57 58 59 60
  switch (desc_.type().type()) {
    case proto::VarType::READER: {
      auto *lod_tensors_ptr =
          desc_.mutable_type()->mutable_reader()->mutable_lod_tensor();
F
fengjiayi 已提交
61 62 63 64 65 66 67 68
      lod_tensors_ptr->Clear();
      for (size_t i = 0; i < num; ++i) {
        lod_tensors_ptr->Add();
      }
      return;
    } break;
    default:
      PADDLE_THROW(
69 70 71
          platform::errors::Unavailable("Setting 'sub_tensor_number' is not "
                                        "supported by the %s type variable.",
                                        this->Name()));
F
fengjiayi 已提交
72
  }
L
Leo Chen 已提交
73
  need_updated_ = true;
F
fengjiayi 已提交
74 75 76
}

size_t VarDesc::GetTensorDescNum() const {
77 78 79
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      return desc_.type().reader().lod_tensor_size();
F
fengjiayi 已提交
80 81 82
      break;
    default:
      PADDLE_THROW(
83 84 85
          platform::errors::Unavailable("Getting 'sub_tensor_number' is not "
                                        "supported by the %s type variable.",
                                        this->Name()));
F
fengjiayi 已提交
86 87 88 89
  }
}

void VarDesc::SetShapes(
F
fengjiayi 已提交
90
    const std::vector<std::vector<int64_t>> &multiple_dims) {
F
fengjiayi 已提交
91
  if (multiple_dims.size() != GetTensorDescNum()) {
M
minqiyang 已提交
92 93 94 95
    VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size()
            << ") doesn't match the existing tensor number("
            << GetTensorDescNum()
            << "). The Reader is going to be reinitialized.";
F
fengjiayi 已提交
96 97
    SetTensorDescNum(multiple_dims.size());
  }
98
  std::vector<proto::VarType::TensorDesc *> tensors = mutable_tensor_descs();
F
fengjiayi 已提交
99 100 101
  for (size_t i = 0; i < multiple_dims.size(); ++i) {
    VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
  }
L
Leo Chen 已提交
102
  need_updated_ = true;
F
fengjiayi 已提交
103 104 105 106 107 108 109
}

std::vector<int64_t> VarDesc::GetShape() const {
  return RepeatedToVector(tensor_desc().dims());
}

std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
110
  std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
F
fengjiayi 已提交
111 112 113 114 115 116 117 118
  std::vector<std::vector<int64_t>> res;
  res.reserve(descs.size());
  for (const auto &tensor_desc : descs) {
    res.push_back(RepeatedToVector(tensor_desc.dims()));
  }
  return res;
}

119
void VarDesc::SetDataType(proto::VarType::Type data_type) {
X
Xin Pan 已提交
120
  mutable_tensor_desc()->set_data_type(data_type);
L
Leo Chen 已提交
121
  need_updated_ = true;
F
fengjiayi 已提交
122 123
}

F
fengjiayi 已提交
124
void VarDesc::SetDataTypes(
125
    const std::vector<proto::VarType::Type> &multiple_data_type) {
F
fengjiayi 已提交
126
  if (multiple_data_type.size() != GetTensorDescNum()) {
M
minqiyang 已提交
127 128 129 130 131
    VLOG(3) << "WARNING: The number of given data types("
            << multiple_data_type.size()
            << ") doesn't match the existing tensor number("
            << GetTensorDescNum()
            << "). The Reader is going to be reinitialized.";
F
fengjiayi 已提交
132 133
    SetTensorDescNum(multiple_data_type.size());
  }
134 135
  std::vector<proto::VarType::TensorDesc *> tensor_descs =
      mutable_tensor_descs();
F
fengjiayi 已提交
136 137 138
  for (size_t i = 0; i < multiple_data_type.size(); ++i) {
    tensor_descs[i]->set_data_type(multiple_data_type[i]);
  }
L
Leo Chen 已提交
139
  need_updated_ = true;
F
fengjiayi 已提交
140 141
}

142
proto::VarType::Type VarDesc::GetDataType() const {
X
Xin Pan 已提交
143
  return tensor_desc().data_type();
144
}
Y
Stash  
Yu Yang 已提交
145

146 147 148 149
size_t VarDesc::ElementSize() const {
  return framework::SizeOfType(GetDataType());
}

150
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
151
  std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
152
  std::vector<proto::VarType::Type> res;
F
fengjiayi 已提交
153 154 155 156 157 158 159
  res.reserve(descs.size());
  for (const auto &tensor_desc : descs) {
    res.push_back(tensor_desc.data_type());
  }
  return res;
}

Y
Yu Yang 已提交
160
void VarDesc::SetLoDLevel(int32_t lod_level) {
161 162 163
  switch (desc_.type().type()) {
    case proto::VarType::LOD_TENSOR:
      desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level);
Y
Yu Yang 已提交
164
      break;
165 166
    case proto::VarType::LOD_TENSOR_ARRAY:
      desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level);
Y
Yu Yang 已提交
167 168
      break;
    default:
169 170 171
      PADDLE_THROW(platform::errors::Unavailable(
          "Setting 'lod_level' is not supported by the %s type variable.",
          this->Name()));
F
fengjiayi 已提交
172
  }
L
Leo Chen 已提交
173
  need_updated_ = true;
F
fengjiayi 已提交
174 175 176
}

void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
F
fengjiayi 已提交
177
  if (multiple_lod_level.size() != GetTensorDescNum()) {
M
minqiyang 已提交
178 179 180 181 182
    VLOG(3) << "WARNING: The number of given lod_levels("
            << multiple_lod_level.size()
            << ") doesn't match the existing tensor number("
            << GetTensorDescNum()
            << "). The Reader is going to be reinitialized.";
F
fengjiayi 已提交
183 184
    SetTensorDescNum(multiple_lod_level.size());
  }
185 186
  switch (desc_.type().type()) {
    case proto::VarType::READER: {
F
fengjiayi 已提交
187
      size_t i = 0;
188 189
      for (auto &lod_tensor :
           *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
F
fengjiayi 已提交
190 191 192 193
        lod_tensor.set_lod_level(multiple_lod_level[i++]);
      }
    } break;
    default:
194 195 196
      PADDLE_THROW(platform::errors::Unavailable(
          "Setting 'lod_levels' is not supported by the %s type variable",
          this->Name()));
Y
Yu Yang 已提交
197
  }
L
Leo Chen 已提交
198
  need_updated_ = true;
Y
Stash  
Yu Yang 已提交
199 200
}

201
int32_t VarDesc::GetLoDLevel() const {
202 203 204 205 206
  switch (desc_.type().type()) {
    case proto::VarType::LOD_TENSOR:
      return desc_.type().lod_tensor().lod_level();
    case proto::VarType::LOD_TENSOR_ARRAY:
      return desc_.type().tensor_array().lod_level();
Y
Yu Yang 已提交
207
    default:
208 209 210
      PADDLE_THROW(platform::errors::Unavailable(
          "Getting 'lod_level' is not supported by the %s type variable.",
          this->Name()));
F
fengjiayi 已提交
211 212 213 214 215
  }
}

std::vector<int32_t> VarDesc::GetLoDLevels() const {
  std::vector<int32_t> res;
216 217 218 219
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      res.reserve(desc_.type().reader().lod_tensor_size());
      for (auto &lod_tensor : desc_.type().reader().lod_tensor()) {
F
fengjiayi 已提交
220 221 222 223 224
        res.push_back(lod_tensor.lod_level());
      }
      return res;
      break;
    default:
225 226 227
      PADDLE_THROW(platform::errors::Unavailable(
          "Getting 'lod_levels' is not supported by the %s type variable.",
          this->Name()));
Y
Yu Yang 已提交
228
  }
Y
Stash  
Yu Yang 已提交
229
}
Y
Yu Yang 已提交
230

231
const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
232
  PADDLE_ENFORCE_EQ(
233 234
      desc_.has_type(),
      true,
235
      platform::errors::NotFound("The variable's type was not set."));
236
  PADDLE_ENFORCE_EQ(
237 238
      desc_.type().has_type(),
      true,
239
      platform::errors::NotFound("The variable's type was not set."));
240 241 242 243 244 245 246
  switch (desc_.type().type()) {
    case proto::VarType::SELECTED_ROWS:
      return desc_.type().selected_rows();
    case proto::VarType::LOD_TENSOR:
      return desc_.type().lod_tensor().tensor();
    case proto::VarType::LOD_TENSOR_ARRAY:
      return desc_.type().tensor_array().tensor();
S
Steffy-zxf 已提交
247 248 249 250
    case proto::VarType::STRINGS:
      return desc_.type().strings();
    case proto::VarType::VOCAB:
      return desc_.type().vocab();
251 252
    case proto::VarType::SPARSE_COO:
      return desc_.type().sparse_coo();
Y
Yu Yang 已提交
253
    default:
254 255 256
      PADDLE_THROW(platform::errors::Unavailable(
          "Getting 'tensor_desc' is not supported by the %s type variable.",
          this->Name()));
F
fengjiayi 已提交
257 258 259
  }
}

260
std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
261
  PADDLE_ENFORCE_EQ(
262 263
      desc_.has_type(),
      true,
264
      platform::errors::NotFound("The variable's type was not be set."));
265
  std::vector<proto::VarType::TensorDesc> res;
F
fengjiayi 已提交
266
  res.reserve(GetTensorDescNum());
267 268 269
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) {
F
fengjiayi 已提交
270 271 272 273
        res.push_back(lod_tensor.tensor());
      }
      return res;
    default:
274 275 276
      PADDLE_THROW(platform::errors::Unavailable(
          "Getting 'tensor_descs' is not supported by the %s type variable.",
          this->Name()));
Y
Yu Yang 已提交
277 278 279
  }
}

280
proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
281
  PADDLE_ENFORCE_EQ(
282 283
      desc_.has_type(),
      true,
284 285
      platform::errors::NotFound("The variable's type was not be set."));
  PADDLE_ENFORCE_EQ(
286 287
      desc_.type().has_type(),
      true,
288
      platform::errors::NotFound("The variable's type was not be set."));
289 290 291 292 293 294 295
  switch (desc_.type().type()) {
    case proto::VarType::SELECTED_ROWS:
      return desc_.mutable_type()->mutable_selected_rows();
    case proto::VarType::LOD_TENSOR:
      return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor();
    case proto::VarType::LOD_TENSOR_ARRAY:
      return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor();
S
Steffy-zxf 已提交
296 297 298 299
    case proto::VarType::STRINGS:
      return desc_.mutable_type()->mutable_strings();
    case proto::VarType::VOCAB:
      return desc_.mutable_type()->mutable_vocab();
300 301
    case proto::VarType::SPARSE_COO:
      return desc_.mutable_type()->mutable_sparse_coo();
Y
Yu Yang 已提交
302
    default:
F
fengjiayi 已提交
303
      PADDLE_THROW(
304 305 306
          platform::errors::Unavailable("Getting 'mutable_tensor_desc' is not "
                                        "supported by the %s type variable.",
                                        this->Name()));
Y
Yu Yang 已提交
307
  }
L
Leo Chen 已提交
308
  need_updated_ = true;
Y
Yu Yang 已提交
309
}
F
fengjiayi 已提交
310

311
std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
312
  PADDLE_ENFORCE_EQ(
313 314
      desc_.has_type(),
      true,
315 316
      platform::errors::NotFound("The variable's type was not be set."));
  PADDLE_ENFORCE_EQ(
317 318
      desc_.type().has_type(),
      true,
319
      platform::errors::NotFound("The variable's type was not be set."));
320
  std::vector<proto::VarType::TensorDesc *> res;
F
fengjiayi 已提交
321
  res.reserve(GetTensorDescNum());
322 323 324 325
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      for (auto &lod_tensor :
           *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
F
fengjiayi 已提交
326 327 328 329
        res.push_back(lod_tensor.mutable_tensor());
      }
      return res;
    default:
330 331 332
      PADDLE_THROW(platform::errors::Unavailable(
          "Getting 'tensor_descs' is not supported by the %s type variable.",
          this->Name()));
F
fengjiayi 已提交
333
  }
L
Leo Chen 已提交
334
  need_updated_ = true;
F
fengjiayi 已提交
335 336
}

337 338 339 340 341 342 343 344 345 346 347 348 349 350 351
std::vector<std::string> VarDesc::AttrNames() const {
  std::vector<std::string> retv;
  retv.reserve(attrs_.size());
  for (auto &attr : attrs_) {
    retv.push_back(attr.first);
  }
  return retv;
}

void VarDesc::RemoveAttr(const std::string &name) { attrs_.erase(name); }

void VarDesc::SetAttr(const std::string &name, const Attribute &v) {
  // NOTICE(sandyhouse): pybind11 will take the empty list in python as
  // the std::vector<int> type in C++; so we have to change the attr's type
  // here if we meet this issue
R
Ruibiao Chen 已提交
352
  proto::AttrType attr_type = static_cast<proto::AttrType>(v.index() - 1);
353
  if (attr_type == proto::AttrType::INTS &&
R
Ruibiao Chen 已提交
354
      PADDLE_GET_CONST(std::vector<int>, v).size() == 0u) {
355 356 357 358 359 360 361
    // Find current attr via attr name and set the correct attribute value
    this->attrs_[name] = std::vector<int>();
    return;
  }
  bool valid = attr_type == proto::AttrType::INT ||
               attr_type == proto::AttrType::STRING ||
               attr_type == proto::AttrType::INTS;
362 363 364 365 366 367
  PADDLE_ENFORCE_EQ(valid,
                    true,
                    platform::errors::InvalidArgument(
                        "The value for attr (%s) must be "
                        "one of int, string, list of int for now.",
                        name));
368 369

  this->attrs_[name] = v;
370
  need_updated_ = true;
371 372 373 374
}

Attribute VarDesc::GetAttr(const std::string &name) const {
  auto it = attrs_.find(name);
375
  PADDLE_ENFORCE_NE(
376 377
      it,
      attrs_.end(),
378
      platform::errors::NotFound("Attribute %s is not found.", name));
379 380 381
  return it->second;
}

382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438
struct SetVarAttrDescVisitor {
  explicit SetVarAttrDescVisitor(proto::VarDesc::Attr *attr) : attr_(attr) {}
  mutable proto::VarDesc::Attr *attr_;

  template <typename T>
  void operator()(T &&v) {
    using U = std::decay_t<decltype(v)>;
    if (std::is_same<U, int>::value) {
      set_attr_value(v);
    } else if (std::is_same<U, std::string>::value) {
      set_attr_value(v);
    } else if (std::is_same<U, std::vector<int>>::value) {
      set_attr_value(v);
    } else {
      PADDLE_THROW(platform::errors::Unavailable(
          "Unsupported calling method of SetAttrDescVisitor object."));
    }
  }

  // This template is used to pass the compilation
  template <typename U>
  void set_attr_value(U v);

  void set_attr_value(int v) { attr_->set_i(v); }

  void set_attr_value(const std::string &v) { attr_->set_s(v); }

  void set_attr_value(const std::vector<int> &v) {
    VectorToRepeated(v, attr_->mutable_ints());
  }
};

// Only need to flush the attrs for auto parallel for now
void VarDesc::Flush() {
  VLOG(4) << "Flush "
          << " " << Name() << " " << need_updated_;
  if (need_updated_) {
    this->desc_.mutable_attrs()->Clear();
    std::vector<std::pair<std::string, Attribute>> sorted_attrs{attrs_.begin(),
                                                                attrs_.end()};
    std::sort(
        sorted_attrs.begin(),
        sorted_attrs.end(),
        [](std::pair<std::string, Attribute> a,
           std::pair<std::string, Attribute> b) { return a.first < b.first; });
    for (auto &attr : sorted_attrs) {
      auto *attr_desc = desc_.add_attrs();
      attr_desc->set_name(attr.first);
      attr_desc->set_type(
          static_cast<proto::AttrType>(attr.second.index() - 1));
      SetVarAttrDescVisitor visitor(attr_desc);
      paddle::visit(visitor, attr.second);
    }
    need_updated_ = false;
  }
}

439 440 441 442 443
TensorDistAttr *VarDesc::MutableDistAttr() {
  // If dist_attr_ is nullptr, construct a new one and return.
  if (dist_attr_) {
    return dist_attr_.get();
  } else {
444 445
    auto shape = paddle::distributed::auto_parallel::get_tensor_shape(this);
    dist_attr_.reset(new TensorDistAttr(shape));
446 447
    return dist_attr_.get();
  }
448
  need_updated_ = true;
449 450 451 452 453 454
}

void VarDesc::SetDistAttr(const TensorDistAttr &dist_attr) {
  // Make sure this dist attr be created
  MutableDistAttr();
  *dist_attr_ = dist_attr;
455
  need_updated_ = true;
456 457
}

458 459 460 461 462
bool operator==(const VarDesc &left, const VarDesc &right) {
  return left.Proto()->SerializeAsString() ==
         right.Proto()->SerializeAsString();
}

F
fengjiayi 已提交
463 464
}  // namespace framework
}  // namespace paddle