device_context.cc 16.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
qijun 已提交
2 3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
6

Q
qijun 已提交
7 8 9 10 11
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yi Wang 已提交
12
#include "paddle/fluid/platform/device_context.h"
13
#include <set>
14
#include <string>
Y
Yu Yang 已提交
15
#include <unordered_set>
16 17
#include <vector>

Y
Yi Wang 已提交
18
#include "paddle/fluid/memory/memory.h"
19 20
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/framework/rw_lock.h"
21
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
S
sneaxiy 已提交
22
#include "paddle/fluid/platform/cuda_device_guard.h"
23
#endif
24

25 26
#include "glog/logging.h"

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
namespace paddle {
namespace memory {

AllocationPtr Alloc(const platform::DeviceContext& dev_ctx, size_t size) {
  auto place = dev_ctx.GetPlace();
#ifdef PADDLE_WITH_CUDA
  if (size == 0 || !platform::is_gpu_place(place)) {
    return Alloc(place, size);
  }
  auto* default_dev_ctx = static_cast<platform::CUDADeviceContext*>(
      platform::DeviceContextPool::Instance().Get(place));
  auto& desired_dev_ctx =
      static_cast<const platform::CUDADeviceContext&>(dev_ctx);
  if (default_dev_ctx->stream() == desired_dev_ctx.stream()) {
    return Alloc(place, size);
  } else {
    return allocation::CUDADeviceContextAllocatorPool::Instance().Alloc(
        desired_dev_ctx, size);
  }
#else
  return Alloc(place, size);
#endif
}

}  // namespace memory
}  // namespace paddle

