serde_test.cc 6.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <unistd.h>
#include <string>
Y
Yi Wang 已提交
17
#include <thread>  // NOLINT
18

Y
Yi Wang 已提交
19
#include "google/protobuf/text_format.h"
20 21 22 23 24
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/operators/detail/sendrecvop_utils.h"
25
#include "paddle/fluid/operators/detail/variable_response.h"
26 27 28 29 30 31 32 33 34 35
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/string/printf.h"

namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace operators = paddle::operators;
namespace math = paddle::operators::math;
namespace memory = paddle::memory;

36
void RunSerdeTestSelectedRows(platform::Place place) {
37 38
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto& ctx = *pool.Get(place);
39 40 41 42

  // serialize var to ByteBuffer
  framework::Variable var;
  auto* slr = var.GetMutable<framework::SelectedRows>();
43
  slr->set_height(1000);
44 45
  auto* tensor = slr->mutable_value();
  auto* rows = slr->mutable_rows();
T
typhoonzero 已提交
46
  tensor->Resize(framework::make_ddim({564, 128}));
47
  tensor->mutable_data<float>(place);
T
typhoonzero 已提交
48
  int tensor_numel = 564 * 128;
49
  math::set_constant(ctx, tensor, 32.7);
T
typhoonzero 已提交
50
  for (int i = 0; i < 564; ++i) rows->push_back(i);
51 52 53

  ::grpc::ByteBuffer msg;
  operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
54
  EXPECT_GT(msg.Length(), static_cast<size_t>(0));
55 56 57 58 59 60 61 62

  // deserialize
  std::vector<::grpc::Slice> slices;
  (void)msg.Dump(&slices);
  std::string tmp;
  for (const auto& s : slices) {
    tmp.append(reinterpret_cast<const char*>(s.begin()), s.size());
  }
63

64 65
  sendrecv::VariableMessage varmsg;
  EXPECT_TRUE(varmsg.ParseFromString(tmp));
66

T
typhoonzero 已提交
67
  // deserialize bytebuffer
68
  EXPECT_EQ(varmsg.varname(), "myvar");
69
  EXPECT_EQ(varmsg.type(), 1);
70 71 72

  const float* tensor_data =
      reinterpret_cast<const float*>(varmsg.serialized().data());
73 74
  const int64_t* rows_data =
      reinterpret_cast<const int64_t*>(varmsg.rows().data());
75
  for (int i = 0; i < tensor_numel; ++i) {
76
    EXPECT_FLOAT_EQ(tensor_data[i], 32.7);
77
  }
T
typhoonzero 已提交
78 79 80 81
  for (int i = 0; i < 564; ++i) {
    EXPECT_EQ(rows_data[i], i);
  }

82
  // deserialize zero-copy
83 84 85 86
  // framework::Variable var2;
  // operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
  framework::Scope scope;
  scope.Var("myvar");
87
  operators::detail::VariableResponse resp(&scope, &ctx);
88 89 90 91 92 93 94
  EXPECT_EQ(resp.Parse(msg), 0);

  framework::Variable* var2 = resp.GetVar();

  auto* slr2 = var2->GetMutable<framework::SelectedRows>();
  auto* tensor2 = slr2->mutable_value();
  auto* rows2 = slr2->mutable_rows();
95 96 97 98 99
  float* tensor_data2 = nullptr;
  framework::Tensor tmp_tensor;

  if (platform::is_gpu_place(ctx.GetPlace())) {
    platform::CPUPlace cpu;
100
    framework::TensorCopy(*tensor2, cpu, &tmp_tensor);
101 102
    tensor_data2 = tmp_tensor.data<float>();
  } else {
103
    tensor_data2 = const_cast<float*>(tensor2->data<float>());
104
  }
Y
Yi Wang 已提交
105
  const int64_t* rows_data2 = rows2->data();
106

107 108 109
  for (int i = 0; i < tensor_numel; ++i) {
    EXPECT_FLOAT_EQ(tensor_data2[i], 32.7);
  }
T
typhoonzero 已提交
110
  for (size_t i = 0; i < rows2->size(); ++i) {
Y
Yancey1989 已提交
111
    EXPECT_EQ(rows_data2[i], static_cast<int64_t>(i));
T
typhoonzero 已提交
112
  }
113
  EXPECT_EQ(slr2->height(), 1000);
114 115
}

