gpu_info.cc 18.0 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
L
liaogang 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

Y
Yi Wang 已提交
15
#include "paddle/fluid/platform/gpu_info.h"
16
#include <algorithm>
S
sneaxiy 已提交
17
#include <cstdlib>
18
#include <memory>
S
sneaxiy 已提交
19
#include <string>
L
liaogang 已提交
20

21
#include "gflags/gflags.h"
22
#include "paddle/fluid/platform/cuda_device_guard.h"
Y
Yi Wang 已提交
23
#include "paddle/fluid/platform/enforce.h"
24 25
#include "paddle/fluid/platform/lock_guard_ptr.h"
#include "paddle/fluid/platform/macros.h"
26
#include "paddle/fluid/string/split.h"
L
liaogang 已提交
27

28 29 30 31 32
DECLARE_double(fraction_of_gpu_memory_to_use);
DECLARE_uint64(initial_gpu_memory_in_mb);
DECLARE_uint64(reallocate_gpu_memory_in_mb);
DECLARE_bool(enable_cublas_tensor_op_math);
DECLARE_string(selected_gpus);
33
DECLARE_uint64(gpu_memory_limit_mb);
34

Z
zhhsplendid 已提交
35 36
constexpr static float fraction_reserve_gpu_memory = 0.05f;

L
liaogang 已提交
37 38 39
namespace paddle {
namespace platform {

40 41 42 43 44
/* Here is a very simple CUDA “pro tip”: cudaDeviceGetAttribute() is a much
faster way to query device properties. You can see details in
https://devblogs.nvidia.com/cuda-pro-tip-the-fast-way-to-query-device-properties/
*/

45 46 47 48 49 50
inline std::string CudaErrorWebsite() {
  return "Please see detail in https://docs.nvidia.com/cuda/cuda-runtime-api"
         "/group__CUDART__TYPES.html#group__CUDART__TYPES_1g3f51e3575c217824"
         "6db0a94a430e0038";
}

S
sneaxiy 已提交
51
static int GetCUDADeviceCountImpl() {
52 53 54 55 56 57 58 59
  int driverVersion = 0;
  cudaError_t status = cudaDriverGetVersion(&driverVersion);

  if (!(status == cudaSuccess && driverVersion != 0)) {
    // No GPU driver
    return 0;
  }

S
sneaxiy 已提交
60 61 62 63 64 65
  const auto *cuda_visible_devices = std::getenv("CUDA_VISIBLE_DEVICES");
  if (cuda_visible_devices != nullptr) {
    std::string cuda_visible_devices_str(cuda_visible_devices);
    if (std::all_of(cuda_visible_devices_str.begin(),
                    cuda_visible_devices_str.end(),
                    [](char ch) { return ch == ' '; })) {
S
sneaxiy 已提交
66
      VLOG(2) << "CUDA_VISIBLE_DEVICES is set to be empty. No GPU detected.";
S
sneaxiy 已提交
67 68 69 70
      return 0;
    }
  }

L
liaogang 已提交
71
  int count;
72
  auto error_code = cudaGetDeviceCount(&count);
L
liaogang 已提交
73
  PADDLE_ENFORCE(
74 75 76 77
      error_code,
      "cudaGetDeviceCount failed in "
      "paddle::platform::GetCUDADeviceCountImpl, error code : %d, %s",
      error_code, CudaErrorWebsite());
L
liaogang 已提交
78 79 80
  return count;
}

S
sneaxiy 已提交
81 82 83 84 85
int GetCUDADeviceCount() {
  static auto dev_cnt = GetCUDADeviceCountImpl();
  return dev_cnt;
}

86 87
int GetCUDAComputeCapability(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
88 89 90 91 92 93 94 95 96
  int major, minor;

  auto major_error_code =
      cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, id);
  auto minor_error_code =
      cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, id);
  PADDLE_ENFORCE_EQ(
      major_error_code, 0,
      "cudaDevAttrComputeCapabilityMajor failed in "
97
      "paddle::platform::GetCUDAComputeCapability, error code : %d, %s",
98 99 100 101 102 103 104
      major_error_code, CudaErrorWebsite());
  PADDLE_ENFORCE_EQ(
      minor_error_code, 0,
      "cudaDevAttrComputeCapabilityMinor failed in "
      "paddle::platform::GetCUDAComputeCapability, error code : %d, %s",
      minor_error_code, CudaErrorWebsite());
  return major * 10 + minor;
105 106
}