Q
qijun 已提交
54 55 56
namespace paddle {
namespace platform {

D
dzhwinter 已提交
57 58
DeviceContextPool* DeviceContextPool::pool = nullptr;

Y
Yu Yang 已提交
59
platform::DeviceContext* DeviceContextPool::Get(const platform::Place& place) {
D
dzhwinter 已提交
60 61
  auto it = device_contexts_.find(place);
  if (it == device_contexts_.end()) {
G
GaoWei8 已提交
62 63 64 65 66
    PADDLE_THROW(platform::errors::Unimplemented(
        "Place %s is not supported. Please check that your paddle compiles "
        "with WITH_GPU option or check that your train process hold the "
        "correct gpu_id if you use Executor.",
        place));
D
dzhwinter 已提交
67
  }
68
  return it->second.get().get();
D
dzhwinter 已提交
69 70
}

71 72 73 74 75 76 77 78 79
template <typename DevCtx, typename PlaceType>
inline void EmplaceDeviceContext(
    std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
        map_ptr,
    platform::Place p) {
  using PtrType = std::unique_ptr<DeviceContext>;
  map_ptr->emplace(p, std::async(std::launch::deferred, [=] {
                     // lazy evaluation. i.e., only create device context at
                     // first `Get`
80
                     return PtrType(new DevCtx(BOOST_GET_CONST(PlaceType, p)));
81
                   }));
C
chengduozh 已提交
82 83
}

D
dzhwinter 已提交
84 85
DeviceContextPool::DeviceContextPool(
    const std::vector<platform::Place>& places) {
G
GaoWei8 已提交
86 87 88 89 90
  PADDLE_ENFORCE_GT(
      places.size(), 0,
      platform::errors::InvalidArgument("The number of platform places should "
                                        "be larger than 0. But received %d.",
                                        places.size()));
91
  std::set<Place> set;
Y
Yu Yang 已提交
92 93 94 95 96
  for (auto& p : places) {
    set.insert(p);
  }
  for (auto& p : set) {
    if (platform::is_cpu_place(p)) {
97
#ifdef PADDLE_WITH_MKLDNN
98
      EmplaceDeviceContext<MKLDNNDeviceContext, CPUPlace>(&device_contexts_, p);
99
#else
100
      EmplaceDeviceContext<CPUDeviceContext, CPUPlace>(&device_contexts_, p);
101
#endif
Y
Yu Yang 已提交
102
    } else if (platform::is_gpu_place(p)) {
D
dzhwinter 已提交
103
#ifdef PADDLE_WITH_CUDA
104
      EmplaceDeviceContext<CUDADeviceContext, CUDAPlace>(&device_contexts_, p);
D
dzhwinter 已提交
105
#else
G
GaoWei8 已提交
106 107 108
      PADDLE_THROW(platform::errors::Unimplemented(
          "'CUDAPlace is not supported. Please re-compile with WITH_GPU."
          "option"));
C
chengduoZH 已提交
109 110 111
#endif
    } else if (platform::is_cuda_pinned_place(p)) {
#ifdef PADDLE_WITH_CUDA
112 113
      EmplaceDeviceContext<CUDAPinnedDeviceContext, CUDAPinnedPlace>(
          &device_contexts_, p);
C
chengduoZH 已提交
114
#else
G
GaoWei8 已提交
115 116 117
      PADDLE_THROW(platform::errors::Unimplemented(
          "'CUDAPlace' is not supported. Please re-compile with WITH_GPU."
          "option"));
D
dzhwinter 已提交
118 119 120 121 122
#endif
    }
  }
}

123 124 125 126
CPUDeviceContext::CPUDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

D
dzhwinter 已提交
127
CPUDeviceContext::CPUDeviceContext(CPUPlace place) : place_(place) {
128 129 130 131 132 133 134
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CPUDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

D
dzhwinter 已提交
135
Place CPUDeviceContext::GetPlace() const { return place_; }
136

137
#ifdef PADDLE_WITH_CUDA
138

Q
init  
qijun 已提交
139 140 141 142 143 144 145
class EigenCudaStreamDevice : public Eigen::StreamInterface {
 public:
  EigenCudaStreamDevice() : scratch_(nullptr), semaphore_(nullptr) {
    Eigen::initializeDeviceProp();
  }
  ~EigenCudaStreamDevice() override {}

D
dzhwinter 已提交
146
  void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
Q
init  
qijun 已提交
147 148 149 150 151 152 153 154 155 156 157 158
    stream_ = cuda_stream;
    place_ = place;
    device_prop_ = &Eigen::m_deviceProperties[place.device];
  }

  const cudaStream_t& stream() const override { return *stream_; }

  const cudaDeviceProp& deviceProperties() const override {
    return *device_prop_;
  }

  void* allocate(size_t num_bytes) const override {
S
sneaxiy 已提交
159 160 161
    if (UNLIKELY(num_bytes == 0)) {
      return nullptr;
    }
162 163 164
    auto buf = memory::Alloc(place_, num_bytes);
    VLOG(4) << "Eigen allocated at " << buf->ptr() << ", size" << buf->size()
            << " requested " << num_bytes;
165
    void* retv = buf->ptr();
S
sneaxiy 已提交
166 167 168 169
    {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.emplace(retv, std::move(buf));
    }
170
    return retv;
Q
init  
qijun 已提交
171 172
  }

S
sneaxiy 已提交
173 174 175 176 177 178
  void deallocate(void* buffer) const override {
    if (LIKELY(buffer)) {
      std::lock_guard<std::mutex> lock(mtx_);
      allocations_.erase(buffer);
    }
  }
Q
init  
qijun 已提交
179 180 181

  void* scratchpad() const override {
    if (scratch_ == NULL) {
Z
Zhang Ting 已提交
182 183 184 185
// windows use an old version of eigen that uses kCudaScratchSize,
// once windows updates eigen to a recent version, the following code
// can use kGpuScratchSize uniformly
#ifdef _WIN32
Q
init  
qijun 已提交
186
      scratch_ = allocate(Eigen::kCudaScratchSize + sizeof(unsigned int));
Z
Zhang Ting 已提交
187 188 189
#else
      scratch_ = allocate(Eigen::kGpuScratchSize + sizeof(unsigned int));
#endif
Q
init  
qijun 已提交
190 191 192 193 194 195
    }
    return scratch_;
  }

  unsigned int* semaphore() const override {
    if (semaphore_ == NULL) {
Z
Zhang Ting 已提交
196
#ifdef _WIN32
Q
init  
qijun 已提交
197 198
      char* scratch =
          static_cast<char*>(scratchpad()) + Eigen::kCudaScratchSize;
Z
Zhang Ting 已提交
199 200 201
#else
      char* scratch = static_cast<char*>(scratchpad()) + Eigen::kGpuScratchSize;
#endif
Q
init  
qijun 已提交
202
      semaphore_ = reinterpret_cast<unsigned int*>(scratch);
203
      PADDLE_ENFORCE_CUDA_SUCCESS(
Q
init  
qijun 已提交
204 205 206 207 208 209
          cudaMemsetAsync(semaphore_, 0, sizeof(unsigned int), *stream_));
    }
    return semaphore_;
  }

 private:
D
dzhwinter 已提交
210
  CUDAPlace place_;
Q
init  
qijun 已提交
211 212
  const cudaStream_t* stream_;         // not owned;
  const cudaDeviceProp* device_prop_;  // not owned;
Q
qijun 已提交
213
  mutable void* scratch_;
Q
init  
qijun 已提交
214
  mutable unsigned int* semaphore_;
S
sneaxiy 已提交
215
  mutable std::mutex mtx_;  // to protect allocations_
Y
Yu Yang 已提交
216
  mutable std::unordered_map<void*, memory::AllocationPtr> allocations_;
Q
init  
qijun 已提交
217 218
};

219 220 221 222 223 224 225 226 227
void CudnnWorkspaceHandle::ReallocWorkspace(size_t required_workspace_bytes) {
  if (required_workspace_bytes <= WorkspaceSize()) {
    return;
  }
  // reset allocation first before re-allocate to save memory
  allocation_.reset();
  allocation_ = memory::Alloc(device_context_, required_workspace_bytes);
}

228 229 230 231 232 233 234 235 236 237 238 239
thread_local std::unordered_map<const CUDADeviceContext*,
                                std::shared_ptr<CUDAContext>>
    CUDADeviceContext::thread_ctx_;
thread_local std::mutex CUDADeviceContext::ctx_mtx_;

void CUDAContext::InitEigenContext() {
  eigen_stream_.reset(new EigenCudaStreamDevice());
  eigen_stream_->Reinitialize(&RawStream(), place_);
  eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}

CUDAContext::CUDAContext(const CUDAPlace& place,
240
                         const stream::Priority& priority) {
241 242 243 244 245 246
  place_ = place;
  CUDADeviceGuard guard(place_.device);
  stream_.reset(new stream::CUDAStream(place, priority));
  InitEigenContext();
  InitCuBlasContext();
  InitCuDNNContext();
G
Guo Sheng 已提交
247
  InitCuSolverContext();
248 249 250 251 252 253
}

CUDAContext::~CUDAContext() {
  CUDADeviceGuard guard(place_.device);
  DestoryCuDNNContext();
  DestoryCuBlasContext();
G
Guo Sheng 已提交
254
  DestoryCuSolverContext();
255 256
}

257
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
Y
Yu Yang 已提交
258
  CUDADeviceGuard guard(place_.device);
C
chengduo 已提交
259 260 261
  compute_capability_ = GetCUDAComputeCapability(place_.device);
  multi_process_ = GetCUDAMultiProcessors(place_.device);
  max_threads_per_mp_ = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
262
  max_grid_dim_size_ = GetGpuMaxGridDimSize(place_.device);
263
  max_threads_per_block_ = GetCUDAMaxThreadsPerBlock(place_.device);
264

C
chengduo 已提交
265 266 267
  driver_version_ = GetCUDADriverVersion(place_.device);
  runtime_version_ = GetCUDARuntimeVersion(place_.device);

268 269
  LOG_FIRST_N(WARNING, 1) << "Please NOTE: device: " << place_.device
                          << ", CUDA Capability: " << compute_capability_
C
chengduo 已提交
270
                          << ", Driver API Version: " << driver_version_ / 1000
271
                          << "." << (driver_version_ % 100) / 10
C
chengduo 已提交
272 273 274
                          << ", Runtime API Version: "
                          << runtime_version_ / 1000 << "."
                          << (runtime_version_ % 100) / 10;
275 276 277
  size_t cudnn_dso_ver = dynload::cudnnGetVersion();
  LOG_FIRST_N(WARNING, 1) << "device: " << place_.device
                          << ", cuDNN Version: " << cudnn_dso_ver / 1000 << "."
278
                          << (cudnn_dso_ver % 1000) / 100 << ".";
S
sneaxiy 已提交
279 280 281

