Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
00e596ed
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
00e596ed
编写于
3月 02, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
get max threads of GPU
上级
60e7ee06
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
320 addition
and
55 deletion
+320
-55
paddle/fluid/operators/concat_op.h
paddle/fluid/operators/concat_op.h
+10
-9
paddle/fluid/operators/math/concat.cc
paddle/fluid/operators/math/concat.cc
+54
-21
paddle/fluid/operators/math/concat.cu
paddle/fluid/operators/math/concat.cu
+148
-22
paddle/fluid/operators/math/concat.h
paddle/fluid/operators/math/concat.h
+8
-3
paddle/fluid/operators/math/concat_test.cc
paddle/fluid/operators/math/concat_test.cc
+74
-0
paddle/fluid/platform/gpu_info.cc
paddle/fluid/platform/gpu_info.cc
+20
-0
paddle/fluid/platform/gpu_info.h
paddle/fluid/platform/gpu_info.h
+6
-0
未找到文件。
paddle/fluid/operators/concat_op.h
浏览文件 @
00e596ed
...
@@ -32,6 +32,7 @@ class ConcatKernel : public framework::OpKernel<T> {
...
@@ -32,6 +32,7 @@ class ConcatKernel : public framework::OpKernel<T> {
int64_t
axis
=
static_cast
<
int64_t
>
(
ctx
.
Attr
<
int
>
(
"axis"
));
int64_t
axis
=
static_cast
<
int64_t
>
(
ctx
.
Attr
<
int
>
(
"axis"
));
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
out
->
mutable_data
<
T
>
(
place
);
out
->
mutable_data
<
T
>
(
place
);
std
::
vector
<
framework
::
Tensor
>
inputs
(
ins
.
size
());
std
::
vector
<
framework
::
Tensor
>
inputs
(
ins
.
size
());
for
(
size_t
j
=
0
;
j
<
ins
.
size
();
++
j
)
{
for
(
size_t
j
=
0
;
j
<
ins
.
size
();
++
j
)
{
inputs
[
j
]
=
*
ins
[
j
];
inputs
[
j
]
=
*
ins
[
j
];
...
@@ -49,17 +50,17 @@ class ConcatGradKernel : public framework::OpKernel<T> {
...
@@ -49,17 +50,17 @@ class ConcatGradKernel : public framework::OpKernel<T> {
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
outs
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
outs
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
int64_t
axis
=
static_cast
<
int64_t
>
(
ctx
.
Attr
<
int
>
(
"axis"
));
int64_t
axis
=
static_cast
<
int64_t
>
(
ctx
.
Attr
<
int
>
(
"axis"
));
size_t
input_offset
=
0
;
auto
in_stride
=
framework
::
stride_numel
(
in
->
dims
());
for
(
auto
&
out
:
outs
)
{
std
::
vector
<
framework
::
Tensor
>
outputs
(
outs
.
size
());
out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
for
(
size_t
j
=
0
;
j
<
outs
.
size
();
++
j
)
{
auto
out_stride
=
framework
::
stride_numel
(
out
->
dims
());
outs
[
j
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
StridedNumelCopyWithAxis
<
T
>
(
ctx
.
device_context
(),
axis
,
out
->
data
<
T
>
(),
outputs
[
j
]
=
*
outs
[
j
];
out_stride
,
in
->
data
<
T
>
()
+
input_offset
,
in_stride
,
out_stride
[
axis
]);
input_offset
+=
out_stride
[
axis
];
}
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
paddle
::
operators
::
math
::
ConcatGradFunctor
<
DeviceContext
,
T
>
concat_grad_functor
;
concat_grad_functor
(
dev_ctx
,
*
in
,
static_cast
<
int
>
(
axis
),
outputs
);
}
}
};
};
...
...
paddle/fluid/operators/math/concat.cc
浏览文件 @
00e596ed
...
@@ -25,16 +25,12 @@ template <typename T>
...
@@ -25,16 +25,12 @@ template <typename T>
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
const
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
)
{
framework
::
Tensor
*
output
)
{
// assume the the max size of input is less than 8 and see the performance
// assume the the max size of input is less than 8 and see the performance
// save origin dim
// save origin dim
int
num
=
input
.
size
();
int
num
=
input
.
size
();
std
::
vector
<
paddle
::
framework
::
DDim
>
origin_dim
(
num
);
std
::
vector
<
paddle
::
framework
::
DDim
>
origin_dim
(
num
);
// for (int j = 0; j < num; ++j) {
// origin_dim[j] = input[j].dims();
// }
auto
out_dim
=
output
->
dims
();
// get the matrix size
// get the matrix size
int
rows
=
1
;
int
rows
=
1
;
...
@@ -42,40 +38,72 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
...
@@ -42,40 +38,72 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
rows
*=
dim_0
[
i
];
rows
*=
dim_0
[
i
];
}
}
int
cols
=
input
[
0
].
numel
()
/
rows
;
int
out_rows
=
rows
,
out_cols
=
0
;
int
out_rows
=
rows
,
out_cols
=
0
;
bool
sameShape
=
true
;
// reshape to matrix
// get input's cols
std
::
vector
<
int64_t
>
input_cols
(
input
.
size
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
input
[
i
].
numel
()
/
rows
;
int
t_cols
=
input
[
i
].
numel
()
/
rows
;
if
(
sameShape
)
{
if
(
t_cols
!=
cols
)
sameShape
=
false
;
}
out_cols
+=
t_cols
;
out_cols
+=
t_cols
;
input
[
i
].
Resize
({
rows
,
t_cols
})
;
input
_cols
[
i
]
=
t_cols
;
}
}
output
->
Resize
({
out_rows
,
out_cols
});
auto
&
cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
context
.
GetPlace
());
auto
&
cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
context
.
GetPlace
());
// computation
// computation
for
(
int
k
=
0
;
k
<
rows
;
++
k
)
{
for
(
int
k
=
0
;
k
<
out_rows
;
++
k
)
{
// offset k * out_cols
T
*
dst_ptr
=
output
->
data
<
T
>
()
+
k
*
out_cols
;
T
*
dst_ptr
=
output
->
data
<
T
>
()
+
k
*
out_cols
;
int
col_idx
=
0
;
int
col_idx
=
0
;
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
int
col_len
=
input
[
j
].
dims
()[
1
];
int
col_len
=
input
_cols
[
j
];
const
T
*
src_prt
=
input
[
j
].
data
<
T
>
()
+
k
*
col_len
;
const
T
*
src_prt
=
input
[
j
].
data
<
T
>
()
+
k
*
col_len
;
memory
::
Copy
(
cpu_place
,
dst_ptr
+
col_idx
,
cpu_place
,
src_prt
,
memory
::
Copy
(
cpu_place
,
dst_ptr
+
col_idx
,
cpu_place
,
src_prt
,
sizeof
(
T
)
*
col_len
);
sizeof
(
T
)
*
col_len
);
col_idx
+=
col_len
;
col_idx
+=
col_len
;
}
}
}
}
}
};
template
<
typename
T
>
class
ConcatGradFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
int
axis
,
std
::
vector
<
framework
::
Tensor
>&
outputs
)
{
// assume the the max size of input is less than 8 and see the performance
// save origin dim
int
num
=
outputs
.
size
();
std
::
vector
<
paddle
::
framework
::
DDim
>
origin_dim
(
num
);
// recover origin dim
// get the matrix size
// for (int j = 0; j < num; ++j) {
int
input_rows
=
1
;
// input[j]->Resize(origin_dim[j]);
auto
dim_0
=
outputs
[
0
].
dims
();
// }
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
output
->
Resize
(
out_dim
);
input_rows
*=
dim_0
[
i
];
}
int
input_cols
=
0
;
// get outputs' cols
std
::
vector
<
int64_t
>
output_cols
(
outputs
.
size
());
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
outputs
[
i
].
numel
()
/
input_rows
;
input_cols
+=
t_cols
;
output_cols
[
i
]
=
t_cols
;
}
auto
&
cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
context
.
GetPlace
());
// computation
for
(
int
k
=
0
;
k
<
input_rows
;
++
k
)
{
const
T
*
src_ptr
=
input
.
data
<
T
>
()
+
k
*
input_cols
;
int
col_idx
=
0
;
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
int
col_len
=
output_cols
[
j
];
T
*
dst_ptr
=
outputs
[
j
].
data
<
T
>
()
+
k
*
col_len
;
memory
::
Copy
(
cpu_place
,
dst_ptr
,
cpu_place
,
src_ptr
+
col_idx
,
sizeof
(
T
)
*
col_len
);
col_idx
+=
col_len
;
}
}
}
}
};
};
...
@@ -84,6 +112,11 @@ template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
...
@@ -84,6 +112,11 @@ template class ConcatFunctor<platform::CPUDeviceContext, int64_t>;
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
double
>;
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
double
>;
template
class
ConcatGradFunctor
<
platform
::
CPUDeviceContext
,
int
>;
template
class
ConcatGradFunctor
<
platform
::
CPUDeviceContext
,
int64_t
>;
template
class
ConcatGradFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
class
ConcatGradFunctor
<
platform
::
CPUDeviceContext
,
double
>;
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/math/concat.cu
浏览文件 @
00e596ed
...
@@ -22,7 +22,7 @@ namespace math {
...
@@ -22,7 +22,7 @@ namespace math {
// TODO(zcd): This can be replaced by tensor,
// TODO(zcd): This can be replaced by tensor,
// if that, maybe we should add int8 to VarType::Type.
// if that, maybe we should add int8 to VarType::Type.
// Or replaced by tensorArray.
// Or replaced by tensorArray.
static
constexpr
int
MaxSize
=
32
;
static
constexpr
int
MaxSize
=
8
;
template
<
typename
T
>
template
<
typename
T
>
struct
CUDADeviceArray
{
struct
CUDADeviceArray
{
T
data
[
MaxSize
];
T
data
[
MaxSize
];
...
@@ -54,7 +54,6 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
...
@@ -54,7 +54,6 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
const
int
output_rows
,
const
int
output_cols
,
const
int
output_rows
,
const
int
output_cols
,
T
*
output
)
{
T
*
output
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
segment
=
upper_bound
<
int
>
(
input_cols
.
data
,
input_cols
.
size
,
tid_x
)
-
1
;
int
segment
=
upper_bound
<
int
>
(
input_cols
.
data
,
input_cols
.
size
,
tid_x
)
-
1
;
int
curr_offset
=
input_cols
.
data
[
segment
];
int
curr_offset
=
input_cols
.
data
[
segment
];
...
@@ -69,13 +68,73 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
...
@@ -69,13 +68,73 @@ __global__ void KernelConcat(const CUDADeviceArray<const T*> inputs,
int
local_col
=
tid_x
-
curr_offset
;
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
const
T
*
input_ptr
=
inputs
.
data
[
curr_segment
];
const
T
*
input_ptr
=
inputs
.
data
[
curr_segment
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output
[
tid_y
*
output_cols
+
tid_x
]
=
output
[
tid_y
*
output_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
segment_width
+
local_col
];
input_ptr
[
tid_y
*
segment_width
+
local_col
];
}
}
}
}
template
<
typename
T
>
__global__
void
KernelConcat
(
const
CUDADeviceArray
<
const
T
*>
inputs
,
const
int
input_col
,
const
int
output_rows
,
const
int
output_cols
,
T
*
output
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
float
inv_input_col
=
1.0
/
input_col
;
for
(;
tid_x
<
output_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
*
inv_input_col
;
int
in_offset
=
tid_x
-
split
*
input_col
;
const
T
*
input_ptr
=
inputs
.
data
[
split
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output
[
tid_y
*
output_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
input_col
+
in_offset
];
}
}
template
<
typename
T
>
__global__
void
KernelConcatGrad
(
const
T
*
input
,
const
int
input_row
,
const
int
input_col
,
CUDADeviceArray
<
int
>
output_cols
,
CUDADeviceArray
<
T
*>
outputs
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
segment
=
upper_bound
<
int
>
(
output_cols
.
data
,
output_cols
.
size
,
tid_x
)
-
1
;
int
curr_offset
=
output_cols
.
data
[
segment
];
int
curr_segment
=
segment
;
for
(;
tid_x
<
input_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
curr_col_offset
;
while
((
curr_col_offset
=
output_cols
.
data
[
curr_segment
+
1
])
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
}
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
T
*
output_ptr
=
outputs
.
data
[
curr_segment
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
input_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
segment_width
+
local_col
]
=
input
[
tid_y
*
input_col
+
tid_x
];
}
}
template
<
typename
T
>
__global__
void
KernelConcatGrad
(
const
T
*
input
,
const
int
input_row
,
const
int
input_col
,
const
int
output_cols
,
CUDADeviceArray
<
T
*>
outputs
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
float
inv_input_col
=
1.0
/
input_col
;
for
(;
tid_x
<
input_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
*
inv_input_col
;
int
in_offset
=
tid_x
-
split
*
input_col
;
T
*
output_ptr
=
outputs
.
data
[
split
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
input_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
output_cols
+
in_offset
]
=
input
[
tid_y
*
input_col
+
tid_x
];
}
}
/*
/*
* All tensors' dimension should be the same.
* All tensors' dimension should be the same.
*/
*/
...
@@ -83,17 +142,13 @@ template <typename T>
...
@@ -83,17 +142,13 @@ template <typename T>
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
const
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
)
{
framework
::
Tensor
*
output
)
{
// assume the the max size of input is less than 8 and see the performance
// assume the the max size of input is less than 8 and see the performance
// save origin dim
// save origin dim
int
num
=
input
.
size
();
int
num
=
input
.
size
();
// std::vector<paddle::framework::DDim> origin_dim(num);
PADDLE_ENFORCE_LT
(
num
,
MaxSize
,
"input number should be less than %d"
,
// for (int j = 0; j < num; ++j) {
MaxSize
);
// origin_dim[j] = input[j].dims();
// }
auto
out_dim
=
output
->
dims
();
// get the matrix size
// get the matrix size
int
rows
=
1
;
int
rows
=
1
;
auto
dim_0
=
input
[
0
].
dims
();
auto
dim_0
=
input
[
0
].
dims
();
...
@@ -117,30 +172,96 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
...
@@ -117,30 +172,96 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
if
(
t_cols
!=
cols
)
sameShape
=
false
;
if
(
t_cols
!=
cols
)
sameShape
=
false
;
}
}
out_cols
+=
t_cols
;
out_cols
+=
t_cols
;
input
[
i
].
Resize
({
rows
,
t_cols
});
inputs_cols
.
data
[
i
+
1
]
=
out_cols
;
inputs_cols
.
data
[
i
+
1
]
=
out_cols
;
inputs_data
.
data
[
i
]
=
input
[
i
].
data
<
T
>
();
inputs_data
.
data
[
i
]
=
input
[
i
].
data
<
T
>
();
}
}
output
->
Resize
({
out_rows
,
out_cols
});
// computation
// computation
const
int
kThreadsPerBlock
=
256
;
// set the thread block and grid according to CurrentDeviceId
const
int
kThreadsPerBlock
=
1024
;
int
block_cols
=
std
::
min
(
out_cols
,
kThreadsPerBlock
);
int
block_cols
=
std
::
min
(
out_cols
,
kThreadsPerBlock
);
int
block_rows
=
std
::
max
(
kThreadsPerBlock
/
block_cols
,
1
);
int
block_rows
=
std
::
max
(
kThreadsPerBlock
/
block_cols
,
1
);
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
int
grid_cols
=
(
out_cols
+
block_cols
-
1
)
/
block_cols
;
int
dev_id
=
paddle
::
platform
::
GetCurrentDeviceId
();
int
grid_rows
=
(
out_rows
+
block_rows
-
1
)
/
block_rows
;
int
multi_process
=
paddle
::
platform
::
GetCUDAMultiProcessors
(
dev_id
);
int
max_threads_per_mp
=
paddle
::
platform
::
GetCUDAMaxThreadsPerMultiProcessor
(
dev_id
);
int
max_threads
=
multi_process
*
max_threads_per_mp
;
int
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
int
grid_cols
=
std
::
min
((
out_cols
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
int
grid_rows
=
std
::
min
(
max_blocks
/
grid_cols
,
std
::
max
(
out_rows
/
block_rows
,
1
));
dim3
grid_size
=
dim3
(
grid_cols
,
grid_rows
,
1
);
dim3
grid_size
=
dim3
(
grid_cols
,
grid_rows
,
1
);
KernelConcat
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
if
(
sameShape
)
{
inputs_data
,
inputs_cols
,
out_rows
,
out_cols
,
output
->
data
<
T
>
());
KernelConcat
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
inputs_data
,
cols
,
out_rows
,
out_cols
,
output
->
data
<
T
>
());
}
else
{
KernelConcat
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
inputs_data
,
inputs_cols
,
out_rows
,
out_cols
,
output
->
data
<
T
>
());
}
}
};
template
<
typename
T
>
class
ConcatGradFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
int
axis
,
std
::
vector
<
framework
::
Tensor
>&
outputs
)
{
// assume the the max size of input is less than 8 and see the performance
// save origin dim
int
num
=
outputs
.
size
();
PADDLE_ENFORCE_LT
(
num
,
MaxSize
,
"input number should be less than %d"
,
MaxSize
);
// get the matrix size
int
input_row
=
1
;
auto
dim_0
=
outputs
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
input_row
*=
dim_0
[
i
];
}
int
output_col_0
=
outputs
[
0
].
numel
()
/
input_row
;
int
input_col
=
0
;
bool
sameShape
=
true
;
CUDADeviceArray
<
T
*>
outputs_data
;
CUDADeviceArray
<
int
>
outputs_cols
;
outputs_data
.
size
=
num
;
outputs_cols
.
size
=
num
+
1
;
outputs_cols
.
data
[
0
]
=
0
;
// recover origin dim
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
// for (int j = 0; j < num; ++j) {
int
t_col
=
outputs
[
i
].
numel
()
/
input_row
;
// input[j].Resize(origin_dim[j]);
if
(
sameShape
)
{
// }
if
(
t_col
!=
output_col_0
)
sameShape
=
false
;
output
->
Resize
(
out_dim
);
}
input_col
+=
t_col
;
outputs_cols
.
data
[
i
+
1
]
=
input_col
;
outputs_data
.
data
[
i
]
=
outputs
[
i
].
data
<
T
>
();
}
// computation
const
int
kThreadsPerBlock
=
256
;
int
block_cols
=
std
::
min
(
input_col
,
kThreadsPerBlock
);
int
block_rows
=
std
::
max
(
kThreadsPerBlock
/
block_cols
,
1
);
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
int
grid_cols
=
(
input_col
+
block_cols
-
1
)
/
block_cols
;
int
grid_rows
=
(
input_row
+
block_rows
-
1
)
/
block_rows
;
dim3
grid_size
=
dim3
(
grid_cols
,
grid_rows
,
1
);
if
(
sameShape
)
{
KernelConcatGrad
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
input_row
,
input_col
,
output_col_0
,
outputs_data
);
}
else
{
KernelConcatGrad
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
input_row
,
input_col
,
outputs_cols
,
outputs_data
);
}
}
}
};
};
...
@@ -149,6 +270,11 @@ template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
...
@@ -149,6 +270,11 @@ template class ConcatFunctor<platform::CUDADeviceContext, int64_t>;
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
double
>;
template
class
ConcatGradFunctor
<
platform
::
CUDADeviceContext
,
int
>;
template
class
ConcatGradFunctor
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
class
ConcatGradFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
ConcatGradFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/math/concat.h
浏览文件 @
00e596ed
...
@@ -20,18 +20,23 @@ namespace operators {
...
@@ -20,18 +20,23 @@ namespace operators {
namespace
math
{
namespace
math
{
/*
/*
* the tensor's shape of input will be changed,
* so the second parameter is not const.
*
*
*/
*/
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
ConcatFunctor
{
class
ConcatFunctor
{
public:
public:
void
operator
()(
const
DeviceContext
&
context
,
void
operator
()(
const
DeviceContext
&
context
,
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
const
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
);
framework
::
Tensor
*
output
);
};
};
template
<
typename
DeviceContext
,
typename
T
>
class
ConcatGradFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
int
axis
,
std
::
vector
<
framework
::
Tensor
>&
outputs
);
};
}
// namespace math
}
// namespace math
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/math/concat_test.cc
浏览文件 @
00e596ed
...
@@ -251,6 +251,80 @@ void testConcat() {
...
@@ -251,6 +251,80 @@ void testConcat() {
}
}
}
}
}
}
/**
* cast4:
* inputs:
* axis = 1
* t_a.shape: [2, 3, 4]
* t_b.shape: [2, 3, 4]
* output:
* out.shape: [2, 6, 4]
*/
dim_a
=
make_ddim
({
2
,
3
,
4
});
dim_b
=
make_ddim
({
2
,
3
,
4
});
dim_out
=
make_ddim
({
2
,
6
,
4
});
input_a
.
Resize
(
dim_a
);
input_b
.
Resize
(
dim_b
);
out
.
Resize
(
dim_out
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
input_a_cpu
.
Resize
(
dim_a
);
input_b_cpu
.
Resize
(
dim_b
);
out_cpu
.
Resize
(
dim_out
);
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
a_ptr
=
input_a_cpu
.
data
<
int
>
();
b_ptr
=
input_b_cpu
.
data
<
int
>
();
}
else
{
a_ptr
=
input_a
.
data
<
int
>
();
b_ptr
=
input_b
.
data
<
int
>
();
}
for
(
int
i
=
0
;
i
<
2
*
3
*
4
;
++
i
)
{
a_ptr
[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
2
*
3
*
4
;
++
i
)
{
b_ptr
[
i
]
=
i
;
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
input_a_cpu
,
Place
(),
*
context
,
&
input_a
);
TensorCopy
(
input_b_cpu
,
Place
(),
*
context
,
&
input_b
);
}
input
.
clear
();
input
.
push_back
(
input_a
);
input
.
push_back
(
input_b
);
concat_functor
(
*
context
,
input
,
1
,
&
out
);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ
(
input_a
.
dims
(),
dim_a
);
PADDLE_ENFORCE_EQ
(
input_b
.
dims
(),
dim_b
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
out
,
CPUPlace
(),
*
context
,
&
out_cpu
);
out_ptr
=
out_cpu
.
data
<
int
>
();
}
else
{
out_ptr
=
out
.
data
<
int
>
();
}
// check the data
cols
=
12
;
idx_a
=
0
,
idx_b
=
0
;
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
for
(
int
j
=
0
;
j
<
24
;
++
j
)
{
if
(
j
>=
cols
)
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
24
+
j
],
b_ptr
[
idx_b
]);
++
idx_b
;
}
else
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
24
+
j
],
a_ptr
[
idx_a
]);
++
idx_a
;
}
}
}
}
}
TEST
(
math
,
concat
)
{
TEST
(
math
,
concat
)
{
...
...
paddle/fluid/platform/gpu_info.cc
浏览文件 @
00e596ed
...
@@ -33,6 +33,26 @@ int GetCUDADeviceCount() {
...
@@ -33,6 +33,26 @@ int GetCUDADeviceCount() {
return
count
;
return
count
;
}
}
int
GetCUDAMultiProcessors
(
int
id
)
{
PADDLE_ENFORCE_LT
(
id
,
GetCUDADeviceCount
(),
"id must less than GPU count"
);
int
count
;
PADDLE_ENFORCE
(
cudaDeviceGetAttribute
(
&
count
,
cudaDevAttrMultiProcessorCount
,
id
),
"cudaDeviceGetAttribute failed in "
"paddle::platform::GetCUDAMultiProcessors"
);
return
count
;
}
int
GetCUDAMaxThreadsPerMultiProcessor
(
int
id
)
{
PADDLE_ENFORCE_LT
(
id
,
GetCUDADeviceCount
(),
"id must less than GPU count"
);
int
count
;
PADDLE_ENFORCE
(
cudaDeviceGetAttribute
(
&
count
,
cudaDevAttrMaxThreadsPerMultiProcessor
,
id
),
"cudaDeviceGetAttribute failed in "
"paddle::platform::GetCUDAMaxThreadsPerMultiProcessor"
);
return
count
;
}
int
GetCurrentDeviceId
()
{
int
GetCurrentDeviceId
()
{
int
device_id
;
int
device_id
;
PADDLE_ENFORCE
(
PADDLE_ENFORCE
(
...
...
paddle/fluid/platform/gpu_info.h
浏览文件 @
00e596ed
...
@@ -30,6 +30,12 @@ const std::string kEnvFractionGpuMemoryToUse =
...
@@ -30,6 +30,12 @@ const std::string kEnvFractionGpuMemoryToUse =
//! Get the total number of GPU devices in system.
//! Get the total number of GPU devices in system.
int
GetCUDADeviceCount
();
int
GetCUDADeviceCount
();
//! Get the MultiProcessors of the ith GPU.
int
GetCUDAMultiProcessors
(
int
i
);
//! Get the MaxThreads of each MultiProcessor of the ith GPU.
int
GetCUDAMaxThreadsPerMultiProcessor
(
int
i
);
//! Get the current GPU device id in system.
//! Get the current GPU device id in system.
int
GetCurrentDeviceId
();
int
GetCurrentDeviceId
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录