Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
84aea8a8
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
84aea8a8
编写于
3月 07, 2018
作者:
C
chengduo
提交者:
GitHub
3月 07, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #8669 from chengduoZH/feature/concat_op
Refine concat_op
上级
8c71adaa
c3864eab
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
883 addition
and
50 deletion
+883
-50
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-0
paddle/fluid/operators/concat_op.cc
paddle/fluid/operators/concat_op.cc
+5
-4
paddle/fluid/operators/concat_op.h
paddle/fluid/operators/concat_op.h
+38
-46
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+3
-0
paddle/fluid/operators/math/concat.cc
paddle/fluid/operators/math/concat.cc
+119
-0
paddle/fluid/operators/math/concat.cu
paddle/fluid/operators/math/concat.cu
+280
-0
paddle/fluid/operators/math/concat.h
paddle/fluid/operators/math/concat.h
+63
-0
paddle/fluid/operators/math/concat_test.cc
paddle/fluid/operators/math/concat_test.cc
+336
-0
paddle/fluid/platform/device_context.cc
paddle/fluid/platform/device_context.cc
+6
-0
paddle/fluid/platform/device_context.h
paddle/fluid/platform/device_context.h
+6
-0
paddle/fluid/platform/gpu_info.cc
paddle/fluid/platform/gpu_info.cc
+20
-0
paddle/fluid/platform/gpu_info.h
paddle/fluid/platform/gpu_info.h
+6
-0
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
84aea8a8
...
...
@@ -201,6 +201,7 @@ op_library(save_op DEPS lod_tensor)
op_library
(
load_op DEPS lod_tensor
)
op_library
(
save_combine_op DEPS lod_tensor
)
op_library
(
load_combine_op DEPS lod_tensor
)
op_library
(
concat_op DEPS concat_functor
)
list
(
REMOVE_ITEM GENERAL_OPS
${
DEPS_OPS
}
)
foreach
(
src
${
GENERAL_OPS
}
)
...
...
paddle/fluid/operators/concat_op.cc
浏览文件 @
84aea8a8
...
...
@@ -100,7 +100,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_EX
(
concat
,
ops
::
ConcatOp
,
ops
::
ConcatOpMaker
,
concat_grad
,
ops
::
ConcatOpGrad
,
false
)
REGISTER_OP_CPU_KERNEL
(
concat
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
)
REGISTER_OP_CPU_KERNEL
(
concat_grad
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
)
REGISTER_OP_CPU_KERNEL
(
concat
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
)
REGISTER_OP_CPU_KERNEL
(
concat_grad
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
)
paddle/fluid/operators/concat_op.h
浏览文件 @
84aea8a8
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/operators/strided_memcpy.h"
namespace
paddle
{
...
...
@@ -27,54 +28,30 @@ class ConcatKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
ins
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
framework
::
Tensor
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
int64_t
axis
=
static_cast
<
int64_t
>
(
ctx
.
Attr
<
int
>
(
"axis"
));
auto
place
=
ctx
.
GetPlace
();
out
->
mutable_data
<
T
>
(
place
);
auto
out_stride
=
framework
::
stride_numel
(
out
->
dims
());
size_t
output_offset
=
0
;
// If axis >=1, copy to out immediately need to call many times
// of cuda memcpy. Copy the input to cpu and do the stride copy,
// then copy to gpu output.
if
(
platform
::
is_gpu_place
(
place
)
&&
axis
>=
1
)
{
platform
::
CPUPlace
copy_place
;
auto
&
cpu_ctx
=
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
copy_place
);
framework
::
Tensor
cpu_out
;
cpu_out
.
Resize
(
out
->
dims
());
cpu_out
.
mutable_data
<
T
>
(
copy_place
);
auto
&
dev_ctx
=
ctx
.
device_context
();
std
::
vector
<
std
::
unique_ptr
<
framework
::
Tensor
>>
cpu_ins
;
for
(
auto
*
in
:
ins
)
{
std
::
unique_ptr
<
framework
::
Tensor
>
cpu_in
(
new
framework
::
Tensor
);
framework
::
TensorCopy
(
*
in
,
copy_place
,
dev_ctx
,
cpu_in
.
get
());
cpu_ins
.
emplace_back
(
std
::
move
(
cpu_in
));
}
// TODO(dzhwinter): overlap copy and compute stream
// https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
dev_ctx
.
Wait
();
for
(
auto
&
in
:
cpu_ins
)
{
auto
&
cpu_in
=
*
in
.
get
();
auto
in_stride
=
framework
::
stride_numel
(
cpu_in
.
dims
());
StridedNumelCopyWithAxis
<
T
>
(
cpu_ctx
,
axis
,
cpu_out
.
data
<
T
>
()
+
output_offset
,
out_stride
,
cpu_in
.
data
<
T
>
(),
in_stride
,
in_stride
[
axis
]);
output_offset
+=
in_stride
[
axis
];
}
framework
::
TensorCopy
(
cpu_out
,
place
,
dev_ctx
,
out
);
}
else
{
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if
(
axis
==
0
&&
ins
.
size
()
<
10
)
{
size_t
output_offset
=
0
;
for
(
auto
*
in
:
ins
)
{
auto
in_stride
=
framework
::
stride_numel
(
in
->
dims
());
auto
out_stride
=
framework
::
stride_numel
(
out
->
dims
());
StridedNumelCopyWithAxis
<
T
>
(
ctx
.
device_context
(),
axis
,
out
->
data
<
T
>
()
+
output_offset
,
out_stride
,
in
->
data
<
T
>
(),
in_stride
,
in_stride
[
axis
]);
output_offset
+=
in_stride
[
axis
];
}
}
else
{
std
::
vector
<
framework
::
Tensor
>
inputs
(
ins
.
size
());
for
(
size_t
j
=
0
;
j
<
ins
.
size
();
++
j
)
{
inputs
[
j
]
=
*
ins
[
j
];
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
paddle
::
operators
::
math
::
ConcatFunctor
<
DeviceContext
,
T
>
concat_functor
;
concat_functor
(
dev_ctx
,
inputs
,
static_cast
<
int
>
(
axis
),
out
);
}
}
};
...
...
@@ -86,16 +63,31 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
outs
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
int64_t
axis
=
static_cast
<
int64_t
>
(
ctx
.
Attr
<
int
>
(
"axis"
));
size_t
input_offset
=
0
;
auto
in_stride
=
framework
::
stride_numel
(
in
->
dims
());
for
(
auto
&
out
:
outs
)
{
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
out_stride
=
framework
::
stride_numel
(
out
->
dims
());
StridedNumelCopyWithAxis
<
T
>
(
ctx
.
device_context
(),
axis
,
out
->
data
<
T
>
(),
out_stride
,
in
->
data
<
T
>
()
+
input_offset
,
in_stride
,
out_stride
[
axis
]);
input_offset
+=
out_stride
[
axis
];
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if
(
axis
==
0
&&
outs
.
size
()
<
10
)
{
size_t
input_offset
=
0
;
auto
in_stride
=
framework
::
stride_numel
(
in
->
dims
());
for
(
auto
&
out
:
outs
)
{
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
out_stride
=
framework
::
stride_numel
(
out
->
dims
());
StridedNumelCopyWithAxis
<
T
>
(
ctx
.
device_context
(),
axis
,
out
->
data
<
T
>
(),
out_stride
,
in
->
data
<
T
>
()
+
input_offset
,
in_stride
,
out_stride
[
axis
]);
input_offset
+=
out_stride
[
axis
];
}
}
else
{
std
::
vector
<
framework
::
Tensor
>
outputs
(
outs
.
size
());
for
(
size_t
j
=
0
;
j
<
outs
.
size
();
++
j
)
{
outs
[
j
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
outputs
[
j
]
=
*
outs
[
j
];
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
paddle
::
operators
::
math
::
ConcatGradFunctor
<
DeviceContext
,
T
>
concat_grad_functor
;
concat_grad_functor
(
dev_ctx
,
*
in
,
static_cast
<
int
>
(
axis
),
outputs
);
}
}
};
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
84aea8a8
...
...
@@ -20,6 +20,7 @@ if(WITH_GPU)
nv_library
(
unpooling SRCS unpooling.cc unpooling.cu DEPS device_context
)
nv_library
(
gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function
)
nv_library
(
cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context
)
nv_library
(
concat_functor SRCS concat.cc concat.cu DEPS device_context tensor
)
else
()
cc_library
(
math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto
)
cc_library
(
selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function
)
...
...
@@ -37,6 +38,7 @@ else()
cc_library
(
unpooling SRCS unpooling.cc DEPS device_context
)
cc_library
(
gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function
)
cc_library
(
cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context
)
cc_library
(
concat_functor SRCS concat.cc DEPS device_context tensor
)
endif
()
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
...
...
@@ -44,3 +46,4 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec
cc_test
(
im2col_test SRCS im2col_test.cc DEPS math_function tensor
)
cc_test
(
vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor
)
cc_test
(
sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding
)
cc_test
(
concat_test SRCS concat_test.cc DEPS concat_functor tensor
)
paddle/fluid/operators/math/concat.cc
0 → 100644
浏览文件 @
84aea8a8
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template
<
typename
T
>
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
)
{
// TODO(zcd): Add input data validity checking
int
num
=
input
.
size
();
int
rows
=
1
;
auto
dim_0
=
input
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
rows
*=
dim_0
[
i
];
}
int
out_rows
=
rows
,
out_cols
=
0
;
std
::
vector
<
int64_t
>
input_cols
(
input
.
size
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
input
[
i
].
numel
()
/
rows
;
out_cols
+=
t_cols
;
input_cols
[
i
]
=
t_cols
;
}
auto
&
cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
context
.
GetPlace
());
// computation
for
(
int
k
=
0
;
k
<
out_rows
;
++
k
)
{
T
*
dst_ptr
=
output
->
data
<
T
>
()
+
k
*
out_cols
;
int
col_idx
=
0
;
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
int
col_len
=
input_cols
[
j
];
const
T
*
src_prt
=
input
[
j
].
data
<
T
>
()
+
k
*
col_len
;
memory
::
Copy
(
cpu_place
,
dst_ptr
+
col_idx
,
cpu_place
,
src_prt
,
sizeof
(
T
)
*
col_len
);
col_idx
+=
col_len
;
}
}
}
};
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template
<
typename
T
>
class
ConcatGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
int
axis
,
std
::
vector
<
framework
::
Tensor
>&
outputs
)
{
// TODO(zcd): Add input data validity checking
int
num
=
outputs
.
size
();
int
input_rows
=
1
;
auto
dim_0
=
outputs
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
input_rows
*=
dim_0
[
i
];
}
int
input_cols
=
0
;
std
::
vector
<
int64_t
>
output_cols
(
outputs
.
size
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
outputs
[
i
].
numel
()
/
input_rows
;
input_cols
+=
t_cols
;
output_cols
[
i
]
=
t_cols
;
}
auto
&
cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
context
.
GetPlace
());
// computation
for
(
int
k
=
0
;
k
<
input_rows
;
++
k
)
{
const
T
*
src_ptr
=
input
.
data
<
T
>
()
+
k
*
input_cols
;
int
col_idx
=
0
;
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
int
col_len
=
output_cols
[
j
];
T
*
dst_ptr
=
outputs
[
j
].
data
<
T
>
()
+
k
*
col_len
;
memory
::
Copy
(
cpu_place
,
dst_ptr
,
cpu_place
,
src_ptr
+
col_idx
,
sizeof
(
T
)
*
col_len
);
col_idx
+=
col_len
;
}
}
}
};
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
int
>;
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
int64_t
>;
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
double
>;
template
class
ConcatGradFunctor
<
platform
::
CPUDeviceContext
,
int
>;
template
class
ConcatGradFunctor
<
platform
::
CPUDeviceContext
,
int64_t
>;
template
class
ConcatGradFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
class
ConcatGradFunctor
<
platform
::
CPUDeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/concat.cu
0 → 100644
浏览文件 @
84aea8a8
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
__device__
T
upper_bound
(
const
T
*
first
,
T
count
,
T
val
)
{
const
T
*
orig
=
first
;
const
T
*
it
=
nullptr
;
T
step
=
0
;
while
(
count
>
0
)
{
it
=
first
;
step
=
count
/
2
;
it
+=
step
;
if
(
!
(
val
<
*
it
))
{
first
=
++
it
;
count
-=
step
+
1
;
}
else
{
count
=
step
;
}
}
return
first
-
orig
;
}
template
<
typename
T
>
__global__
void
KernelConcat
(
T
**
inputs
,
const
int
*
input_cols
,
int
col_size
,
const
int
output_rows
,
const
int
output_cols
,
T
*
output
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
segment
=
upper_bound
<
int
>
(
input_cols
,
col_size
,
tid_x
)
-
1
;
int
curr_offset
=
input_cols
[
segment
];
int
curr_segment
=
segment
;
for
(;
tid_x
<
output_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
curr_col_offset
;
while
((
curr_col_offset
=
input_cols
[
curr_segment
+
1
])
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
}
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
T
*
input_ptr
=
inputs
[
curr_segment
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output
[
tid_y
*
output_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
segment_width
+
local_col
];
}
}
template
<
typename
T
>
__global__
void
KernelConcat
(
T
**
inputs
,
const
int
input_col
,
const
int
output_rows
,
const
int
output_cols
,
T
*
output
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
double
inv_input_col
=
1.0
/
input_col
;
for
(;
tid_x
<
output_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
*
inv_input_col
;
int
in_offset
=
tid_x
-
split
*
input_col
;
T
*
input_ptr
=
inputs
[
split
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
{
output
[
tid_y
*
output_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
input_col
+
in_offset
];
}
}
}
template
<
typename
T
>
__global__
void
KernelConcatGrad
(
const
T
*
input
,
const
int
input_row
,
const
int
input_col
,
const
int
*
output_cols
,
int
col_size
,
T
**
outputs
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
segment
=
upper_bound
<
int
>
(
output_cols
,
col_size
,
tid_x
)
-
1
;
int
curr_offset
=
output_cols
[
segment
];
int
curr_segment
=
segment
;
for
(;
tid_x
<
input_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
curr_col_offset
;
while
((
curr_col_offset
=
output_cols
[
curr_segment
+
1
])
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
}
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
T
*
output_ptr
=
outputs
[
curr_segment
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
input_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
segment_width
+
local_col
]
=
input
[
tid_y
*
input_col
+
tid_x
];
}
}
template
<
typename
T
>
__global__
void
KernelConcatGrad
(
const
T
*
input
,
const
int
input_row
,
const
int
input_col
,
const
int
output_cols
,
T
**
outputs
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
double
inv_input_col
=
1.0
/
input_col
;
for
(;
tid_x
<
input_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
*
inv_input_col
;
int
in_offset
=
tid_x
-
split
*
input_col
;
T
*
output_ptr
=
outputs
[
split
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
input_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
output_cols
+
in_offset
]
=
input
[
tid_y
*
input_col
+
tid_x
];
}
}
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template
<
typename
T
>
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
)
{
// TODO(zcd): Add input data validity checking
int
num
=
input
.
size
();
int
rows
=
1
;
auto
dim_0
=
input
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
rows
*=
dim_0
[
i
];
}
int
cols
=
input
[
0
].
numel
()
/
rows
;
int
out_rows
=
rows
,
out_cols
=
0
;
framework
::
Vector
<
int16_t
>
inputs_data
(
num
*
sizeof
(
T
*
)
/
2
);
framework
::
Vector
<
int
>
inputs_cols
(
num
+
1
);
inputs_cols
[
0
]
=
0
;
T
**
inputs_ptr
=
reinterpret_cast
<
T
**>
(
inputs_data
.
data
());
bool
sameShape
=
true
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
input
[
i
].
numel
()
/
rows
;
if
(
sameShape
)
{
if
(
t_cols
!=
cols
)
sameShape
=
false
;
}
out_cols
+=
t_cols
;
inputs_cols
[
i
+
1
]
=
out_cols
;
inputs_ptr
[
i
]
=
const_cast
<
T
*>
(
input
[
i
].
data
<
T
>
());
}
T
**
ins_gpu
=
reinterpret_cast
<
T
**>
(
inputs_data
.
CUDAMutableData
(
context
.
GetPlace
()));
const
int
*
ins_col_gpu
=
inputs_cols
.
CUDAData
(
context
.
GetPlace
());
// computation
// set the thread block and grid according to CurrentDeviceId
const
int
kThreadsPerBlock
=
1024
;
int
block_cols
=
kThreadsPerBlock
;
if
(
out_cols
<
kThreadsPerBlock
)
{
// block_cols is aligned by 32.
block_cols
=
((
out_cols
+
31
)
>>
5
)
<<
5
;
}
int
block_rows
=
kThreadsPerBlock
/
block_cols
;
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
int
max_threads
=
context
.
GetMaxPhysicalThreadCount
();
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
int
grid_cols
=
std
::
min
((
out_cols
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
int
grid_rows
=
std
::
min
(
max_blocks
/
grid_cols
,
std
::
max
(
out_rows
/
block_rows
,
1
));
dim3
grid_size
=
dim3
(
grid_cols
,
grid_rows
,
1
);
if
(
sameShape
)
{
KernelConcat
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
ins_gpu
,
cols
,
out_rows
,
out_cols
,
output
->
data
<
T
>
());
}
else
{
KernelConcat
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
ins_gpu
,
ins_col_gpu
,
static_cast
<
int
>
(
inputs_cols
.
size
()),
out_rows
,
out_cols
,
output
->
data
<
T
>
());
}
}
};
/*
* All tensors' dimension should be the same and the values of
* each dimension are the same, except the axis dimension.
*/
template
<
typename
T
>
class
ConcatGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
int
axis
,
std
::
vector
<
framework
::
Tensor
>&
outputs
)
{
// TODO(zcd): Add input data validity checking
int
num
=
outputs
.
size
();
int
input_row
=
1
;
auto
dim_0
=
outputs
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
input_row
*=
dim_0
[
i
];
}
int
output_col_0
=
outputs
[
0
].
numel
()
/
input_row
;
int
input_col
=
0
;
bool
sameShape
=
true
;
framework
::
Vector
<
int16_t
>
outputs_data
(
num
*
sizeof
(
T
*
)
/
2
);
framework
::
Vector
<
int
>
outputs_cols
(
num
+
1
);
outputs_cols
[
0
]
=
0
;
T
**
outputs_ptr
=
reinterpret_cast
<
T
**>
(
outputs_data
.
data
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_col
=
outputs
[
i
].
numel
()
/
input_row
;
if
(
sameShape
)
{
if
(
t_col
!=
output_col_0
)
sameShape
=
false
;
}
input_col
+=
t_col
;
outputs_cols
[
i
+
1
]
=
input_col
;
outputs_ptr
[
i
]
=
outputs
[
i
].
data
<
T
>
();
}
T
**
outs_gpu
=
reinterpret_cast
<
T
**>
(
outputs_data
.
CUDAMutableData
(
context
.
GetPlace
()));
const
int
*
outs_col_gpu
=
outputs_cols
.
CUDAData
(
context
.
GetPlace
());
// computation
const
int
kThreadsPerBlock
=
1024
;
int
block_cols
=
kThreadsPerBlock
;
if
(
input_col
<
kThreadsPerBlock
)
{
// block_cols is aligned by 32.
block_cols
=
((
input_col
+
31
)
>>
5
)
<<
5
;
}
int
block_rows
=
kThreadsPerBlock
/
block_cols
;
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
int
max_threads
=
context
.
GetMaxPhysicalThreadCount
();
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
int
grid_cols
=
std
::
min
((
input_col
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
int
grid_rows
=
std
::
min
(
max_blocks
/
grid_cols
,
std
::
max
(
input_row
/
block_rows
,
1
));
dim3
grid_size
=
dim3
(
grid_cols
,
grid_rows
,
1
);
if
(
sameShape
)
{
KernelConcatGrad
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
input_row
,
input_col
,
output_col_0
,
outs_gpu
);
}
else
{
KernelConcatGrad
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
input_row
,
input_col
,
outs_col_gpu
,
static_cast
<
int
>
(
outputs_cols
.
size
()),
outs_gpu
);
}
}
};
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
int
>;
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
ConcatGradFunctor
<
platform
::
CUDADeviceContext
,
int
>;
template
class
ConcatGradFunctor
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
class
ConcatGradFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
ConcatGradFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/concat.h
0 → 100644
浏览文件 @
84aea8a8
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
/*
* \brief Concatenate the input tensors along the dimension axis.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input[0] = [[1,2],[3,4]]
* Input[1] = [[5,6]]
* axis = 0
*
* Output = [[1,2],
* [3,4],
* [5,6]]
*/
template
<
typename
DeviceContext
,
typename
T
>
class
ConcatFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
);
};
/*
* \brief Split the input tensors along the dimension axis into outputs.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input = [[1,2],
* [3,4],
* [5,6]]
* axis = 0
*
* Output[0] = [[1,2],[3,4]]
* Output[1] = [[5,6]]
*/
template
<
typename
DeviceContext
,
typename
T
>
class
ConcatGradFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
int
axis
,
std
::
vector
<
framework
::
Tensor
>&
outputs
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/concat_test.cc
0 → 100644
浏览文件 @
84aea8a8
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/tensor_util.h"
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
template
<
typename
DeviceContext
,
typename
Place
>
void
testConcat
()
{
Tensor
input_a_cpu
;
Tensor
input_b_cpu
;
Tensor
out_cpu
;
Tensor
input_a
;
Tensor
input_b
;
Tensor
out
;
DeviceContext
*
context
=
new
DeviceContext
(
Place
());
// DeviceContext context(Place());
/**
* cast1:
* inputs:
* t_a.shape: [2, 3, 4]
* t_b.shape: [3, 3, 4]
* output:
* out.shape: [5, 3, 4]
*/
auto
dim_a
=
make_ddim
({
2
,
3
,
4
});
auto
dim_b
=
make_ddim
({
3
,
3
,
4
});
auto
dim_out
=
make_ddim
({
5
,
3
,
4
});
input_a
.
mutable_data
<
int
>
(
dim_a
,
Place
());
input_b
.
mutable_data
<
int
>
(
dim_b
,
Place
());
out
.
mutable_data
<
int
>
(
dim_out
,
Place
());
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
input_a_cpu
.
mutable_data
<
int
>
(
dim_a
,
CPUPlace
());
input_b_cpu
.
mutable_data
<
int
>
(
dim_b
,
CPUPlace
());
out_cpu
.
mutable_data
<
int
>
(
dim_out
,
CPUPlace
());
}
int
*
a_ptr
;
int
*
b_ptr
;
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
a_ptr
=
input_a_cpu
.
data
<
int
>
();
b_ptr
=
input_b_cpu
.
data
<
int
>
();
}
else
{
a_ptr
=
input_a
.
data
<
int
>
();
b_ptr
=
input_b
.
data
<
int
>
();
}
for
(
int
i
=
0
;
i
<
2
*
3
*
4
;
++
i
)
{
a_ptr
[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
3
*
3
*
4
;
++
i
)
{
b_ptr
[
i
]
=
i
;
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
input_a_cpu
,
Place
(),
*
context
,
&
input_a
);
TensorCopy
(
input_b_cpu
,
Place
(),
*
context
,
&
input_b
);
}
std
::
vector
<
Tensor
>
input
;
input
.
push_back
(
input_a
);
input
.
push_back
(
input_b
);
paddle
::
operators
::
math
::
ConcatFunctor
<
DeviceContext
,
int
>
concat_functor
;
concat_functor
(
*
context
,
input
,
0
,
&
out
);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ
(
input_a
.
dims
(),
dim_a
);
PADDLE_ENFORCE_EQ
(
input_b
.
dims
(),
dim_b
);
int
*
out_ptr
;
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
out
,
CPUPlace
(),
*
context
,
&
out_cpu
);
out_ptr
=
out_cpu
.
data
<
int
>
();
}
else
{
out_ptr
=
out
.
data
<
int
>
();
}
int
cols
=
2
*
3
*
4
;
int
idx_a
=
0
,
idx_b
=
0
;
for
(
int
j
=
0
;
j
<
5
*
3
*
4
;
++
j
)
{
if
(
j
>=
cols
)
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
j
],
b_ptr
[
idx_b
]);
++
idx_b
;
}
else
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
j
],
a_ptr
[
idx_a
]);
++
idx_a
;
}
}
//
/**
* cast2:
* inputs:
* t_a.shape: [2, 3, 4]
* t_b.shape: [2, 4, 4]
* output:
* out.shape: [2, 7, 4]
*/
dim_a
=
make_ddim
({
2
,
3
,
4
});
dim_b
=
make_ddim
({
2
,
4
,
4
});
dim_out
=
make_ddim
({
2
,
7
,
4
});
input_a
.
Resize
(
dim_a
);
input_b
.
Resize
(
dim_b
);
out
.
Resize
(
dim_out
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
input_a_cpu
.
Resize
(
dim_a
);
input_b_cpu
.
Resize
(
dim_b
);
out_cpu
.
Resize
(
dim_out
);
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
a_ptr
=
input_a_cpu
.
data
<
int
>
();
b_ptr
=
input_b_cpu
.
data
<
int
>
();
}
else
{
a_ptr
=
input_a
.
data
<
int
>
();
b_ptr
=
input_b
.
data
<
int
>
();
}
for
(
int
i
=
0
;
i
<
2
*
3
*
4
;
++
i
)
{
a_ptr
[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
2
*
4
*
4
;
++
i
)
{
b_ptr
[
i
]
=
i
;
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
input_a_cpu
,
Place
(),
*
context
,
&
input_a
);
TensorCopy
(
input_b_cpu
,
Place
(),
*
context
,
&
input_b
);
}
input
.
clear
();
input
.
push_back
(
input_a
);
input
.
push_back
(
input_b
);
concat_functor
(
*
context
,
input
,
1
,
&
out
);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ
(
input_a
.
dims
(),
dim_a
);
PADDLE_ENFORCE_EQ
(
input_b
.
dims
(),
dim_b
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
out
,
CPUPlace
(),
*
context
,
&
out_cpu
);
out_ptr
=
out_cpu
.
data
<
int
>
();
}
else
{
out_ptr
=
out
.
data
<
int
>
();
}
cols
=
3
*
4
;
idx_a
=
0
,
idx_b
=
0
;
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
for
(
int
j
=
0
;
j
<
28
;
++
j
)
{
if
(
j
>=
cols
)
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
28
+
j
],
b_ptr
[
idx_b
]);
++
idx_b
;
}
else
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
28
+
j
],
a_ptr
[
idx_a
]);
++
idx_a
;
}
}
}
/**
* cast3:
* inputs:
* t_a.shape: [2, 3, 5]
* t_b.shape: [2, 3, 4]
* output:
* out.shape: [2, 3, 9]
*/
dim_a
=
make_ddim
({
2
,
3
,
4
});
dim_b
=
make_ddim
({
2
,
3
,
5
});
dim_out
=
make_ddim
({
2
,
3
,
9
});
input_a
.
Resize
(
dim_a
);
input_b
.
Resize
(
dim_b
);
out
.
Resize
(
dim_out
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
input_a_cpu
.
Resize
(
dim_a
);
input_b_cpu
.
Resize
(
dim_b
);
out_cpu
.
Resize
(
dim_out
);
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
a_ptr
=
input_a_cpu
.
data
<
int
>
();
b_ptr
=
input_b_cpu
.
data
<
int
>
();
}
else
{
a_ptr
=
input_a
.
data
<
int
>
();
b_ptr
=
input_b
.
data
<
int
>
();
}
for
(
int
i
=
0
;
i
<
2
*
3
*
4
;
++
i
)
{
a_ptr
[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
2
*
3
*
5
;
++
i
)
{
b_ptr
[
i
]
=
i
;
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
input_a_cpu
,
Place
(),
*
context
,
&
input_a
);
TensorCopy
(
input_b_cpu
,
Place
(),
*
context
,
&
input_b
);
}
input
.
clear
();
input
.
push_back
(
input_a
);
input
.
push_back
(
input_b
);
concat_functor
(
*
context
,
input
,
2
,
&
out
);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ
(
input_a
.
dims
(),
dim_a
);
PADDLE_ENFORCE_EQ
(
input_b
.
dims
(),
dim_b
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
out
,
CPUPlace
(),
*
context
,
&
out_cpu
);
out_ptr
=
out_cpu
.
data
<
int
>
();
}
else
{
out_ptr
=
out
.
data
<
int
>
();
}
// check the data
cols
=
4
;
idx_a
=
0
,
idx_b
=
0
;
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
for
(
int
j
=
0
;
j
<
9
;
++
j
)
{
if
(
j
>=
cols
)
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
9
+
j
],
b_ptr
[
idx_b
]);
++
idx_b
;
}
else
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
9
+
j
],
a_ptr
[
idx_a
]);
++
idx_a
;
}
}
}
/**
* cast4:
* inputs:
* axis = 1
* t_a.shape: [2, 3, 4]
* t_b.shape: [2, 3, 4]
* output:
* out.shape: [2, 6, 4]
*/
dim_a
=
make_ddim
({
2
,
3
,
4
});
dim_b
=
make_ddim
({
2
,
3
,
4
});
dim_out
=
make_ddim
({
2
,
6
,
4
});
input_a
.
Resize
(
dim_a
);
input_b
.
Resize
(
dim_b
);
out
.
Resize
(
dim_out
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
input_a_cpu
.
Resize
(
dim_a
);
input_b_cpu
.
Resize
(
dim_b
);
out_cpu
.
Resize
(
dim_out
);
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
a_ptr
=
input_a_cpu
.
data
<
int
>
();
b_ptr
=
input_b_cpu
.
data
<
int
>
();
}
else
{
a_ptr
=
input_a
.
data
<
int
>
();
b_ptr
=
input_b
.
data
<
int
>
();
}
for
(
int
i
=
0
;
i
<
2
*
3
*
4
;
++
i
)
{
a_ptr
[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
2
*
3
*
4
;
++
i
)
{
b_ptr
[
i
]
=
i
;
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
input_a_cpu
,
Place
(),
*
context
,
&
input_a
);
TensorCopy
(
input_b_cpu
,
Place
(),
*
context
,
&
input_b
);
}
input
.
clear
();
input
.
push_back
(
input_a
);
input
.
push_back
(
input_b
);
concat_functor
(
*
context
,
input
,
1
,
&
out
);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ
(
input_a
.
dims
(),
dim_a
);
PADDLE_ENFORCE_EQ
(
input_b
.
dims
(),
dim_b
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
out
,
CPUPlace
(),
*
context
,
&
out_cpu
);
out_ptr
=
out_cpu
.
data
<
int
>
();
}
else
{
out_ptr
=
out
.
data
<
int
>
();
}
// check the data
cols
=
12
;
idx_a
=
0
,
idx_b
=
0
;
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
for
(
int
j
=
0
;
j
<
24
;
++
j
)
{
if
(
j
>=
cols
)
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
24
+
j
],
b_ptr
[
idx_b
]);
++
idx_b
;
}
else
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
24
+
j
],
a_ptr
[
idx_a
]);
++
idx_a
;
}
}
}
}
TEST
(
math
,
concat
)
{
testConcat
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
CPUPlace
>
();
#ifdef PADDLE_WITH_CUDA
testConcat
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
CUDAPlace
>
();
#endif
}
paddle/fluid/platform/device_context.cc
浏览文件 @
84aea8a8
...
...
@@ -127,6 +127,8 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
CUDADeviceContext
::
CUDADeviceContext
(
CUDAPlace
place
)
:
place_
(
place
)
{
SetDeviceId
(
place_
.
device
);
multi_process
=
GetCUDAMultiProcessors
(
place_
.
device
);
max_threads_per_mp
=
GetCUDAMaxThreadsPerMultiProcessor
(
place_
.
device
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
...
...
@@ -160,6 +162,10 @@ void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE
(
cudaGetLastError
());
}
int
CUDADeviceContext
::
GetMaxPhysicalThreadCount
()
const
{
return
multi_process
*
max_threads_per_mp
;
}
Eigen
::
GpuDevice
*
CUDADeviceContext
::
eigen_device
()
const
{
return
eigen_device_
.
get
();
}
...
...
paddle/fluid/platform/device_context.h
浏览文件 @
84aea8a8
...
...
@@ -79,6 +79,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return place in the device context. */
Place
GetPlace
()
const
override
;
/*! \brief Return the max physical thread count in the device context */
int
GetMaxPhysicalThreadCount
()
const
;
/*! \brief Return eigen device in the device context. */
Eigen
::
GpuDevice
*
eigen_device
()
const
;
...
...
@@ -100,6 +103,9 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t
stream_
;
cudnnHandle_t
cudnn_handle_
;
cublasHandle_t
cublas_handle_
;
int
multi_process
;
int
max_threads_per_mp
;
};
template
<
>
...
...
paddle/fluid/platform/gpu_info.cc
浏览文件 @
84aea8a8
...
...
@@ -33,6 +33,26 @@ int GetCUDADeviceCount() {
return
count
;
}
int
GetCUDAMultiProcessors
(
int
id
)
{
PADDLE_ENFORCE_LT
(
id
,
GetCUDADeviceCount
(),
"id must less than GPU count"
);
int
count
;
PADDLE_ENFORCE
(
cudaDeviceGetAttribute
(
&
count
,
cudaDevAttrMultiProcessorCount
,
id
),
"cudaDeviceGetAttribute failed in "
"paddle::platform::GetCUDAMultiProcessors"
);
return
count
;
}
int
GetCUDAMaxThreadsPerMultiProcessor
(
int
id
)
{
PADDLE_ENFORCE_LT
(
id
,
GetCUDADeviceCount
(),
"id must less than GPU count"
);
int
count
;
PADDLE_ENFORCE
(
cudaDeviceGetAttribute
(
&
count
,
cudaDevAttrMaxThreadsPerMultiProcessor
,
id
),
"cudaDeviceGetAttribute failed in "
"paddle::platform::GetCUDAMaxThreadsPerMultiProcessor"
);
return
count
;
}
int
GetCurrentDeviceId
()
{
int
device_id
;
PADDLE_ENFORCE
(
...
...
paddle/fluid/platform/gpu_info.h
浏览文件 @
84aea8a8
...
...
@@ -30,6 +30,12 @@ const std::string kEnvFractionGpuMemoryToUse =
//! Get the total number of GPU devices in system.
int
GetCUDADeviceCount
();
//! Get the MultiProcessors of the ith GPU.
int
GetCUDAMultiProcessors
(
int
i
);
//! Get the MaxThreads of each MultiProcessor of the ith GPU.
int
GetCUDAMaxThreadsPerMultiProcessor
(
int
i
);
//! Get the current GPU device id in system.
int
GetCurrentDeviceId
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录