  {
    // Check CUDA/CUDNN version compatiblity
282 283 284 285
    auto local_cuda_version =
        (driver_version_ / 1000) * 10 + (driver_version_ % 100) / 10;
    auto compile_cuda_version =
        (CUDA_VERSION / 1000) * 10 + (CUDA_VERSION % 100) / 10;
S
sneaxiy 已提交
286 287 288 289 290 291 292 293 294 295 296 297
    if (local_cuda_version < compile_cuda_version) {
      LOG_FIRST_N(WARNING, 1)
          << "WARNING: device: " << place_.device
          << ". The installed Paddle is compiled with CUDA "
          << compile_cuda_version / 10 << "." << compile_cuda_version % 10
          << ", but CUDA runtime version in your machine is "
          << local_cuda_version / 10 << "." << local_cuda_version % 10
          << ", which may cause serious incompatible bug. "
          << "Please recompile or reinstall Paddle with compatible CUDA "
             "version.";
    }
  }
298
  default_ctx_.reset(new CUDAContext(place_));
299 300 301 302
}

CUDADeviceContext::~CUDADeviceContext() {
  SetDeviceId(place_.device);
303 304 305 306 307
#if defined(PADDLE_WITH_NCCL)
  if (nccl_comm_) {
    PADDLE_ENFORCE_CUDA_SUCCESS(dynload::ncclCommDestroy(nccl_comm_));
  }
#endif
308 309
}

L
liaogang 已提交
310
Place CUDADeviceContext::GetPlace() const { return place_; }
311

312
void CUDADeviceContext::Wait() const { context()->Stream()->Wait(); }
313

K
Kexin Zhao 已提交
314
int CUDADeviceContext::GetComputeCapability() const {
C
chengduo 已提交
315
  return compute_capability_;
K
Kexin Zhao 已提交
316 317
}

318
int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
C
chengduo 已提交
319
  return multi_process_ * max_threads_per_mp_;
320 321
}

