var_desc.cc 10.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
F
fengjiayi 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15 16
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/platform/enforce.h"
F
fengjiayi 已提交
17 18 19 20

namespace paddle {
namespace framework {

21
proto::VarType::Type VarDesc::GetType() const { return desc_.type().type(); }
Q
QI JUN 已提交
22

23 24 25
void VarDesc::SetType(proto::VarType::Type type) {
  desc_.mutable_type()->set_type(type);
}
Q
QI JUN 已提交
26

Y
Yu Yang 已提交
27
void VarDesc::SetShape(const std::vector<int64_t> &dims) {
Y
Yu Yang 已提交
28
  VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
F
fengjiayi 已提交
29 30
}

F
fengjiayi 已提交
31
void VarDesc::SetTensorDescNum(size_t num) {
32 33 34 35
  switch (desc_.type().type()) {
    case proto::VarType::READER: {
      auto *lod_tensors_ptr =
          desc_.mutable_type()->mutable_reader()->mutable_lod_tensor();
F
fengjiayi 已提交
36 37 38 39 40 41 42 43 44 45 46 47 48 49
      lod_tensors_ptr->Clear();
      for (size_t i = 0; i < num; ++i) {
        lod_tensors_ptr->Add();
      }
      return;
    } break;
    default:
      PADDLE_THROW(
          "Setting 'sub_tensor_number' is not supported by the type of var %s.",
          this->Name());
  }
}

size_t VarDesc::GetTensorDescNum() const {
50 51 52
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      return desc_.type().reader().lod_tensor_size();
F
fengjiayi 已提交
53 54 55 56 57 58 59 60 61
      break;
    default:
      PADDLE_THROW(
          "Getting 'sub_tensor_number' is not supported by the type of var %s.",
          this->Name());
  }
}

void VarDesc::SetShapes(
F
fengjiayi 已提交
62
    const std::vector<std::vector<int64_t>> &multiple_dims) {
F
fengjiayi 已提交
63 64 65 66 67 68 69
  if (multiple_dims.size() != GetTensorDescNum()) {
    VLOG(3) << "WARNING: The number of given shapes(" << multiple_dims.size()
            << ") doesn't match the existing tensor number("
            << GetTensorDescNum()
            << "). The Reader is going to be reinitialized.";
    SetTensorDescNum(multiple_dims.size());
  }
70
  std::vector<proto::VarType::TensorDesc *> tensors = mutable_tensor_descs();
F
fengjiayi 已提交
71 72 73 74 75 76 77 78 79 80
  for (size_t i = 0; i < multiple_dims.size(); ++i) {
    VectorToRepeated(multiple_dims[i], tensors[i]->mutable_dims());
  }
}

std::vector<int64_t> VarDesc::GetShape() const {
  return RepeatedToVector(tensor_desc().dims());
}

std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
81
  std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
F
fengjiayi 已提交
82 83 84 85 86 87 88 89
  std::vector<std::vector<int64_t>> res;
  res.reserve(descs.size());
  for (const auto &tensor_desc : descs) {
    res.push_back(RepeatedToVector(tensor_desc.dims()));
  }
  return res;
}

90
void VarDesc::SetDataType(proto::VarType::Type data_type) {
91 92 93 94 95 96 97
  switch (desc_.type().type()) {
    case proto::VarType::CHANNEL:
      mutable_channel_desc()->set_data_type(data_type);
      break;
    default:
      mutable_tensor_desc()->set_data_type(data_type);
  }
F
fengjiayi 已提交
98 99
}

F
fengjiayi 已提交
100
void VarDesc::SetDataTypes(
101
    const std::vector<proto::VarType::Type> &multiple_data_type) {
F
fengjiayi 已提交
102 103 104 105 106 107 108 109
  if (multiple_data_type.size() != GetTensorDescNum()) {
    VLOG(3) << "WARNING: The number of given data types("
            << multiple_data_type.size()
            << ") doesn't match the existing tensor number("
            << GetTensorDescNum()
            << "). The Reader is going to be reinitialized.";
    SetTensorDescNum(multiple_data_type.size());
  }
110 111
  std::vector<proto::VarType::TensorDesc *> tensor_descs =
      mutable_tensor_descs();
F
fengjiayi 已提交
112 113 114
  for (size_t i = 0; i < multiple_data_type.size(); ++i) {
    tensor_descs[i]->set_data_type(multiple_data_type[i]);
  }
F
fengjiayi 已提交
115 116
}