107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133
dim3 GetGpuMaxGridDimSize(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  dim3 ret;
  int size;
  auto error_code_x = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimX, id);
  PADDLE_ENFORCE_EQ(error_code_x, 0,
                    "cudaDevAttrMaxGridDimX failed in "
                    "paddle::platform::GpuMaxGridDimSize, error code : %d, %s",
                    error_code_x, CudaErrorWebsite());
  ret.x = size;

  auto error_code_y = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimY, id);
  PADDLE_ENFORCE_EQ(error_code_y, 0,
                    "cudaDevAttrMaxGridDimY failed in "
                    "paddle::platform::GpuMaxGridDimSize, error code : %d, %s",
                    error_code_y, CudaErrorWebsite());
  ret.y = size;

  auto error_code_z = cudaDeviceGetAttribute(&size, cudaDevAttrMaxGridDimZ, id);
  PADDLE_ENFORCE_EQ(error_code_z, 0,
                    "cudaDevAttrMaxGridDimZ failed in "
                    "paddle::platform::GpuMaxGridDimSize, error code : %d, %s",
                    error_code_z, CudaErrorWebsite());
  ret.z = size;
  return ret;
}

C
chengduo 已提交
134 135 136
int GetCUDARuntimeVersion(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int runtime_version = 0;
137 138
  auto error_code = cudaRuntimeGetVersion(&runtime_version);
  PADDLE_ENFORCE(error_code,
C
chengduo 已提交
139
                 "cudaRuntimeGetVersion failed in "
140 141
                 "paddle::platform::GetCUDARuntimeVersion, error code : %d, %s",
                 error_code, CudaErrorWebsite());
C
chengduo 已提交
142 143 144 145 146 147
  return runtime_version;
}

int GetCUDADriverVersion(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int driver_version = 0;
148 149
  auto error_code = cudaDriverGetVersion(&driver_version);
  PADDLE_ENFORCE(error_code,
C
chengduo 已提交
150
                 "cudaDriverGetVersion failed in "
151 152
                 "paddle::platform::GetCUDADriverVersion, error code : %d, %s",
                 error_code, CudaErrorWebsite());
C
chengduo 已提交
153 154 155
  return driver_version;
}

156 157 158 159 160 161 162 163 164 165
bool TensorCoreAvailable() {
#if CUDA_VERSION >= 9000
  int device = GetCurrentDeviceId();
  int driver_version = GetCUDAComputeCapability(device);
  return driver_version >= 70;
#else
  return false;
#endif
}

C
chengduoZH 已提交
166 167 168
int GetCUDAMultiProcessors(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int count;
169 170 171 172 173 174
  auto error_code =
      cudaDeviceGetAttribute(&count, cudaDevAttrMultiProcessorCount, id);
  PADDLE_ENFORCE(error_code,
                 "cudaDeviceGetAttribute failed in "
                 "paddle::platform::GetCUDAMultiProcess, error code : %d, %s",
                 error_code, CudaErrorWebsite());
C
chengduoZH 已提交
175 176 177 178 179 180
  return count;
}

int GetCUDAMaxThreadsPerMultiProcessor(int id) {
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
  int count;
181 182 183 184 185 186 187
  auto error_code = cudaDeviceGetAttribute(
      &count, cudaDevAttrMaxThreadsPerMultiProcessor, id);
  PADDLE_ENFORCE(
      error_code,
      "cudaDeviceGetAttribute failed in paddle::"
      "platform::GetCUDAMaxThreadsPerMultiProcessor, error code : %d, %s",
      error_code, CudaErrorWebsite());
C
chengduoZH 已提交
188 189 190
  return count;
}

191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209
int GetCUDAMaxThreadsPerBlock(int id) {
  PADDLE_ENFORCE_LT(
      id, GetCUDADeviceCount(),
      platform::errors::InvalidArgument(
          "Device id must less than GPU count, but received id is:%d, "
          "GPU count is: %d.",
          id, GetCUDADeviceCount()));
  int count;
  auto error_code =
      cudaDeviceGetAttribute(&count, cudaDevAttrMaxThreadsPerBlock, id);
  PADDLE_ENFORCE_EQ(
      error_code, 0,
      platform::errors::InvalidArgument(
          "cudaDeviceGetAttribute returned error code should be 0, "
          "but received error code is: %d, %s",
          error_code, CudaErrorWebsite()));
  return count;
}