322 323 324 325 326 327
int CUDADeviceContext::GetSMCount() const { return multi_process_; }

int CUDADeviceContext::GetMaxThreadsPerBlock() const {
  return max_threads_per_block_;
}

328
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
329
  return context()->EigenDevice().get();
330 331
}

332
bool CUDADeviceContext::tensor_core_available() const {
333
  return context()->CublasTensorCoreHandle() != nullptr;
S
sneaxiy 已提交
334 335
}

336 337 338 339
dim3 CUDADeviceContext::GetCUDAMaxGridDimSize() const {
  return max_grid_dim_size_;
}

340 341 342
cudnnHandle_t CUDADeviceContext::cudnn_handle() const {
  return context()->CudnnHandle();
}
343

S
sneaxiy 已提交
344
CudnnWorkspaceHandle CUDADeviceContext::cudnn_workspace_handle() const {
345
  return CudnnWorkspaceHandle(*this, &cudnn_handle_mtx_);
346
}
347

G
Guo Sheng 已提交
348 349 350 351
cusolverDnHandle_t CUDADeviceContext::cusolver_dn_handle() const {
  return context()->CusolverDnHandle();
}

352 353 354
cudaStream_t CUDADeviceContext::stream() const {
  return context()->RawStream();
}
Q
qijun 已提交
355

C
chengduoZH 已提交
356 357 358 359 360 361 362 363 364 365 366 367 368 369
CUDAPinnedDeviceContext::CUDAPinnedDeviceContext() {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

CUDAPinnedDeviceContext::CUDAPinnedDeviceContext(CUDAPinnedPlace place)
    : place_(place) {
  eigen_device_.reset(new Eigen::DefaultDevice());
}

Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
  return eigen_device_.get();
}

Place CUDAPinnedDeviceContext::GetPlace() const { return place_; }
L
Luo Tao 已提交
370
#endif
Q
qijun 已提交
371

T
tensor-tang 已提交
372 373
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
A
Adam 已提交
374 375 376
    : CPUDeviceContext(place),
      engine_(mkldnn::engine::kind::cpu, 0),
      p_blobmap_() {
377 378
  p_blobmap_.reset(new BlobMap());
  p_mutex_.reset(new std::mutex());
T
tensor-tang 已提交
379 380
}

381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397
MKLDNNDeviceContextThreadLocals::Body::Body() {
  cur_mkldnn_session_id = kMKLDNNSessionID_Default;
  cur_input_shape_str = "";
  cur_input_shape_cache_capacity = 1;
  cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
    size_t sid) {
  cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
  return cur_mkldnn_session_id;
}

