var_desc.cc 12.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/framework/var_desc.h"
16 17

#include "glog/logging.h"
18
#include "paddle/fluid/framework/data_type.h"
Y
Yi Wang 已提交
19
#include "paddle/fluid/platform/enforce.h"
F
fengjiayi 已提交
20 21 22 23

namespace paddle {
namespace framework {

24 25 26 27 28 29 30 31 32
VarDesc::VarDesc(const VarDesc &other)
    : desc_(other.desc_),
      attrs_(other.attrs_),
      original_id_(other.original_id_) {
  if (other.dist_attr_) {
    dist_attr_.reset(new TensorDistAttr(*other.dist_attr_));
  }
}

33
proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
Q
QI JUN 已提交
34

35 36
void VarDesc::SetType(proto::VarType::Type type) {
  desc_.mutable_type()->set_type(type);
L
Leo Chen 已提交
37
  need_updated_ = true;
38
}
Q
QI JUN 已提交
39

Y
Yu Yang 已提交
40
void VarDesc::SetShape(const std::vector<int64_t> &dims) {
Y
Yu Yang 已提交
41
  VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
L
Leo Chen 已提交
42
  need_updated_ = true;
F
fengjiayi 已提交
43 44
}

F
fengjiayi 已提交
45
void VarDesc::SetTensorDescNum(size_t num) {
46 47 48 49
  switch (desc_.type().type()) {
    case proto::VarType::READER: {
      auto *lod_tensors_ptr =
          desc_.mutable_type()->mutable_reader()->mutable_lod_tensor();
F
fengjiayi 已提交
50 51 52 53 54 55 56 57
      lod_tensors_ptr->Clear();
      for (size_t i = 0; i < num; ++i) {
        lod_tensors_ptr->Add();
      }
      return;
    } break;
    default:
      PADDLE_THROW(
58 59 60
          platform::errors::Unavailable("Setting 'sub_tensor_number' is not "
                                        "supported by the %s type variable.",
                                        this->Name()));
F
fengjiayi 已提交
61
  }
L
Leo Chen 已提交
62
  need_updated_ = true;
F
fengjiayi 已提交
63 64 65
}

size_t VarDesc::GetTensorDescNum() const {
66 67 68
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      return desc_.type().reader().lod_tensor_size();
F
fengjiayi 已提交
69 70 71
      break;
    default:
      PADDLE_THROW(
72 73 74
          platform::errors::Unavailable("Getting 'sub_tensor_number' is not "
                                        "supported by the %s type variable.",
                                        this->Name()));
F
fengjiayi 已提交
75 76 77 78
  }
}

void VarDesc::SetShapes(
F
fengjiayi 已提交
79
    const std::vector<std::vector<int64_t>> &multiple_dims) {
F
fengjiayi 已提交
80
  if (multiple_dims.size() != GetTensorDescNum()) {
M
minqiyang 已提交
81 82 83 84
    VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size()
            << ") doesn't match the existing tensor number("
            << GetTensorDescNum()
            << "). The Reader is going to be reinitialized.";
F
fengjiayi 已提交
85 86
    SetTensorDescNum(multiple_dims.size());
  }
87
  std::vector<proto::VarType::TensorDesc *> tensors = mutable_tensor_descs();
F
fengjiayi 已提交
88 89 90
  for (size_t i = 0; i < multiple_dims.size(); ++i) {
    VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
  }
L
Leo Chen 已提交
91
  need_updated_ = true;
F
fengjiayi 已提交
92 93 94 95 96 97 98
}

std::vector<int64_t> VarDesc::GetShape() const {
  return RepeatedToVector(tensor_desc().dims());
}

std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
99
  std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
F
fengjiayi 已提交
100 101 102 103 104 105 106 107
  std::vector<std::vector<int64_t>> res;
  res.reserve(descs.size());
  for (const auto &tensor_desc : descs) {
    res.push_back(RepeatedToVector(tensor_desc.dims()));
  }
  return res;
}