L
liaogang 已提交
210 211
int GetCurrentDeviceId() {
  int device_id;
212 213 214 215 216
  auto error_code = cudaGetDevice(&device_id);
  PADDLE_ENFORCE(error_code,
                 "cudaGetDevice failed in "
                 "paddle::platform::GetCurrentDeviceId, error code : %d, %s",
                 error_code, CudaErrorWebsite());
L
liaogang 已提交
217 218 219
  return device_id;
}

220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices() {
  // use user specified GPUs in single-node multi-process mode.
  std::vector<int> devices;
  if (!FLAGS_selected_gpus.empty()) {
    auto devices_str = paddle::string::Split(FLAGS_selected_gpus, ',');
    for (auto id : devices_str) {
      devices.push_back(atoi(id.c_str()));
    }
  } else {
    int count = GetCUDADeviceCount();
    for (int i = 0; i < count; ++i) {
      devices.push_back(i);
    }
  }
  return devices;
}

L
liaogang 已提交
238
void SetDeviceId(int id) {
Q
qijun 已提交
239
  // TODO(qijun): find a better way to cache the cuda device count
Y
Yang Yang 已提交
240
  PADDLE_ENFORCE_LT(id, GetCUDADeviceCount(), "id must less than GPU count");
241 242 243 244 245
  auto error_code = cudaSetDevice(id);
  PADDLE_ENFORCE(error_code,
                 "cudaSetDevice failed in "
                 "paddle::platform::SetDeviced, error code : %d, %s",
                 error_code, CudaErrorWebsite());
L
liaogang 已提交
246 247
}

248
void GpuMemoryUsage(size_t *available, size_t *total) {
249 250 251
  size_t actual_available, actual_total;
  RecordedCudaMemGetInfo(available, total, &actual_available, &actual_total,
                         platform::GetCurrentDeviceId());
L
liaogang 已提交
252 253
}

254
size_t GpuAvailableMemToAlloc() {
L
liaogang 已提交
255 256
  size_t total = 0;
  size_t available = 0;
257
  GpuMemoryUsage(&available, &total);
258 259
  size_t reserving =
      static_cast<size_t>(fraction_reserve_gpu_memory * available);
260
  // If available size is less than minimum chunk size, no usable memory exists
261
  size_t available_to_alloc = available - reserving;
262
  size_t min_chunk_size = GpuMinChunkSize();
263 264 265
  if (available_to_alloc < min_chunk_size) {
    available_to_alloc = 0;
  }
266 267 268
  VLOG(10) << "GPU usage " << (available >> 20) << "M/" << (total >> 20)
           << "M, " << (available_to_alloc >> 20) << "M available to allocate";
  return available_to_alloc;
Z
zhhsplendid 已提交
269 270
}

271 272 273
size_t GpuMaxAllocSize() {
  return std::max(GpuInitAllocSize(), GpuReallocSize());
}
Z
zhhsplendid 已提交
274

275 276 277 278 279 280 281 282 283 284
static size_t GpuAllocSize(bool realloc) {
  size_t available_to_alloc = GpuAvailableMemToAlloc();
  PADDLE_ENFORCE_GT(available_to_alloc, 0, "No enough available GPU memory");
  // If FLAGS_initial_gpu_memory_in_mb is 0, then initial memory will be
  // allocated by fraction
  size_t flag_mb = realloc ? FLAGS_reallocate_gpu_memory_in_mb
                           : FLAGS_initial_gpu_memory_in_mb;
  size_t alloc_bytes =
      (flag_mb > 0ul ? flag_mb << 20 : available_to_alloc *
                                           FLAGS_fraction_of_gpu_memory_to_use);
285
  PADDLE_ENFORCE_GE(available_to_alloc, alloc_bytes,
286 287 288 289 290
                    "No enough available GPU memory");
  VLOG(10) << "Alloc size is " << (alloc_bytes >> 20)
           << " MiB, is it Re-alloc: " << realloc;
  return alloc_bytes;
}
Z
zhhsplendid 已提交
291

292
size_t GpuInitAllocSize() { return GpuAllocSize(/* realloc = */ false); }
Z
zhhsplendid 已提交
293

294
size_t GpuReallocSize() { return GpuAllocSize(/* realloc = */ true); }
L
liaogang 已提交
295

L
liaogang 已提交
296 297 298 299 300 301
size_t GpuMinChunkSize() {
  // Allow to allocate the minimum chunk size is 256 bytes.
  return 1 << 8;
}

