Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
18eb7730
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
18eb7730
编写于
3月 26, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add CUDAPinnedPlace
上级
f3dc3112
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
125 addition
and
86 deletion
+125
-86
paddle/fluid/framework/tensor.h
paddle/fluid/framework/tensor.h
+10
-19
paddle/fluid/framework/tensor_impl.h
paddle/fluid/framework/tensor_impl.h
+9
-14
paddle/fluid/memory/detail/system_allocator.cc
paddle/fluid/memory/detail/system_allocator.cc
+8
-8
paddle/fluid/memory/detail/system_allocator.h
paddle/fluid/memory/detail/system_allocator.h
+1
-3
paddle/fluid/memory/memory.cc
paddle/fluid/memory/memory.cc
+46
-39
paddle/fluid/memory/memory.h
paddle/fluid/memory/memory.h
+1
-0
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+12
-0
paddle/fluid/platform/place.cc
paddle/fluid/platform/place.cc
+9
-2
paddle/fluid/platform/place.h
paddle/fluid/platform/place.h
+29
-1
未找到文件。
paddle/fluid/framework/tensor.h
浏览文件 @
18eb7730
...
...
@@ -45,11 +45,10 @@ class Tensor {
friend
struct
EigenVector
;
public:
Tensor
()
:
offset_
(
0
)
,
is_pinned_
(
false
)
{}
Tensor
()
:
offset_
(
0
)
{}
/*! Constructor with place should only be used in pybind. */
explicit
Tensor
(
const
platform
::
Place
&
place
)
:
offset_
(
0
),
is_pinned_
(
false
)
{
explicit
Tensor
(
const
platform
::
Place
&
place
)
:
offset_
(
0
)
{
holder_
->
set_place
(
place
);
}
...
...
@@ -70,12 +69,11 @@ class Tensor {
* @note If not exist, then allocation.
*/
template
<
typename
T
>
inline
T
*
mutable_data
(
platform
::
Place
place
,
bool
is_pinned
=
false
);
inline
T
*
mutable_data
(
platform
::
Place
place
);
inline
void
*
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
,
bool
is_pinned
=
false
);
inline
void
*
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
);
inline
void
*
mutable_data
(
platform
::
Place
place
,
bool
is_pinned
=
false
);
inline
void
*
mutable_data
(
platform
::
Place
place
);
/**
* @brief Return a pointer to mutable memory block.
...
...
@@ -86,8 +84,7 @@ class Tensor {
* @note If not exist, then allocation.
*/
template
<
typename
T
>
inline
T
*
mutable_data
(
DDim
dims
,
platform
::
Place
place
,
bool
is_pinned
=
false
);
inline
T
*
mutable_data
(
DDim
dims
,
platform
::
Place
place
);
/*! Return the dimensions of the memory block. */
inline
const
DDim
&
dims
()
const
;
...
...
@@ -152,14 +149,12 @@ class Tensor {
template
<
typename
Place
>
struct
PlaceholderImpl
:
public
Placeholder
{
PlaceholderImpl
(
Place
place
,
size_t
size
,
std
::
type_index
type
,
bool
is_pinned
=
false
)
:
ptr_
(
static_cast
<
uint8_t
*>
(
memory
::
Alloc
(
place
,
size
,
is_pinned
)),
memory
::
PODDeleter
<
uint8_t
,
Place
>
(
place
,
is_pinned
)),
PlaceholderImpl
(
Place
place
,
size_t
size
,
std
::
type_index
type
)
:
ptr_
(
static_cast
<
uint8_t
*>
(
memory
::
Alloc
(
place
,
size
)),
memory
::
PODDeleter
<
uint8_t
,
Place
>
(
place
)),
place_
(
place
),
size_
(
size
),
type_
(
type
),
is_pinned_
(
is_pinned
)
{
type_
(
type
)
{
PADDLE_ENFORCE_NOT_NULL
(
ptr_
,
"Insufficient %s memory to allocation."
,
(
is_cpu_place
(
place_
)
?
"CPU"
:
"GPU"
));
}
...
...
@@ -182,9 +177,6 @@ class Tensor {
/* the current type of memory */
std
::
type_index
type_
;
/*! use pinned memory or not. */
bool
is_pinned_
;
};
/*! holds the memory block if allocated. */
...
...
@@ -219,7 +211,6 @@ class Tensor {
* PlaceHolder::ptr_ and where the tensor data really begins.
*/
size_t
offset_
;
bool
is_pinned_
;
};
inline
void
Tensor
::
switch_place
(
platform
::
Place
new_place
)
{
...
...
paddle/fluid/framework/tensor_impl.h
浏览文件 @
18eb7730
...
...
@@ -101,21 +101,19 @@ inline T* Tensor::data() {
}
template
<
typename
T
>
inline
T
*
Tensor
::
mutable_data
(
DDim
dims
,
platform
::
Place
place
,
bool
is_pinned
)
{
inline
T
*
Tensor
::
mutable_data
(
DDim
dims
,
platform
::
Place
place
)
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
Resize
(
dims
);
return
mutable_data
<
T
>
(
place
,
is_pinned
);
return
mutable_data
<
T
>
(
place
);
}
template
<
typename
T
>
inline
T
*
Tensor
::
mutable_data
(
platform
::
Place
place
,
bool
is_pinned
)
{
inline
T
*
Tensor
::
mutable_data
(
platform
::
Place
place
)
{
static_assert
(
std
::
is_pod
<
T
>::
value
,
"T must be POD"
);
return
reinterpret_cast
<
T
*>
(
mutable_data
(
place
,
typeid
(
T
)
,
is_pinned
));
return
reinterpret_cast
<
T
*>
(
mutable_data
(
place
,
typeid
(
T
)));
}
inline
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
,
bool
is_pinned
)
{
inline
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
,
std
::
type_index
type
)
{
if
(
holder_
!=
nullptr
)
{
holder_
->
set_type
(
type
);
}
...
...
@@ -129,27 +127,26 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type,
holder_
->
size
()
<
size
+
offset_
)
{
if
(
platform
::
is_cpu_place
(
place
))
{
holder_
.
reset
(
new
PlaceholderImpl
<
platform
::
CPUPlace
>
(
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
size
,
type
,
is_pinned
));
boost
::
get
<
platform
::
CPUPlace
>
(
place
),
size
,
type
));
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW
(
"'CUDAPlace' is not supported in CPU only device."
);
}
#else
holder_
.
reset
(
new
PlaceholderImpl
<
platform
::
CUDAPlace
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
size
,
type
,
is_pinned
));
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
size
,
type
));
}
#endif
offset_
=
0
;
is_pinned_
=
is_pinned
;
}
return
reinterpret_cast
<
void
*>
(
reinterpret_cast
<
uintptr_t
>
(
holder_
->
ptr
())
+
offset_
);
}
inline
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
,
bool
is_pinned
)
{
inline
void
*
Tensor
::
mutable_data
(
platform
::
Place
place
)
{
PADDLE_ENFORCE
(
this
->
holder_
!=
nullptr
,
"Cannot invoke mutable data if current hold nothing"
);
return
mutable_data
(
place
,
holder_
->
type
()
,
is_pinned
);
return
mutable_data
(
place
,
holder_
->
type
());
}
inline
Tensor
&
Tensor
::
ShareDataWith
(
const
Tensor
&
src
)
{
...
...
@@ -191,8 +188,6 @@ inline const DDim& Tensor::dims() const { return dims_; }
inline
int64_t
Tensor
::
numel
()
const
{
return
product
(
dims_
);
}
inline
bool
Tensor
::
isPinned
()
const
{
return
is_pinned_
;
}
inline
Tensor
ReshapeToMatrix
(
const
Tensor
&
src
,
int
num_col_dims
)
{
Tensor
res
;
res
.
ShareDataWith
(
src
);
...
...
paddle/fluid/memory/detail/system_allocator.cc
浏览文件 @
18eb7730
...
...
@@ -123,20 +123,20 @@ bool GPUAllocator::UseGpu() const { return true; }
// memory. It’s locked to a physical address.
void
*
CUDAPinnedAllocator
::
Alloc
(
size_t
&
index
,
size_t
size
)
{
if
(
size
<=
0
)
return
nullptr
;
void
*
p
;
// NOTE: here, we use
GpuMaxAllocSize()
as the maximum memory size
// NOTE: here, we use
CpuMaxAllocSize()/2
as the maximum memory size
// of host pinned allocation. Allocates too much would reduce
// the amount of memory available to the underlying system for paging.
size_t
usable
=
paddle
::
platform
::
GpuMaxAllocSize
()
-
fallback_alloc_size_
;
size_t
usable
=
CpuMaxAllocSize
()
/
2
-
cuda_pinnd_alloc_size_
;
if
(
size
>
usable
)
return
nullptr
;
// PINNED memory is visible to all CUDA contexts.
cudaError_t
result
=
cudaMallocHost
(
&
p
,
size
);
if
(
result
==
cudaSuccess
)
{
index
=
1
;
fallback
_alloc_size_
+=
size
;
index
=
1
;
// PINNED memory
cuda_pinnd
_alloc_size_
+=
size
;
return
p
;
}
...
...
@@ -147,8 +147,8 @@ void CUDAPinnedAllocator::Free(void* p, size_t size, size_t index) {
cudaError_t
err
;
PADDLE_ASSERT
(
index
==
1
);
PADDLE_ASSERT
(
fallback
_alloc_size_
>=
size
);
fallback
_alloc_size_
-=
size
;
PADDLE_ASSERT
(
cuda_pinnd
_alloc_size_
>=
size
);
cuda_pinnd
_alloc_size_
-=
size
;
err
=
cudaFreeHost
(
p
);
// Purposefully allow cudaErrorCudartUnloading, because
...
...
paddle/fluid/memory/detail/system_allocator.h
浏览文件 @
18eb7730
...
...
@@ -59,9 +59,7 @@ class CUDAPinnedAllocator : public SystemAllocator {
virtual
bool
UseGpu
()
const
;
private:
size_t
gpu_alloc_size_
=
0
;
// TODO(zcd): how to define the upper limit of CUDAPinnedMemory?
size_t
fallback_alloc_size_
=
0
;
size_t
cuda_pinnd_alloc_size_
=
0
;
};
#endif
...
...
paddle/fluid/memory/memory.cc
浏览文件 @
18eb7730
...
...
@@ -38,8 +38,7 @@ BuddyAllocator* GetCPUBuddyAllocator() {
}
template
<
>
void
*
Alloc
<
platform
::
CPUPlace
>
(
platform
::
CPUPlace
place
,
size_t
size
,
bool
is_pinned
)
{
void
*
Alloc
<
platform
::
CPUPlace
>
(
platform
::
CPUPlace
place
,
size_t
size
)
{
VLOG
(
10
)
<<
"Allocate "
<<
size
<<
" bytes on "
<<
platform
::
Place
(
place
);
void
*
p
=
GetCPUBuddyAllocator
()
->
Alloc
(
size
);
VLOG
(
10
)
<<
" pointer="
<<
p
;
...
...
@@ -47,8 +46,7 @@ void* Alloc<platform::CPUPlace>(platform::CPUPlace place, size_t size,
}
template
<
>
void
Free
<
platform
::
CPUPlace
>
(
platform
::
CPUPlace
place
,
void
*
p
,
bool
is_pinned
)
{
void
Free
<
platform
::
CPUPlace
>
(
platform
::
CPUPlace
place
,
void
*
p
)
{
VLOG
(
10
)
<<
"Free pointer="
<<
p
<<
" on "
<<
platform
::
Place
(
place
);
GetCPUBuddyAllocator
()
->
Free
(
p
);
}
...
...
@@ -85,27 +83,13 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
}
BuddyAllocator
*
GetCUDAPinnedBuddyAllocator
(
int
gpu_id
)
{
static
BuddyAllocator
*
*
as
=
NULL
;
static
BuddyAllocator
*
as
=
NULL
;
if
(
as
==
NULL
)
{
int
gpu_num
=
platform
::
GetCUDADeviceCount
();
as
=
new
BuddyAllocator
*
[
gpu_num
];
for
(
int
gpu
=
0
;
gpu
<
gpu_num
;
gpu
++
)
{
as
[
gpu
]
=
nullptr
;
}
}
platform
::
SetDeviceId
(
gpu_id
);
if
(
!
as
[
gpu_id
])
{
as
[
gpu_id
]
=
new
BuddyAllocator
(
new
detail
::
CUDAPinnedAllocator
,
platform
::
GpuMinChunkSize
(),
platform
::
GpuMaxChunkSize
());
VLOG
(
10
)
<<
"
\n\n
NOTE: each GPU device use "
<<
FLAGS_fraction_of_gpu_memory_to_use
*
100
<<
"% of GPU memory.
\n
"
<<
"You can set GFlags environment variable '"
<<
"FLAGS_fraction_of_gpu_memory_to_use"
<<
"' to change the fraction of GPU usage.
\n\n
"
;
as
=
new
BuddyAllocator
(
new
detail
::
CUDAPinnedAllocator
,
platform
::
CpuMinChunkSize
(),
platform
::
CpuMaxChunkSize
());
}
return
as
[
gpu_id
]
;
return
as
;
}
template
<
>
...
...
@@ -114,16 +98,9 @@ size_t Used<platform::CUDAPlace>(platform::CUDAPlace place) {
}
template
<
>
void
*
Alloc
<
platform
::
CUDAPlace
>
(
platform
::
CUDAPlace
place
,
size_t
size
,
bool
is_pinned
)
{
void
*
ptr
;
if
(
is_pinned
)
{
auto
*
buddy_allocator
=
GetCUDAPinnedBuddyAllocator
(
place
.
device
);
ptr
=
buddy_allocator
->
Alloc
(
size
);
}
else
{
void
*
Alloc
<
platform
::
CUDAPlace
>
(
platform
::
CUDAPlace
place
,
size_t
size
)
{
auto
*
buddy_allocator
=
GetGPUBuddyAllocator
(
place
.
device
);
ptr
=
buddy_allocator
->
Alloc
(
size
);
}
void
*
ptr
=
buddy_allocator
->
Alloc
(
size
);
if
(
ptr
==
nullptr
)
{
int
cur_dev
=
platform
::
GetCurrentDeviceId
();
...
...
@@ -142,13 +119,39 @@ void* Alloc<platform::CUDAPlace>(platform::CUDAPlace place, size_t size,
}
template
<
>
void
Free
<
platform
::
CUDAPlace
>
(
platform
::
CUDAPlace
place
,
void
*
p
,
bool
is_pinned
)
{
if
(
is_pinned
)
{
GetCUDAPinnedBuddyAllocator
(
place
.
device
)
->
Free
(
p
);
}
else
{
void
Free
<
platform
::
CUDAPlace
>
(
platform
::
CUDAPlace
place
,
void
*
p
)
{
GetGPUBuddyAllocator
(
place
.
device
)
->
Free
(
p
);
}
size_t
Used
<
platform
::
CUDAPinnedPlace
>
(
platform
::
CUDAPinnedPlace
place
)
{
return
GetGPUBuddyAllocator
(
place
.
device
)
->
Used
();
}
template
<
>
void
*
Alloc
<
platform
::
CUDAPinnedPlace
>
(
platform
::
CUDAPinnedPlace
place
,
size_t
size
)
{
auto
*
buddy_allocator
=
GetCUDAPinnedBuddyAllocator
(
place
.
device
);
void
*
ptr
=
buddy_allocator
->
Alloc
(
size
);
if
(
ptr
==
nullptr
)
{
int
cur_dev
=
platform
::
GetCurrentDeviceId
();
platform
::
SetDeviceId
(
place
.
device
);
size_t
avail
,
total
;
platform
::
GpuMemoryUsage
(
avail
,
total
);
LOG
(
WARNING
)
<<
"Cannot allocate "
<<
size
<<
" bytes in GPU "
<<
place
.
device
<<
", available "
<<
avail
<<
" bytes"
;
LOG
(
WARNING
)
<<
"total "
<<
total
;
LOG
(
WARNING
)
<<
"GpuMinChunkSize "
<<
platform
::
GpuMinChunkSize
();
LOG
(
WARNING
)
<<
"GpuMaxChunkSize "
<<
platform
::
GpuMaxChunkSize
();
LOG
(
WARNING
)
<<
"GPU memory used: "
<<
Used
<
platform
::
CUDAPlace
>
(
place
);
platform
::
SetDeviceId
(
cur_dev
);
}
return
ptr
;
}
template
<
>
void
Free
<
platform
::
CUDAPinnedPlace
>
(
platform
::
CUDAPinnedPlace
place
,
void
*
p
)
{
GetCUDAPinnedBuddyAllocator
(
place
.
device
)
->
Free
(
p
);
}
#endif
...
...
@@ -165,6 +168,10 @@ size_t Usage::operator()(const platform::CUDAPlace& gpu) const {
#endif
}
size_t
Usage
::
operator
()(
const
platform
::
CUDAPinnedPlace
&
cuda_pinned
)
const
{
return
Used
(
cuda_pinned
);
}
size_t
memory_usage
(
const
platform
::
Place
&
p
)
{
return
boost
::
apply_visitor
(
Usage
(),
p
);
}
...
...
paddle/fluid/memory/memory.h
浏览文件 @
18eb7730
...
...
@@ -57,6 +57,7 @@ size_t Used(Place place);
struct
Usage
:
public
boost
::
static_visitor
<
size_t
>
{
size_t
operator
()(
const
platform
::
CPUPlace
&
cpu
)
const
;
size_t
operator
()(
const
platform
::
CUDAPlace
&
gpu
)
const
;
size_t
operator
()(
const
platform
::
CUDAPinnedPlace
&
cuda_pinned
)
const
;
};
size_t
memory_usage
(
const
platform
::
Place
&
p
);
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
18eb7730
...
...
@@ -118,6 +118,18 @@ struct DefaultDeviceContextType<platform::CUDAPlace> {
using
TYPE
=
CUDADeviceContext
;
};
// Currently, CUDAPinnedDeviceContext is only used to data copying.
// class CUDAPinnedDeviceContext : public DeviceContext {
// public:
// CUDAPinnedDeviceContext();
// explicit CUDAPinnedDeviceContext(CUDAPinnedPlace place);
//
// Place GetPlace() const override;
//
// private:
// CUDAPinnedPlace place_;
//};
#endif
#ifdef PADDLE_WITH_MKLDNN
...
...
paddle/fluid/platform/place.cc
浏览文件 @
18eb7730
...
...
@@ -40,12 +40,19 @@ const Place &get_place() { return the_default_place; }
const
CUDAPlace
default_gpu
()
{
return
CUDAPlace
(
0
);
}
const
CPUPlace
default_cpu
()
{
return
CPUPlace
();
}
const
CUDAPinnedPlace
default_cuda_pinned
()
{
return
CUDAPinnedPlace
();
}
bool
is_gpu_place
(
const
Place
&
p
)
{
return
boost
::
apply_visitor
(
IsCUDAPlace
(),
p
);
}
bool
is_cpu_place
(
const
Place
&
p
)
{
return
!
is_gpu_place
(
p
);
}
bool
is_cpu_place
(
const
Place
&
p
)
{
return
boost
::
apply_visitor
(
IsCPUPlace
(),
p
);
}
bool
is_cuda_pinned_place
(
const
Place
&
p
)
{
return
boost
::
apply_visitor
(
IsCUDAPinnedPlace
(),
p
);
}
bool
places_are_same_class
(
const
Place
&
p1
,
const
Place
&
p2
)
{
return
p1
.
which
()
==
p2
.
which
();
...
...
@@ -53,7 +60,7 @@ bool places_are_same_class(const Place &p1, const Place &p2) {
bool
is_same_place
(
const
Place
&
p1
,
const
Place
&
p2
)
{
if
(
places_are_same_class
(
p1
,
p2
))
{
if
(
is_cpu_place
(
p1
))
{
if
(
is_cpu_place
(
p1
)
||
is_cuda_pinned_place
(
p1
)
)
{
return
true
;
}
else
{
return
boost
::
get
<
CUDAPlace
>
(
p1
)
==
boost
::
get
<
CUDAPlace
>
(
p2
);
...
...
paddle/fluid/platform/place.h
浏览文件 @
18eb7730
...
...
@@ -45,12 +45,33 @@ struct CUDAPlace {
int
device
;
};
struct
CUDAPinnedPlace
{
CUDAPinnedPlace
()
{}
// needed for variant equality comparison
inline
bool
operator
==
(
const
CUDAPinnedPlace
&
)
const
{
return
true
;
}
inline
bool
operator
!=
(
const
CUDAPinnedPlace
&
)
const
{
return
false
;
}
};
struct
IsCUDAPlace
:
public
boost
::
static_visitor
<
bool
>
{
bool
operator
()(
const
CPUPlace
&
)
const
{
return
false
;
}
bool
operator
()(
const
CUDAPlace
&
gpu
)
const
{
return
true
;
}
bool
operator
()(
const
CUDAPinnedPlace
&
)
const
{
return
false
;
}
};
typedef
boost
::
variant
<
CUDAPlace
,
CPUPlace
>
Place
;
struct
IsCPUPlace
:
public
boost
::
static_visitor
<
bool
>
{
bool
operator
()(
const
CPUPlace
&
cpu
)
const
{
return
true
;
}
bool
operator
()(
const
CUDAPlace
&
)
const
{
return
false
;
}
bool
operator
()(
const
CUDAPinnedPlace
&
)
const
{
return
false
;
}
};
struct
IsCUDAPinnedPlace
:
public
boost
::
static_visitor
<
bool
>
{
bool
operator
()(
const
CPUPlace
&
)
const
{
return
false
;
}
bool
operator
()(
const
CUDAPlace
&
)
const
{
return
false
;
}
bool
operator
()(
const
CUDAPinnedPlace
&
cuda_pinned
)
const
{
return
true
;
}
};
typedef
boost
::
variant
<
CUDAPlace
,
CPUPlace
,
CUDAPinnedPlace
>
Place
;
using
PlaceList
=
std
::
vector
<
Place
>
;
...
...
@@ -59,9 +80,11 @@ const Place &get_place();
const
CUDAPlace
default_gpu
();
const
CPUPlace
default_cpu
();
const
CUDAPinnedPlace
default_cuda_pinned
();
bool
is_gpu_place
(
const
Place
&
);
bool
is_cpu_place
(
const
Place
&
);
bool
is_cuda_pinned_place
(
const
Place
&
);
bool
places_are_same_class
(
const
Place
&
,
const
Place
&
);
bool
is_same_place
(
const
Place
&
,
const
Place
&
);
...
...
@@ -97,6 +120,11 @@ struct PlaceVisitorWrapper
return
typename
Visitor
::
result_type
();
#endif
}
typename
Visitor
::
result_type
operator
()(
const
CUDAPinnedPlace
&
cuda_pinned
)
const
{
return
visitor_
(
cuda_pinned
);
}
};
template
<
typename
Visitor
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录