void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
    std::string input_shape_str) {
398 399
  cur_input_shape_str = input_shape_str;
}
400 401
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
    int input_shape_cache_capacity) {
402 403
  cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
S
Sylwester Fraczek 已提交
404

405 406
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
    framework::DataLayout dl) {
407 408 409
  cur_paddle_data_layout = dl;
}

410 411
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
412 413 414
  return cur_paddle_data_layout;
}

415 416 417 418
void MKLDNNDeviceContext::ResetBlobMap() const {
  VLOG(3) << "Clearing DNNL cache.";
  p_blobmap_->clear();
}
419

420
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
421
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
422
  BlobMap* pMap = p_blobmap_.get();
423
  auto map_it = pMap->find(tls().cur_mkldnn_session_id);
424
  if (map_it == pMap->end()) {
425 426 427
    PADDLE_THROW(platform::errors::NotFound(
        "MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
        tls().cur_mkldnn_session_id));
428 429 430 431
  }
  return map_it->second->size();
}

432
void MKLDNNDeviceContext::SetBlob(const std::string& name,
433
                                  BlobPtr_t<void> data) const {
434
  BlobMap* pMap = p_blobmap_.get();
435 436
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
437

438
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
439

440
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
T
tensor-tang 已提交
441

442 443
  // Find ShapeBlob for current mkldnn session id.
  auto map_it = pMap->find(sid);
444 445 446

  if (map_it == pMap->end()) {
    // 1st time to set blob in current thread
447
    sBlob = std::make_shared<ShapeBlob>();
448 449
    (*pMap)[sid] = sBlob;
    VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
450
  } else {
451
    sBlob = map_it->second;
452
  }
T
tensor-tang 已提交
453

454
  // Find KeyBlob for current input shape
455
  auto key_it = sBlob->find(tls().cur_input_shape_str);
456

457
  if (key_it == sBlob->end()) {
458 459
    // In cache clearing mode, cur_input_shape_cache_capacity defines
    // max pblob capacity
460 461
    if ((static_cast<size_t>(sid) ==
         MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
462
        sBlob->size() &&
463
        (sBlob->size() >=
464
         static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
465 466 467 468
      VLOG(2) << "sid=" << sid
              << ", remove all blobs of shape: " << sBlob->begin()->first;
      sBlob->erase(sBlob->begin()->first);
    }
469 470
    pBlob = std::make_shared<KeyBlob>();
    (*sBlob)[tls().cur_input_shape_str] = pBlob;
471
  } else {
472
    pBlob = key_it->second;
473 474
  }

475 476 477 478 479 480 481
  // Find Blob via name
  auto blob_it = pBlob->find(name);
  if (blob_it == pBlob->end()) {
    (*pBlob)[name] = data;
  } else {
    blob_it->second = data;  // set data to existing blob
  }
482
  VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
483
  // lock will be automatically released when out of scope
484
  return;
T
tensor-tang 已提交
485 486
}

487
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
488
    const std::string& name) const {
489
  BlobMap* pMap = p_blobmap_.get();
490 491
  BlobPtr_t<ShapeBlob> sBlob = nullptr;
  BlobPtr_t<KeyBlob> pBlob = nullptr;
T
tensor-tang 已提交
492

493
  int sid = tls().get_cur_mkldnn_session_id();
T
tensor-tang 已提交
494

495
  std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
496

497 498
  // Find ShapeBlob for current mkldnn session id firstly
  auto map_it = pMap->find(sid);
499
  if (map_it == pMap->end()) {
500
    VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
501 502 503 504 505
    return nullptr;
  }
  sBlob = map_it->second;

  // Find KeyBlob for current input shape secondly
506
  auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
507
  if (sBlob_it == sBlob->end()) {
508
    VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
509 510 511 512
            << ", miss input_shape_str\n";
    return nullptr;
  }
  pBlob = sBlob_it->second;
513 514 515 516

  // Find Blob via name
  auto key_it = pBlob->find(name);

517
  if (key_it == pBlob->end()) {
518
    VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
519 520
    return nullptr;
  }
521

522
  VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
523 524
  // lock will be automatically released when out of scope
  return key_it->second;
T
tensor-tang 已提交
525 526 527 528
}

#endif

Q
qijun 已提交
529
}  // namespace platform
Q
qijun 已提交
530
}  // namespace paddle