108
void VarDesc::SetDataType(proto::VarType::Type data_type) {
X
Xin Pan 已提交
109
  mutable_tensor_desc()->set_data_type(data_type);
L
Leo Chen 已提交
110
  need_updated_ = true;
F
fengjiayi 已提交
111 112
}

F
fengjiayi 已提交
113
void VarDesc::SetDataTypes(
114
    const std::vector<proto::VarType::Type> &multiple_data_type) {
F
fengjiayi 已提交
115
  if (multiple_data_type.size() != GetTensorDescNum()) {
M
minqiyang 已提交
116 117 118 119 120
    VLOG(3) << "WARNING: The number of given data types("
            << multiple_data_type.size()
            << ") doesn't match the existing tensor number("
            << GetTensorDescNum()
            << "). The Reader is going to be reinitialized.";
F
fengjiayi 已提交
121 122
    SetTensorDescNum(multiple_data_type.size());
  }
123 124
  std::vector<proto::VarType::TensorDesc *> tensor_descs =
      mutable_tensor_descs();
F
fengjiayi 已提交
125 126 127
  for (size_t i = 0; i < multiple_data_type.size(); ++i) {
    tensor_descs[i]->set_data_type(multiple_data_type[i]);
  }
L
Leo Chen 已提交
128
  need_updated_ = true;
F
fengjiayi 已提交
129 130
}

131
proto::VarType::Type VarDesc::GetDataType() const {
X
Xin Pan 已提交
132
  return tensor_desc().data_type();
133
}
Y
Stash  
Yu Yang 已提交
134

135 136 137 138
size_t VarDesc::ElementSize() const {
  return framework::SizeOfType(GetDataType());
}

139
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
140
  std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
141
  std::vector<proto::VarType::Type> res;
F
fengjiayi 已提交
142 143 144 145 146 147 148
  res.reserve(descs.size());
  for (const auto &tensor_desc : descs) {
    res.push_back(tensor_desc.data_type());
  }
  return res;
}

Y
Yu Yang 已提交
149
void VarDesc::SetLoDLevel(int32_t lod_level) {
150 151 152
  switch (desc_.type().type()) {
    case proto::VarType::LOD_TENSOR:
      desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level);
Y
Yu Yang 已提交
153
      break;
154 155
    case proto::VarType::LOD_TENSOR_ARRAY:
      desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level);
Y
Yu Yang 已提交
156 157
      break;
    default:
158 159 160
      PADDLE_THROW(platform::errors::Unavailable(
          "Setting 'lod_level' is not supported by the %s type variable.",
          this->Name()));
F
fengjiayi 已提交
161
  }
L
Leo Chen 已提交
162
  need_updated_ = true;
F
fengjiayi 已提交
163 164 165
}

void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
F
fengjiayi 已提交
166
  if (multiple_lod_level.size() != GetTensorDescNum()) {
M
minqiyang 已提交
167 168 169 170 171
    VLOG(3) << "WARNING: The number of given lod_levels("
            << multiple_lod_level.size()
            << ") doesn't match the existing tensor number("
            << GetTensorDescNum()
            << "). The Reader is going to be reinitialized.";
F
fengjiayi 已提交
172 173
    SetTensorDescNum(multiple_lod_level.size());
  }
174 175
  switch (desc_.type().type()) {
    case proto::VarType::READER: {
F
fengjiayi 已提交
176
      size_t i = 0;
177 178
      for (auto &lod_tensor :
           *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
F
fengjiayi 已提交
179 180 181 182
        lod_tensor.set_lod_level(multiple_lod_level[i++]);
      }
    } break;
    default:
183 184 185
      PADDLE_THROW(platform::errors::Unavailable(
          "Setting 'lod_levels' is not supported by the %s type variable",
          this->Name()));
Y
Yu Yang 已提交
186
  }
L
Leo Chen 已提交
187
  need_updated_ = true;
Y
Stash  
Yu Yang 已提交
188 189
}

