io_utils_tester.cc 5.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <glog/logging.h>
#include <gtest/gtest.h>
17

18
#include <utility>
19

20
#include "paddle/fluid/inference/api/helper.h"
21
#include "paddle/fluid/inference/utils/io_utils.h"
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83

namespace paddle {
namespace inference {
namespace {

bool pd_tensor_equal(const paddle::PaddleTensor& ref,
                     const paddle::PaddleTensor& t) {
  bool is_equal = true;
  VLOG(3) << "ref.name: " << ref.name << ", t.name: " << t.name;
  VLOG(3) << "ref.dtype: " << ref.dtype << ", t.dtype: " << t.dtype;
  VLOG(3) << "ref.lod_level: " << ref.lod.size()
          << ", t.dtype: " << t.lod.size();
  VLOG(3) << "ref.data_len: " << ref.data.length()
          << ", t.data_len: " << t.data.length();
  return is_equal && (ref.name == t.name) && (ref.lod == t.lod) &&
         (ref.dtype == t.dtype) &&
         (std::memcmp(ref.data.data(), t.data.data(), ref.data.length()) == 0);
}

template <typename T>
void test_io_utils() {
  std::vector<T> input({6, 8});
  paddle::PaddleTensor in;
  in.name = "Hello";
  in.shape = {1, 2};
  in.lod = std::vector<std::vector<size_t>>{{0, 1}};
  in.data = paddle::PaddleBuf(static_cast<void*>(input.data()),
                              input.size() * sizeof(T));
  in.dtype = paddle::inference::PaddleTensorGetDType<T>();
  std::stringstream ss;
  paddle::inference::SerializePDTensorToStream(&ss, in);
  paddle::PaddleTensor out;
  paddle::inference::DeserializePDTensorToStream(ss, &out);
  ASSERT_TRUE(pd_tensor_equal(in, out));
}
}  // namespace
}  // namespace inference
}  // namespace paddle

TEST(infer_io_utils, float32) { paddle::inference::test_io_utils<float>(); }

TEST(infer_io_utils, tensors) {
  // Create a float32 tensor.
  std::vector<float> input_fp32({1.1f, 3.2f, 5.0f, 8.2f});
  paddle::PaddleTensor in_fp32;
  in_fp32.name = "Tensor.fp32_0";
  in_fp32.shape = {2, 2};
  in_fp32.data = paddle::PaddleBuf(static_cast<void*>(input_fp32.data()),
                                   input_fp32.size() * sizeof(float));
  in_fp32.dtype = paddle::inference::PaddleTensorGetDType<float>();

  // Create a int64 tensor.
  std::vector<float> input_int64({5, 8});
  paddle::PaddleTensor in_int64;
  in_int64.name = "Tensor.int64_0";
  in_int64.shape = {1, 2};
  in_int64.lod = std::vector<std::vector<size_t>>{{0, 1}};
  in_int64.data = paddle::PaddleBuf(static_cast<void*>(input_int64.data()),
                                    input_int64.size() * sizeof(int64_t));
  in_int64.dtype = paddle::inference::PaddleTensorGetDType<int64_t>();

  // Serialize tensors.
84
  std::vector<paddle::PaddleTensor> tensors_in({in_fp32});
85 86 87 88 89 90 91 92 93 94 95 96 97 98
  std::string file_path = "./io_utils_tensors";
  paddle::inference::SerializePDTensorsToFile(file_path, tensors_in);

  // Deserialize tensors.
  std::vector<paddle::PaddleTensor> tensors_out;
  paddle::inference::DeserializePDTensorsToFile(file_path, &tensors_out);

  // Check results.
  ASSERT_EQ(tensors_in.size(), tensors_out.size());
  for (size_t i = 0; i < tensors_in.size(); ++i) {
    ASSERT_TRUE(
        paddle::inference::pd_tensor_equal(tensors_in[i], tensors_out[i]));
  }
}
99 100 101 102

TEST(shape_info_io, read_and_write) {
  const std::string path = "test_shape_info_io";
  std::map<std::string, std::vector<int32_t>> min_shape, max_shape, opt_shape;
103
  std::map<std::string, std::vector<int32_t>> min_value, max_value, opt_value;
104 105 106 107 108 109
  min_shape.insert(
      std::make_pair("test1", std::vector<int32_t>{1, 3, 112, 112}));
  max_shape.insert(
      std::make_pair("test1", std::vector<int32_t>{1, 3, 224, 224}));
  opt_shape.insert(
      std::make_pair("test1", std::vector<int32_t>{1, 3, 224, 224}));
110 111 112 113 114 115
  min_value.insert(
      std::make_pair("test1", std::vector<int32_t>{1, 3, 112, 112}));
  max_value.insert(
      std::make_pair("test1", std::vector<int32_t>{1, 3, 224, 224}));
  opt_value.insert(
      std::make_pair("test1", std::vector<int32_t>{1, 3, 224, 224}));
116
  paddle::inference::SerializeShapeRangeInfo(
117
      path, min_shape, max_shape, opt_shape, min_value, max_value, opt_value);
118 119 120
  min_shape.clear();
  max_shape.clear();
  opt_shape.clear();
121 122 123
  min_value.clear();
  max_value.clear();
  opt_value.clear();
124 125
  opt_shape.insert(
      std::make_pair("test2", std::vector<int32_t>{1, 3, 224, 224}));
126 127 128 129 130 131 132
  paddle::inference::DeserializeShapeRangeInfo(path,
                                               &min_shape,
                                               &max_shape,
                                               &opt_shape,
                                               &min_value,
                                               &max_value,
                                               &opt_value);
133 134 135

  min_shape.insert(std::make_pair("test1", std::vector<int32_t>{1, 3, 56, 56}));
  std::vector<std::string> names{"test1"};
136 137
  paddle::inference::UpdateShapeRangeInfo(
      path, min_shape, max_shape, opt_shape, names);
W
Wilber 已提交
138

139 140 141 142 143 144 145
  ASSERT_THROW(paddle::inference::DeserializeShapeRangeInfo("no_exists_file",
                                                            &min_shape,
                                                            &max_shape,
                                                            &opt_shape,
                                                            &min_value,
                                                            &max_value,
                                                            &opt_value);
W
Wilber 已提交
146
               , paddle::platform::EnforceNotMet);
147
}