var_desc.cc 3.1 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/var_desc.h"
Y
Yu Yang 已提交
16
#include "paddle/platform/enforce.h"
F
fengjiayi 已提交
17 18 19 20

namespace paddle {
namespace framework {

Y
Yu Yang 已提交
21
proto::VarDesc::VarType VarDesc::GetType() const { return desc_.type(); }
Q
QI JUN 已提交
22

Y
Yu Yang 已提交
23
void VarDesc::SetType(proto::VarDesc::VarType type) { desc_.set_type(type); }
Q
QI JUN 已提交
24

Y
Yu Yang 已提交
25
void VarDesc::SetShape(const std::vector<int64_t> &dims) {
Y
Yu Yang 已提交
26
  VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
F
fengjiayi 已提交
27 28
}

Y
Yu Yang 已提交
29
void VarDesc::SetDataType(proto::DataType data_type) {
Y
Yu Yang 已提交
30
  mutable_tensor_desc()->set_data_type(data_type);
F
fengjiayi 已提交
31 32
}

Y
Yu Yang 已提交
33
std::vector<int64_t> VarDesc::Shape() const {
Y
Yu Yang 已提交
34
  return RepeatedToVector(tensor_desc().dims());
F
fengjiayi 已提交
35 36
}

Y
Yu Yang 已提交
37
proto::DataType VarDesc::GetDataType() const {
38 39
  return tensor_desc().data_type();
}
Y
Stash  
Yu Yang 已提交
40

Y
Yu Yang 已提交
41
void VarDesc::SetLoDLevel(int32_t lod_level) {
Y
Yu Yang 已提交
42
  switch (desc_.type()) {
43
    case proto::VarDesc::LOD_TENSOR:
Y
Yu Yang 已提交
44 45
      desc_.mutable_lod_tensor()->set_lod_level(lod_level);
      break;
46
    case proto::VarDesc::LOD_TENSOR_ARRAY:
Y
Yu Yang 已提交
47 48 49
      desc_.mutable_tensor_array()->set_lod_level(lod_level);
      break;
    default:
50 51
      PADDLE_THROW("Tensor type=%d does not support LoDLevel",
                   desc_.tensor_array().lod_level());
Y
Yu Yang 已提交
52
  }
Y
Stash  
Yu Yang 已提交
53 54
}

Y
Yu Yang 已提交
55
int32_t VarDesc::GetLodLevel() const {
Y
Yu Yang 已提交
56
  switch (desc_.type()) {
57
    case proto::VarDesc::LOD_TENSOR:
Y
Yu Yang 已提交
58
      return desc_.lod_tensor().lod_level();
59
    case proto::VarDesc::LOD_TENSOR_ARRAY:
Y
Yu Yang 已提交
60 61
      return desc_.tensor_array().lod_level();
    default:
62 63
      PADDLE_THROW("Tensor type=%d does not support LoDLevel",
                   desc_.tensor_array().lod_level());
Y
Yu Yang 已提交
64
  }
Y
Stash  
Yu Yang 已提交
65
}
Y
Yu Yang 已提交
66

Y
Yu Yang 已提交
67
const proto::TensorDesc &VarDesc::tensor_desc() const {
Y
Yu Yang 已提交
68 69
  PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
  switch (desc_.type()) {
70
    case proto::VarDesc::SELECTED_ROWS:
Y
Yu Yang 已提交
71
      return desc_.selected_rows();
72
    case proto::VarDesc::LOD_TENSOR:
Y
Yu Yang 已提交
73
      return desc_.lod_tensor().tensor();
74
    case proto::VarDesc::LOD_TENSOR_ARRAY:
Y
Yu Yang 已提交
75
      return desc_.tensor_array().tensor();
Y
Yu Yang 已提交
76
    default:
F
fengjiayi 已提交
77
      PADDLE_THROW("The type of var '", this->Name(), "' is unsupported.");
Y
Yu Yang 已提交
78 79 80
  }
}

Y
Yu Yang 已提交
81
proto::TensorDesc *VarDesc::mutable_tensor_desc() {
Y
Yu Yang 已提交
82 83 84
  PADDLE_ENFORCE(desc_.has_type(),
                 "invoke MutableTensorDesc must after set type");
  switch (desc_.type()) {
85
    case proto::VarDesc::SELECTED_ROWS:
Y
Yu Yang 已提交
86
      return desc_.mutable_selected_rows();
87
    case proto::VarDesc::LOD_TENSOR:
Y
Yu Yang 已提交
88
      return desc_.mutable_lod_tensor()->mutable_tensor();
89
    case proto::VarDesc::LOD_TENSOR_ARRAY:
Y
Yu Yang 已提交
90
      return desc_.mutable_tensor_array()->mutable_tensor();
Y
Yu Yang 已提交
91 92 93 94
    default:
      PADDLE_THROW("Unexpected branch.");
  }
}
F
fengjiayi 已提交
95 96
}  // namespace framework
}  // namespace paddle