size_t GpuMaxChunkSize() {
302 303 304
  size_t max_chunk_size = GpuMaxAllocSize();
  VLOG(10) << "Max chunk size " << (max_chunk_size >> 20) << "M";
  return max_chunk_size;
L
liaogang 已提交
305 306
}

L
liaogang 已提交
307 308
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
                    enum cudaMemcpyKind kind, cudaStream_t stream) {
309 310
  auto error_code = cudaMemcpyAsync(dst, src, count, kind, stream);
  PADDLE_ENFORCE(error_code,
311
                 "cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync "
312 313 314
                 "(%p -> %p, length: %d) error code : %d, %s",
                 src, dst, static_cast<int>(count), error_code,
                 CudaErrorWebsite());
L
liaogang 已提交
315 316
}

317 318
void GpuMemcpySync(void *dst, const void *src, size_t count,
                   enum cudaMemcpyKind kind) {
319 320 321 322 323 324
  auto error_code = cudaMemcpy(dst, src, count, kind);
  PADDLE_ENFORCE(error_code,
                 "cudaMemcpy failed in paddle::platform::GpuMemcpySync "
                 "(%p -> %p, length: %d) error code : %d, %s",
                 src, dst, static_cast<int>(count), error_code,
                 CudaErrorWebsite());
325 326 327 328
}

void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
                        int src_device, size_t count, cudaStream_t stream) {
329 330
  auto error_code =
      cudaMemcpyPeerAsync(dst, dst_device, src, src_device, count, stream);
L
liaogang 已提交
331
  PADDLE_ENFORCE(
332 333 334 335
      error_code,
      "cudaMemcpyPeerAsync failed in paddle::platform::GpuMemcpyPeerAsync "
      "error code : %d, %s",
      error_code, CudaErrorWebsite());
336 337 338 339
}

void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
                       int src_device, size_t count) {
340 341 342 343 344
  auto error_code = cudaMemcpyPeer(dst, dst_device, src, src_device, count);
  PADDLE_ENFORCE(error_code,
                 "cudaMemcpyPeer failed in paddle::platform::GpuMemcpyPeerSync "
                 "error code : %d, %s",
                 error_code, CudaErrorWebsite());
L
liaogang 已提交
345
}
D
dzhwinter 已提交
346 347

void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
348 349 350 351 352
  auto error_code = cudaMemsetAsync(dst, value, count, stream);
  PADDLE_ENFORCE(error_code,
                 "cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync "
                 "error code : %d, %s",
                 error_code, CudaErrorWebsite());
D
dzhwinter 已提交
353
}
354

石晓伟 已提交
355 356 357 358 359 360 361 362 363 364
void GpuStreamSync(cudaStream_t stream) {
  auto error_code = cudaStreamSynchronize(stream);
  PADDLE_ENFORCE_CUDA_SUCCESS(
      error_code,
      platform::errors::External(
          "cudaStreamSynchronize failed in paddle::platform::GpuStreamSync "
          "error code : %d, %s",
          error_code, CudaErrorWebsite()));
}

365
static void RaiseNonOutOfMemoryError(cudaError_t *status) {
366 367 368 369 370 371 372 373 374 375 376 377 378
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }

  PADDLE_ENFORCE_CUDA_SUCCESS(*status);

  *status = cudaGetLastError();
  if (*status == cudaErrorMemoryAllocation) {
    *status = cudaSuccess;
  }

  PADDLE_ENFORCE_CUDA_SUCCESS(*status);
}
石晓伟 已提交
379

380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532
class RecordedCudaMallocHelper {
 private:
  explicit RecordedCudaMallocHelper(int dev_id, uint64_t limit_size = 0)
      : dev_id_(dev_id), limit_size_(limit_size) {
    if (NeedRecord()) {
      mtx_.reset(new std::mutex());
    }
  }

  DISABLE_COPY_AND_ASSIGN(RecordedCudaMallocHelper);

 public:
  static RecordedCudaMallocHelper *Instance(int dev_id) {
    std::call_once(once_flag_, [] {
      int dev_cnt = GetCUDADeviceCount();
      instances_.reserve(dev_cnt);
      for (int i = 0; i < dev_cnt; ++i) {
        instances_.emplace_back(
            new RecordedCudaMallocHelper(i, FLAGS_gpu_memory_limit_mb << 20));
      }
    });

    PADDLE_ENFORCE_GE(
        dev_id, 0,
        platform::errors::OutOfRange(
            "Device id must be not less than 0, but got %d", dev_id));
    PADDLE_ENFORCE_LT(
        dev_id, instances_.size(),
        platform::errors::OutOfRange("Device id %d exceeds gpu card number %d",
                                     dev_id, instances_.size()));
    return instances_[dev_id].get();
  }