117
proto::VarType::Type VarDesc::GetDataType() const {
118 119 120 121 122 123 124
  switch (desc_.type().type()) {
    case proto::VarType::CHANNEL:
      return channel_desc().data_type();
      break;
    default:
      return tensor_desc().data_type();
  }
125
}
Y
Stash  
Yu Yang 已提交
126

127
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
128
  std::vector<proto::VarType::TensorDesc> descs = tensor_descs();
129
  std::vector<proto::VarType::Type> res;
F
fengjiayi 已提交
130 131 132 133 134 135 136
  res.reserve(descs.size());
  for (const auto &tensor_desc : descs) {
    res.push_back(tensor_desc.data_type());
  }
  return res;
}

137 138 139 140 141 142 143 144 145 146 147
void VarDesc::SetCapacity(int64_t capacity) {
  switch (desc_.type().type()) {
    case proto::VarType::CHANNEL:
      desc_.mutable_type()->mutable_channel()->set_capacity(capacity);
      break;
    default:
      PADDLE_THROW("Setting 'capacity' is not supported by the type of var %s.",
                   this->Name());
  }
}

Y
Yu Yang 已提交
148
void VarDesc::SetLoDLevel(int32_t lod_level) {
149 150 151
  switch (desc_.type().type()) {
    case proto::VarType::LOD_TENSOR:
      desc_.mutable_type()->mutable_lod_tensor()->set_lod_level(lod_level);
Y
Yu Yang 已提交
152
      break;
153 154
    case proto::VarType::LOD_TENSOR_ARRAY:
      desc_.mutable_type()->mutable_tensor_array()->set_lod_level(lod_level);
Y
Yu Yang 已提交
155 156
      break;
    default:
F
fengjiayi 已提交
157 158 159 160 161 162 163
      PADDLE_THROW(
          "Setting 'lod_level' is not supported by the type of var %s.",
          this->Name());
  }
}

void VarDesc::SetLoDLevels(const std::vector<int32_t> &multiple_lod_level) {
F
fengjiayi 已提交
164 165 166 167 168 169 170 171
  if (multiple_lod_level.size() != GetTensorDescNum()) {
    VLOG(3) << "WARNING: The number of given lod_levels("
            << multiple_lod_level.size()
            << ") doesn't match the existing tensor number("
            << GetTensorDescNum()
            << "). The Reader is going to be reinitialized.";
    SetTensorDescNum(multiple_lod_level.size());
  }
172 173
  switch (desc_.type().type()) {
    case proto::VarType::READER: {
F
fengjiayi 已提交
174
      size_t i = 0;
175 176
      for (auto &lod_tensor :
           *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
F
fengjiayi 已提交
177 178 179 180 181 182 183
        lod_tensor.set_lod_level(multiple_lod_level[i++]);
      }
    } break;
    default:
      PADDLE_THROW(
          "Setting 'lod_levels' is not supported by the type of var %s.",
          this->Name());
Y
Yu Yang 已提交
184
  }
Y
Stash  
Yu Yang 已提交
185 186
}

187
int32_t VarDesc::GetLoDLevel() const {
188 189 190 191 192
  switch (desc_.type().type()) {
    case proto::VarType::LOD_TENSOR:
      return desc_.type().lod_tensor().lod_level();
    case proto::VarType::LOD_TENSOR_ARRAY:
      return desc_.type().tensor_array().lod_level();
Y
Yu Yang 已提交
193
    default:
F
fengjiayi 已提交
194 195 196 197 198 199 200 201
      PADDLE_THROW(
          "Getting 'lod_level' is not supported by the type of var %s.",
          this->Name());
  }
}

std::vector<int32_t> VarDesc::GetLoDLevels() const {
  std::vector<int32_t> res;
202 203 204 205
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      res.reserve(desc_.type().reader().lod_tensor_size());
      for (auto &lod_tensor : desc_.type().reader().lod_tensor()) {
F
fengjiayi 已提交
206 207 208 209 210 211 212 213
        res.push_back(lod_tensor.lod_level());
      }
      return res;
      break;
    default:
      PADDLE_THROW(
          "Getting 'lod_levels' is not supported by the type of var %s.",
          this->Name());
Y
Yu Yang 已提交
214
  }