190
int32_t VarDesc::GetLoDLevel() const {
191 192 193 194 195
  switch (desc_.type().type()) {
    case proto::VarType::LOD_TENSOR:
      return desc_.type().lod_tensor().lod_level();
    case proto::VarType::LOD_TENSOR_ARRAY:
      return desc_.type().tensor_array().lod_level();
Y
Yu Yang 已提交
196
    default:
197 198 199
      PADDLE_THROW(platform::errors::Unavailable(
          "Getting 'lod_level' is not supported by the %s type variable.",
          this->Name()));
F
fengjiayi 已提交
200 201 202 203 204
  }
}

std::vector<int32_t> VarDesc::GetLoDLevels() const {
  std::vector<int32_t> res;
205 206 207 208
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      res.reserve(desc_.type().reader().lod_tensor_size());
      for (auto &lod_tensor : desc_.type().reader().lod_tensor()) {
F
fengjiayi 已提交
209 210 211 212 213
        res.push_back(lod_tensor.lod_level());
      }
      return res;
      break;
    default:
214 215 216
      PADDLE_THROW(platform::errors::Unavailable(
          "Getting 'lod_levels' is not supported by the %s type variable.",
          this->Name()));
Y
Yu Yang 已提交
217
  }
Y
Stash  
Yu Yang 已提交
218
}
Y
Yu Yang 已提交
219

220
const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
221
  PADDLE_ENFORCE_EQ(
222 223
      desc_.has_type(),
      true,
224 225
      platform::errors::NotFound("The variable's type was not be set."));
  PADDLE_ENFORCE_EQ(
226 227
      desc_.type().has_type(),
      true,
228
      platform::errors::NotFound("The variable's type was not be set."));
229 230 231 232 233 234 235
  switch (desc_.type().type()) {
    case proto::VarType::SELECTED_ROWS:
      return desc_.type().selected_rows();
    case proto::VarType::LOD_TENSOR:
      return desc_.type().lod_tensor().tensor();
    case proto::VarType::LOD_TENSOR_ARRAY:
      return desc_.type().tensor_array().tensor();
S
Steffy-zxf 已提交
236 237 238 239
    case proto::VarType::STRINGS:
      return desc_.type().strings();
    case proto::VarType::VOCAB:
      return desc_.type().vocab();
Y
Yu Yang 已提交
240
    default:
241 242 243
      PADDLE_THROW(platform::errors::Unavailable(
          "Getting 'tensor_desc' is not supported by the %s type variable.",
          this->Name()));
F
fengjiayi 已提交
244 245 246
  }
}

247
std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
248
  PADDLE_ENFORCE_EQ(
249 250
      desc_.has_type(),
      true,
251
      platform::errors::NotFound("The variable's type was not be set."));
252
  std::vector<proto::VarType::TensorDesc> res;
F
fengjiayi 已提交
253
  res.reserve(GetTensorDescNum());
254 255 256
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) {
F
fengjiayi 已提交
257 258 259 260
        res.push_back(lod_tensor.tensor());
      }
      return res;
    default:
261 262 263
      PADDLE_THROW(platform::errors::Unavailable(
          "Getting 'tensor_descs' is not supported by the %s type variable.",
          this->Name()));
Y
Yu Yang 已提交
264 265 266
  }
}

267
proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
268
  PADDLE_ENFORCE_EQ(
269 270
      desc_.has_type(),
      true,
271 272
      platform::errors::NotFound("The variable's type was not be set."));
  PADDLE_ENFORCE_EQ(
273 274
      desc_.type().has_type(),
      true,
275
      platform::errors::NotFound("The variable's type was not be set."));
276 277 278 279 280 281 282
  switch (desc_.type().type()) {
    case proto::VarType::SELECTED_ROWS:
      return desc_.mutable_type()->mutable_selected_rows();
    case proto::VarType::LOD_TENSOR:
      return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor();
    case proto::VarType::LOD_TENSOR_ARRAY:
      return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor();
S
Steffy-zxf 已提交
283 284 285 286
    case proto::VarType::STRINGS:
      return desc_.mutable_type()->mutable_strings();
    case proto::VarType::VOCAB:
      return desc_.mutable_type()->mutable_vocab();
Y
Yu Yang 已提交
287
    default:
F
fengjiayi 已提交
288
      PADDLE_THROW(
289 290 291
          platform::errors::Unavailable("Getting 'mutable_tensor_desc' is not "
                                        "supported by the %s type variable.",
                                        this->Name()));
Y
Yu Yang 已提交
292
  }