  /**
   * Try to allocate `size` gpu memory. Only cudaErrorMemoryAllocation
   * or cudaSuccess would be returned, and the cudaGetLastError() flag
   * would be clear.
   */
  cudaError_t Malloc(void **ptr, size_t size) {
    LockGuardPtr<std::mutex> lock(mtx_);
    if (UNLIKELY(NeedRecord() && cur_size_ + size > limit_size_)) {
      return cudaErrorMemoryAllocation;
    }

    CUDADeviceGuard guard(dev_id_);
    auto result = cudaMalloc(ptr, size);
    if (result == cudaSuccess) {
      if (NeedRecord()) {
        cur_size_ += size;
      }
      return cudaSuccess;
    } else {
      RaiseNonOutOfMemoryError(&result);
      // Non out of memory error would be raised inside
      // RaiseNonOutOfMemoryError. Therefore, we can
      // return cudaErrorMemoryAllocation directly here.
      return cudaErrorMemoryAllocation;
    }
  }

  /**
   * Free gpu memory. Usually, free is not allowed to raise error.
   * If it does raise error, the process should be crashed.
   */
  void Free(void *ptr, size_t size) {
    // Purposefully allow cudaErrorCudartUnloading, because
    // that is returned if you ever call cudaFree after the
    // driver has already shutdown. This happens only if the
    // process is terminating, in which case we don't care if
    // cudaFree succeeds.
    CUDADeviceGuard guard(dev_id_);
    auto err = cudaFree(ptr);
    if (err != cudaErrorCudartUnloading) {
      PADDLE_ENFORCE_CUDA_SUCCESS(
          err, platform::errors::External("cudaFree raises unexpected error"));
      if (NeedRecord()) {
        std::lock_guard<std::mutex> guard(*mtx_);
        cur_size_ -= size;
      }
    } else {
      cudaGetLastError();  // clear the error flag when cudaErrorCudartUnloading
    }
  }

  bool GetMemInfo(size_t *avail, size_t *total, size_t *actual_avail,
                  size_t *actual_total) {
    {
      CUDADeviceGuard guard(dev_id_);
      auto result = cudaMemGetInfo(actual_avail, actual_total);
      if (result != cudaSuccess) {
        *actual_avail = 0;
      }
      RaiseNonOutOfMemoryError(&result);
    }

    if (NeedRecord()) {
      std::lock_guard<std::mutex> guard(*mtx_);
      *avail = std::min(*actual_avail, limit_size_ - cur_size_);
      *total = std::min(*actual_total, limit_size_);
      return *total < *actual_total;
    } else {
      *avail = *actual_avail;
      *total = *actual_total;
      return false;
    }
  }

  inline bool NeedRecord() const { return limit_size_ != 0; }

  uint64_t RecordedSize() const {
    LockGuardPtr<std::mutex> lock(mtx_);
    return NeedRecord() ? cur_size_ : 0;
  }

  uint64_t LimitSize() const { return limit_size_; }

 private:
  const int dev_id_;
  const uint64_t limit_size_;
  uint64_t cur_size_{0};

  mutable std::unique_ptr<std::mutex> mtx_;

  static std::once_flag once_flag_;
  static std::vector<std::unique_ptr<RecordedCudaMallocHelper>> instances_;
};

std::once_flag RecordedCudaMallocHelper::once_flag_;
std::vector<std::unique_ptr<RecordedCudaMallocHelper>>
    RecordedCudaMallocHelper::instances_;

cudaError_t RecordedCudaMalloc(void **ptr, size_t size, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->Malloc(ptr, size);
}

void RecordedCudaFree(void *p, size_t size, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->Free(p, size);
}

bool RecordedCudaMemGetInfo(size_t *avail, size_t *total, size_t *actual_avail,
                            size_t *actual_total, int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->GetMemInfo(
      avail, total, actual_avail, actual_total);
}

uint64_t RecordedCudaMallocSize(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->RecordedSize();
}

bool IsCudaMallocRecorded(int dev_id) {
  return RecordedCudaMallocHelper::Instance(dev_id)->NeedRecord();
}

L
liaogang 已提交
533 534
}  // namespace platform
}  // namespace paddle