116
void RunTestLodTensor(platform::Place place, int from_type = 0) {
117 118
  // serialize var to ByteBuffer
  framework::Variable var;
119
  auto* tensor = var.GetMutable<framework::LoDTensor>();
T
typhoonzero 已提交
120
  tensor->Resize(framework::make_ddim({512, 8, 4, 2}));
121 122 123
  framework::LoD lod;
  lod.push_back(framework::Vector<size_t>({1, 3, 8}));
  tensor->set_lod(lod);
T
typhoonzero 已提交
124
  int tensor_numel = 512 * 8 * 4 * 2;
125 126
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto& ctx = *pool.Get(place);
127
  tensor->mutable_data<float>(place);
128
  math::set_constant(ctx, tensor, 31.9);
129 130 131

  ::grpc::ByteBuffer msg;
  operators::detail::SerializeToByteBuffer("myvar", &var, ctx, &msg);
132
  EXPECT_GT(msg.Length(), static_cast<size_t>(0));
133 134 135 136 137 138 139 140 141 142 143

  // deserialize
  std::vector<::grpc::Slice> slices;
  (void)msg.Dump(&slices);
  std::string tmp;
  for (const auto& s : slices) {
    tmp.append(reinterpret_cast<const char*>(s.begin()), s.size());
  }
  sendrecv::VariableMessage varmsg;
  EXPECT_TRUE(varmsg.ParseFromString(tmp));
  EXPECT_EQ(varmsg.varname(), "myvar");
144
  EXPECT_EQ(varmsg.type(), 0);
T
typhoonzero 已提交
145
  EXPECT_EQ(varmsg.dims()[0], 512);
146 147 148 149 150 151 152
  EXPECT_EQ(varmsg.dims()[1], 8);
  EXPECT_EQ(varmsg.dims()[2], 4);
  EXPECT_EQ(varmsg.dims()[3], 2);
  EXPECT_EQ(varmsg.lod_level(), 1);
  EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
  EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
  EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
153 154 155 156

  const float* tensor_data =
      reinterpret_cast<const float*>(varmsg.serialized().data());
  for (int i = 0; i < tensor_numel; ++i) {
157
    EXPECT_FLOAT_EQ(tensor_data[i], 31.9);
158
  }
159 160 161 162 163 164 165 166 167 168 169 170

  // message binary
  std::string str;
  varmsg.SerializeToString(&str);

  // message bytebuffer
  ::grpc::Slice slices_2[1];
  int num_slices = 1;
  slices_2[0] = ::grpc::Slice(str.length());
  memcpy(const_cast<uint8_t*>(slices_2[0].begin()), str.c_str(), str.length());
  ::grpc::ByteBuffer bytebuffer2(&slices_2[0], num_slices);

171
  // deserialize zero-copy
172 173
  framework::Scope scope;
  scope.Var("myvar");
174
  operators::detail::VariableResponse resp(&scope, &ctx);
175 176 177 178 179
  if (from_type == 0) {
    EXPECT_EQ(resp.Parse(msg), 0);
  } else {
    EXPECT_EQ(resp.Parse(bytebuffer2), 0);
  }
180

181 182 183
  framework::Variable* var2 = resp.GetVar();

  auto tensor2 = var2->Get<framework::LoDTensor>();
184 185 186 187 188
  float* tensor_data2 = nullptr;
  framework::Tensor tmp_tensor;

  if (platform::is_gpu_place(ctx.GetPlace())) {
    platform::CPUPlace cpu;
189
    framework::TensorCopy(tensor2, cpu, &tmp_tensor);
190 191
    tensor_data2 = tmp_tensor.data<float>();
  } else {
192
    tensor_data2 = const_cast<float*>(tensor2.data<float>());
193 194
  }

195 196 197 198 199 200 201
  EXPECT_EQ(varmsg.lod_level(), 1);
  EXPECT_EQ(varmsg.lod(0).lod_data(0), 1);
  EXPECT_EQ(varmsg.lod(0).lod_data(1), 3);
  EXPECT_EQ(varmsg.lod(0).lod_data(2), 8);
  for (int i = 0; i < tensor_numel; ++i) EXPECT_FLOAT_EQ(tensor_data2[i], 31.9);
}

Y
Yancey 已提交
202 203
TEST(LodTensor, Run) {
  platform::CPUPlace place;
204 205
  RunTestLodTensor(place);
  RunTestLodTensor(place, 1);
Y
Yancey 已提交
206
#ifdef PADDLE_WITH_CUDA
Y
yi.wu 已提交
207 208 209
  platform::CUDAPlace gpu(0);
  RunTestLodTensor(gpu);
  RunTestLodTensor(gpu, 1);
Y
Yancey 已提交
210
#endif
211 212
}

Y
Yancey 已提交
213
TEST(SelectedRows, Run) {
214 215
  platform::CPUPlace place;
  RunSerdeTestSelectedRows(place);
216

Y
Yancey 已提交
217
#ifdef PADDLE_WITH_CUDA
Y
yi.wu 已提交
218 219
  platform::CUDAPlace gpu;
  RunSerdeTestSelectedRows(gpu);
Y
Yancey 已提交
220
#endif
221
}