L
Leo Chen 已提交
293
  need_updated_ = true;
Y
Yu Yang 已提交
294
}
F
fengjiayi 已提交
295

296
std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
297
  PADDLE_ENFORCE_EQ(
298 299
      desc_.has_type(),
      true,
300 301
      platform::errors::NotFound("The variable's type was not be set."));
  PADDLE_ENFORCE_EQ(
302 303
      desc_.type().has_type(),
      true,
304
      platform::errors::NotFound("The variable's type was not be set."));
305
  std::vector<proto::VarType::TensorDesc *> res;
F
fengjiayi 已提交
306
  res.reserve(GetTensorDescNum());
307 308 309 310
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      for (auto &lod_tensor :
           *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
F
fengjiayi 已提交
311 312 313 314
        res.push_back(lod_tensor.mutable_tensor());
      }
      return res;
    default:
315 316 317
      PADDLE_THROW(platform::errors::Unavailable(
          "Getting 'tensor_descs' is not supported by the %s type variable.",
          this->Name()));
F
fengjiayi 已提交
318
  }
L
Leo Chen 已提交
319
  need_updated_ = true;
F
fengjiayi 已提交
320 321
}

322 323 324 325 326 327 328 329 330 331 332 333 334 335 336
std::vector<std::string> VarDesc::AttrNames() const {
  std::vector<std::string> retv;
  retv.reserve(attrs_.size());
  for (auto &attr : attrs_) {
    retv.push_back(attr.first);
  }
  return retv;
}

void VarDesc::RemoveAttr(const std::string &name) { attrs_.erase(name); }

void VarDesc::SetAttr(const std::string &name, const Attribute &v) {
  // NOTICE(sandyhouse): pybind11 will take the empty list in python as
  // the std::vector<int> type in C++; so we have to change the attr's type
  // here if we meet this issue
R
Ruibiao Chen 已提交
337
  proto::AttrType attr_type = static_cast<proto::AttrType>(v.index() - 1);
338
  if (attr_type == proto::AttrType::INTS &&
R
Ruibiao Chen 已提交
339
      PADDLE_GET_CONST(std::vector<int>, v).size() == 0u) {
340 341 342 343 344 345 346
    // Find current attr via attr name and set the correct attribute value
    this->attrs_[name] = std::vector<int>();
    return;
  }
  bool valid = attr_type == proto::AttrType::INT ||
               attr_type == proto::AttrType::STRING ||
               attr_type == proto::AttrType::INTS;
347
  PADDLE_ENFORCE_EQ(
348 349
      valid,
      true,
350 351 352
      platform::errors::InvalidArgument("The value for attr (%s) must be "
                                        "one of list or int or string.",
                                        name));
353 354 355 356 357 358

  this->attrs_[name] = v;
}

Attribute VarDesc::GetAttr(const std::string &name) const {
  auto it = attrs_.find(name);
359
  PADDLE_ENFORCE_NE(
360 361
      it,
      attrs_.end(),
362
      platform::errors::NotFound("Attribute %s is not found.", name));
363 364 365
  return it->second;
}

366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
TensorDistAttr *VarDesc::MutableDistAttr() {
  // If dist_attr_ is nullptr, construct a new one and return.
  if (dist_attr_) {
    return dist_attr_.get();
  } else {
    dist_attr_.reset(new TensorDistAttr(*this));
    return dist_attr_.get();
  }
}

void VarDesc::SetDistAttr(const TensorDistAttr &dist_attr) {
  // Make sure this dist attr be created
  MutableDistAttr();
  *dist_attr_ = dist_attr;
}

382 383 384 385 386
bool operator==(const VarDesc &left, const VarDesc &right) {
  return left.Proto()->SerializeAsString() ==
         right.Proto()->SerializeAsString();
}

F
fengjiayi 已提交
387 388
}  // namespace framework
}  // namespace paddle