var_desc.cc 3.1 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/framework/var_desc.h"
Y
Yu Yang 已提交
16
#include "paddle/platform/enforce.h"
F
fengjiayi 已提交
17 18 19 20

namespace paddle {
namespace framework {

21
proto::VarDesc::VarType VarDescBind::GetType() const { return desc_.type(); }
Q
QI JUN 已提交
22

23 24 25
void VarDescBind::SetType(proto::VarDesc::VarType type) {
  desc_.set_type(type);
}
Q
QI JUN 已提交
26

F
fengjiayi 已提交
27
void VarDescBind::SetShape(const std::vector<int64_t> &dims) {
Y
Yu Yang 已提交
28
  VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
F
fengjiayi 已提交
29 30
}

31
void VarDescBind::SetDataType(proto::DataType data_type) {
Y
Yu Yang 已提交
32
  mutable_tensor_desc()->set_data_type(data_type);
F
fengjiayi 已提交
33 34 35
}

std::vector<int64_t> VarDescBind::Shape() const {
Y
Yu Yang 已提交
36
  return RepeatedToVector(tensor_desc().dims());
F
fengjiayi 已提交
37 38
}

39 40 41
proto::DataType VarDescBind::GetDataType() const {
  return tensor_desc().data_type();
}
Y
Stash  
Yu Yang 已提交
42 43

void VarDescBind::SetLoDLevel(int32_t lod_level) {
Y
Yu Yang 已提交
44
  switch (desc_.type()) {
45
    case proto::VarDesc::LOD_TENSOR:
Y
Yu Yang 已提交
46 47
      desc_.mutable_lod_tensor()->set_lod_level(lod_level);
      break;
48
    case proto::VarDesc::LOD_TENSOR_ARRAY:
Y
Yu Yang 已提交
49 50 51
      desc_.mutable_tensor_array()->set_lod_level(lod_level);
      break;
    default:
52 53
      PADDLE_THROW("Tensor type=%d does not support LoDLevel",
                   desc_.tensor_array().lod_level());
Y
Yu Yang 已提交
54
  }
Y
Stash  
Yu Yang 已提交
55 56 57
}

int32_t VarDescBind::GetLodLevel() const {
Y
Yu Yang 已提交
58
  switch (desc_.type()) {
59
    case proto::VarDesc::LOD_TENSOR:
Y
Yu Yang 已提交
60
      return desc_.lod_tensor().lod_level();
61
    case proto::VarDesc::LOD_TENSOR_ARRAY:
Y
Yu Yang 已提交
62 63
      return desc_.tensor_array().lod_level();
    default:
64 65
      PADDLE_THROW("Tensor type=%d does not support LoDLevel",
                   desc_.tensor_array().lod_level());
Y
Yu Yang 已提交
66
  }
Y
Stash  
Yu Yang 已提交
67
}
Y
Yu Yang 已提交
68

69
const proto::TensorDesc &VarDescBind::tensor_desc() const {
Y
Yu Yang 已提交
70 71
  PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
  switch (desc_.type()) {
72
    case proto::VarDesc::SELECTED_ROWS:
Y
Yu Yang 已提交
73
      return desc_.selected_rows();
74
    case proto::VarDesc::LOD_TENSOR:
Y
Yu Yang 已提交
75
      return desc_.lod_tensor().tensor();
76
    case proto::VarDesc::LOD_TENSOR_ARRAY:
Y
Yu Yang 已提交
77
      return desc_.tensor_array().tensor();
Y
Yu Yang 已提交
78 79 80 81 82
    default:
      PADDLE_THROW("Unexpected branch.");
  }
}

83
proto::TensorDesc *VarDescBind::mutable_tensor_desc() {
Y
Yu Yang 已提交
84 85 86
  PADDLE_ENFORCE(desc_.has_type(),
                 "invoke MutableTensorDesc must after set type");
  switch (desc_.type()) {
87
    case proto::VarDesc::SELECTED_ROWS:
Y
Yu Yang 已提交
88
      return desc_.mutable_selected_rows();
89
    case proto::VarDesc::LOD_TENSOR:
Y
Yu Yang 已提交
90
      return desc_.mutable_lod_tensor()->mutable_tensor();
91
    case proto::VarDesc::LOD_TENSOR_ARRAY:
Y
Yu Yang 已提交
92
      return desc_.mutable_tensor_array()->mutable_tensor();
Y
Yu Yang 已提交
93 94 95 96
    default:
      PADDLE_THROW("Unexpected branch.");
  }
}
F
fengjiayi 已提交
97 98
}  // namespace framework
}  // namespace paddle