Y
Stash  
Yu Yang 已提交
215
}
Y
Yu Yang 已提交
216

217 218 219 220 221 222 223 224 225 226 227 228 229
const proto::VarType::ChannelDesc &VarDesc::channel_desc() const {
  PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
  PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
  switch (desc_.type().type()) {
    case proto::VarType::CHANNEL:
      return desc_.type().channel();
    default:
      PADDLE_THROW(
          "Getting 'channel_desc' is not supported by the type of var %s.",
          this->Name());
  }
}

230
const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
F
fengjiayi 已提交
231
  PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
232 233 234 235 236 237 238 239
  PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
  switch (desc_.type().type()) {
    case proto::VarType::SELECTED_ROWS:
      return desc_.type().selected_rows();
    case proto::VarType::LOD_TENSOR:
      return desc_.type().lod_tensor().tensor();
    case proto::VarType::LOD_TENSOR_ARRAY:
      return desc_.type().tensor_array().tensor();
Y
Yu Yang 已提交
240
    default:
F
fengjiayi 已提交
241 242 243 244 245 246
      PADDLE_THROW(
          "Getting 'tensor_desc' is not supported by the type of var %s.",
          this->Name());
  }
}

247
std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
F
fengjiayi 已提交
248
  PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
249
  std::vector<proto::VarType::TensorDesc> res;
F
fengjiayi 已提交
250
  res.reserve(GetTensorDescNum());
251 252 253
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      for (const auto &lod_tensor : desc_.type().reader().lod_tensor()) {
F
fengjiayi 已提交
254 255 256 257 258 259 260 261
        res.push_back(lod_tensor.tensor());
      }
      return res;
    default:
      PADDLE_THROW(
          "Getting 'tensor_descs' is not supported by the type of var "
          "%s.",
          this->Name());
Y
Yu Yang 已提交
262 263 264
  }
}

265 266 267 268 269 270 271 272 273 274 275 276 277 278
proto::VarType::ChannelDesc *VarDesc::mutable_channel_desc() {
  PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
  PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
  switch (desc_.type().type()) {
    case proto::VarType::CHANNEL:
      return desc_.mutable_type()->mutable_channel();
    default:
      PADDLE_THROW(
          "Getting 'mutable_channel_desc' is not supported by the type of var "
          "%s.",
          this->Name());
  }
}

279
proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
F
fengjiayi 已提交
280
  PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
281 282 283 284 285 286 287 288
  PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
  switch (desc_.type().type()) {
    case proto::VarType::SELECTED_ROWS:
      return desc_.mutable_type()->mutable_selected_rows();
    case proto::VarType::LOD_TENSOR:
      return desc_.mutable_type()->mutable_lod_tensor()->mutable_tensor();
    case proto::VarType::LOD_TENSOR_ARRAY:
      return desc_.mutable_type()->mutable_tensor_array()->mutable_tensor();
Y
Yu Yang 已提交
289
    default:
F
fengjiayi 已提交
290 291 292 293
      PADDLE_THROW(
          "Getting 'mutable_tensor_desc' is not supported by the type of var "
          "%s.",
          this->Name());
Y
Yu Yang 已提交
294 295
  }
}
F
fengjiayi 已提交
296

297
std::vector<proto::VarType::TensorDesc *> VarDesc::mutable_tensor_descs() {
F
fengjiayi 已提交
298
  PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
299 300
  PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
  std::vector<proto::VarType::TensorDesc *> res;
F
fengjiayi 已提交
301
  res.reserve(GetTensorDescNum());
302 303 304 305
  switch (desc_.type().type()) {
    case proto::VarType::READER:
      for (auto &lod_tensor :
           *desc_.mutable_type()->mutable_reader()->mutable_lod_tensor()) {
F
fengjiayi 已提交
306 307 308 309 310 311 312 313 314 315 316
        res.push_back(lod_tensor.mutable_tensor());
      }
      return res;
    default:
      PADDLE_THROW(
          "Getting 'tensor_descs' is not supported by the type of var "
          "%s.",
          this->Name());
  }
}

F
fengjiayi 已提交
317 318
}  // namespace framework
}  // namespace paddle