Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
60e7ee06
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
60e7ee06
编写于
2月 28, 2018
作者:
C
chengduoZH
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine concat_op
上级
cf883d9c
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
559 addition
and
49 deletion
+559
-49
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-0
paddle/fluid/operators/concat_op.cc
paddle/fluid/operators/concat_op.cc
+5
-4
paddle/fluid/operators/concat_op.h
paddle/fluid/operators/concat_op.h
+8
-45
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+3
-0
paddle/fluid/operators/math/concat.cc
paddle/fluid/operators/math/concat.cc
+89
-0
paddle/fluid/operators/math/concat.cu
paddle/fluid/operators/math/concat.cu
+154
-0
paddle/fluid/operators/math/concat.h
paddle/fluid/operators/math/concat.h
+37
-0
paddle/fluid/operators/math/concat_test.cc
paddle/fluid/operators/math/concat_test.cc
+262
-0
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
60e7ee06
...
@@ -184,6 +184,7 @@ op_library(save_op DEPS lod_tensor)
...
@@ -184,6 +184,7 @@ op_library(save_op DEPS lod_tensor)
op_library
(
load_op DEPS lod_tensor
)
op_library
(
load_op DEPS lod_tensor
)
op_library
(
save_combine_op DEPS lod_tensor
)
op_library
(
save_combine_op DEPS lod_tensor
)
op_library
(
load_combine_op DEPS lod_tensor
)
op_library
(
load_combine_op DEPS lod_tensor
)
op_library
(
concat_op DEPS concat_functor
)
list
(
REMOVE_ITEM GENERAL_OPS
${
DEPS_OPS
}
)
list
(
REMOVE_ITEM GENERAL_OPS
${
DEPS_OPS
}
)
foreach
(
src
${
GENERAL_OPS
}
)
foreach
(
src
${
GENERAL_OPS
}
)
...
...
paddle/fluid/operators/concat_op.cc
浏览文件 @
60e7ee06
...
@@ -100,7 +100,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
...
@@ -100,7 +100,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_EX
(
concat
,
ops
::
ConcatOp
,
ops
::
ConcatOpMaker
,
concat_grad
,
REGISTER_OP_EX
(
concat
,
ops
::
ConcatOp
,
ops
::
ConcatOpMaker
,
concat_grad
,
ops
::
ConcatOpGrad
,
false
)
ops
::
ConcatOpGrad
,
false
)
REGISTER_OP_CPU_KERNEL
(
concat
,
REGISTER_OP_CPU_KERNEL
(
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
)
concat
,
ops
::
ConcatKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
)
REGISTER_OP_CPU_KERNEL
(
concat_grad
,
REGISTER_OP_CPU_KERNEL
(
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
)
concat_grad
,
ops
::
ConcatGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
)
paddle/fluid/operators/concat_op.h
浏览文件 @
60e7ee06
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/strided_memcpy.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -27,55 +28,17 @@ class ConcatKernel : public framework::OpKernel<T> {
...
@@ -27,55 +28,17 @@ class ConcatKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
ins
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
ins
=
ctx
.
MultiInput
<
framework
::
Tensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
framework
::
Tensor
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
int64_t
axis
=
static_cast
<
int64_t
>
(
ctx
.
Attr
<
int
>
(
"axis"
));
int64_t
axis
=
static_cast
<
int64_t
>
(
ctx
.
Attr
<
int
>
(
"axis"
));
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
out
->
mutable_data
<
T
>
(
place
);
out
->
mutable_data
<
T
>
(
place
);
std
::
vector
<
framework
::
Tensor
>
inputs
(
ins
.
size
());
auto
out_stride
=
framework
::
stride_numel
(
out
->
dims
());
for
(
size_t
j
=
0
;
j
<
ins
.
size
();
++
j
)
{
inputs
[
j
]
=
*
ins
[
j
];
size_t
output_offset
=
0
;
// If axis >=1, copy to out immediately need to call many times
// of cuda memcpy. Copy the input to cpu and do the stride copy,
// then copy to gpu output.
if
(
platform
::
is_gpu_place
(
place
)
&&
axis
>=
1
)
{
platform
::
CPUPlace
copy_place
;
auto
&
cpu_ctx
=
*
platform
::
DeviceContextPool
::
Instance
().
Get
(
copy_place
);
framework
::
Tensor
cpu_out
;
cpu_out
.
Resize
(
out
->
dims
());
cpu_out
.
mutable_data
<
T
>
(
copy_place
);
auto
&
dev_ctx
=
ctx
.
device_context
();
std
::
vector
<
std
::
unique_ptr
<
framework
::
Tensor
>>
cpu_ins
;
for
(
auto
*
in
:
ins
)
{
std
::
unique_ptr
<
framework
::
Tensor
>
cpu_in
(
new
framework
::
Tensor
);
framework
::
TensorCopy
(
*
in
,
copy_place
,
dev_ctx
,
cpu_in
.
get
());
cpu_ins
.
emplace_back
(
std
::
move
(
cpu_in
));
}
// TODO(dzhwinter): overlap copy and compute stream
// https://devblogs.nvidia.com/how-overlap-data-transfers-cuda-cc/
dev_ctx
.
Wait
();
for
(
auto
&
in
:
cpu_ins
)
{
auto
&
cpu_in
=
*
in
.
get
();
auto
in_stride
=
framework
::
stride_numel
(
cpu_in
.
dims
());
StridedNumelCopyWithAxis
<
T
>
(
cpu_ctx
,
axis
,
cpu_out
.
data
<
T
>
()
+
output_offset
,
out_stride
,
cpu_in
.
data
<
T
>
(),
in_stride
,
in_stride
[
axis
]);
output_offset
+=
in_stride
[
axis
];
}
framework
::
TensorCopy
(
cpu_out
,
place
,
dev_ctx
,
out
);
}
else
{
for
(
auto
*
in
:
ins
)
{
auto
in_stride
=
framework
::
stride_numel
(
in
->
dims
());
StridedNumelCopyWithAxis
<
T
>
(
ctx
.
device_context
(),
axis
,
out
->
data
<
T
>
()
+
output_offset
,
out_stride
,
in
->
data
<
T
>
(),
in_stride
,
in_stride
[
axis
]);
output_offset
+=
in_stride
[
axis
];
}
}
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
paddle
::
operators
::
math
::
ConcatFunctor
<
DeviceContext
,
T
>
concat_functor
;
concat_functor
(
dev_ctx
,
inputs
,
static_cast
<
int
>
(
axis
),
out
);
}
}
};
};
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
60e7ee06
...
@@ -20,6 +20,7 @@ if(WITH_GPU)
...
@@ -20,6 +20,7 @@ if(WITH_GPU)
nv_library
(
unpooling SRCS unpooling.cc unpooling.cu DEPS device_context
)
nv_library
(
unpooling SRCS unpooling.cc unpooling.cu DEPS device_context
)
nv_library
(
gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function
)
nv_library
(
gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function
)
nv_library
(
cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context
)
nv_library
(
cos_sim_functor SRCS cos_sim_functor.cc cos_sim_functor.cu DEPS device_context
)
nv_library
(
concat_functor SRCS concat.cc concat.cu DEPS device_context tensor
)
else
()
else
()
cc_library
(
math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto
)
cc_library
(
math_function SRCS math_function.cc im2col.cc DEPS cblas device_context framework_proto
)
cc_library
(
selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function
)
cc_library
(
selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function
)
...
@@ -37,6 +38,7 @@ else()
...
@@ -37,6 +38,7 @@ else()
cc_library
(
unpooling SRCS unpooling.cc DEPS device_context
)
cc_library
(
unpooling SRCS unpooling.cc DEPS device_context
)
cc_library
(
gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function
)
cc_library
(
gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function
)
cc_library
(
cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context
)
cc_library
(
cos_sim_functor SRCS cos_sim_functor.cc DEPS device_context
)
cc_library
(
concat_functor SRCS concat.cc DEPS device_context tensor
)
endif
()
endif
()
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function tensor
)
...
@@ -44,3 +46,4 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec
...
@@ -44,3 +46,4 @@ cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selec
cc_test
(
im2col_test SRCS im2col_test.cc DEPS math_function tensor
)
cc_test
(
im2col_test SRCS im2col_test.cc DEPS math_function tensor
)
cc_test
(
vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor
)
cc_test
(
vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor
)
cc_test
(
sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding
)
cc_test
(
sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding
)
cc_test
(
concat_test SRCS concat_test.cc DEPS concat_functor tensor
)
paddle/fluid/operators/math/concat.cc
0 → 100644
浏览文件 @
60e7ee06
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
/*
* All tensors' dimension should be the same.
*/
template
<
typename
T
>
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
)
{
// assume the the max size of input is less than 8 and see the performance
// save origin dim
int
num
=
input
.
size
();
std
::
vector
<
paddle
::
framework
::
DDim
>
origin_dim
(
num
);
// for (int j = 0; j < num; ++j) {
// origin_dim[j] = input[j].dims();
// }
auto
out_dim
=
output
->
dims
();
// get the matrix size
int
rows
=
1
;
auto
dim_0
=
input
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
rows
*=
dim_0
[
i
];
}
int
cols
=
input
[
0
].
numel
()
/
rows
;
int
out_rows
=
rows
,
out_cols
=
0
;
bool
sameShape
=
true
;
// reshape to matrix
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
input
[
i
].
numel
()
/
rows
;
if
(
sameShape
)
{
if
(
t_cols
!=
cols
)
sameShape
=
false
;
}
out_cols
+=
t_cols
;
input
[
i
].
Resize
({
rows
,
t_cols
});
}
output
->
Resize
({
out_rows
,
out_cols
});
auto
&
cpu_place
=
boost
::
get
<
platform
::
CPUPlace
>
(
context
.
GetPlace
());
// computation
for
(
int
k
=
0
;
k
<
rows
;
++
k
)
{
// offset k * out_cols
T
*
dst_ptr
=
output
->
data
<
T
>
()
+
k
*
out_cols
;
int
col_idx
=
0
;
for
(
int
j
=
0
;
j
<
num
;
++
j
)
{
int
col_len
=
input
[
j
].
dims
()[
1
];
const
T
*
src_prt
=
input
[
j
].
data
<
T
>
()
+
k
*
col_len
;
memory
::
Copy
(
cpu_place
,
dst_ptr
+
col_idx
,
cpu_place
,
src_prt
,
sizeof
(
T
)
*
col_len
);
col_idx
+=
col_len
;
}
}
// recover origin dim
// for (int j = 0; j < num; ++j) {
// input[j]->Resize(origin_dim[j]);
// }
output
->
Resize
(
out_dim
);
}
};
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
int
>;
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
int64_t
>;
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
float
>;
template
class
ConcatFunctor
<
platform
::
CPUDeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/concat.cu
0 → 100644
浏览文件 @
60e7ee06
/* Copyright (c) 2018 paddlepaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
#include "paddle/fluid/platform/cuda_helper.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
// TODO(zcd): This can be replaced by tensor,
// if that, maybe we should add int8 to VarType::Type.
// Or replaced by tensorArray.
static
constexpr
int
MaxSize
=
32
;
template
<
typename
T
>
struct
CUDADeviceArray
{
T
data
[
MaxSize
];
int
size
;
};
template
<
typename
T
>
__device__
T
upper_bound
(
const
T
*
first
,
T
count
,
T
val
)
{
const
T
*
orig
=
first
;
const
T
*
it
=
nullptr
;
T
step
=
0
;
while
(
count
>
0
)
{
it
=
first
;
step
=
count
/
2
;
it
+=
step
;
if
(
!
(
val
<
*
it
))
{
first
=
++
it
;
count
-=
step
+
1
;
}
else
{
count
=
step
;
}
}
return
first
-
orig
;
}
template
<
typename
T
>
__global__
void
KernelConcat
(
const
CUDADeviceArray
<
const
T
*>
inputs
,
const
CUDADeviceArray
<
int
>
input_cols
,
const
int
output_rows
,
const
int
output_cols
,
T
*
output
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
segment
=
upper_bound
<
int
>
(
input_cols
.
data
,
input_cols
.
size
,
tid_x
)
-
1
;
int
curr_offset
=
input_cols
.
data
[
segment
];
int
curr_segment
=
segment
;
for
(;
tid_x
<
output_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
curr_col_offset
;
while
((
curr_col_offset
=
input_cols
.
data
[
curr_segment
+
1
])
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
}
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
const
T
*
input_ptr
=
inputs
.
data
[
curr_segment
];
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output
[
tid_y
*
output_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
segment_width
+
local_col
];
}
}
/*
* All tensors' dimension should be the same.
*/
template
<
typename
T
>
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
T
>
{
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
)
{
// assume the the max size of input is less than 8 and see the performance
// save origin dim
int
num
=
input
.
size
();
// std::vector<paddle::framework::DDim> origin_dim(num);
// for (int j = 0; j < num; ++j) {
// origin_dim[j] = input[j].dims();
// }
auto
out_dim
=
output
->
dims
();
// get the matrix size
int
rows
=
1
;
auto
dim_0
=
input
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
rows
*=
dim_0
[
i
];
}
int
cols
=
input
[
0
].
numel
()
/
rows
;
int
out_rows
=
rows
,
out_cols
=
0
;
bool
sameShape
=
true
;
CUDADeviceArray
<
const
T
*>
inputs_data
;
CUDADeviceArray
<
int
>
inputs_cols
;
inputs_data
.
size
=
num
;
inputs_cols
.
size
=
num
+
1
;
inputs_cols
.
data
[
0
]
=
0
;
// reshape to matrix
// check input shape is valid
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
input
[
i
].
numel
()
/
rows
;
if
(
sameShape
)
{
if
(
t_cols
!=
cols
)
sameShape
=
false
;
}
out_cols
+=
t_cols
;
input
[
i
].
Resize
({
rows
,
t_cols
});
inputs_cols
.
data
[
i
+
1
]
=
out_cols
;
inputs_data
.
data
[
i
]
=
input
[
i
].
data
<
T
>
();
}
output
->
Resize
({
out_rows
,
out_cols
});
// computation
const
int
kThreadsPerBlock
=
256
;
int
block_cols
=
std
::
min
(
out_cols
,
kThreadsPerBlock
);
int
block_rows
=
std
::
max
(
kThreadsPerBlock
/
block_cols
,
1
);
dim3
block_size
=
dim3
(
block_cols
,
block_rows
,
1
);
int
grid_cols
=
(
out_cols
+
block_cols
-
1
)
/
block_cols
;
int
grid_rows
=
(
out_rows
+
block_rows
-
1
)
/
block_rows
;
dim3
grid_size
=
dim3
(
grid_cols
,
grid_rows
,
1
);
KernelConcat
<<<
grid_size
,
block_size
,
0
,
context
.
stream
()
>>>
(
inputs_data
,
inputs_cols
,
out_rows
,
out_cols
,
output
->
data
<
T
>
());
// recover origin dim
// for (int j = 0; j < num; ++j) {
// input[j].Resize(origin_dim[j]);
// }
output
->
Resize
(
out_dim
);
}
};
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
int
>;
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
int64_t
>;
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
float
>;
template
class
ConcatFunctor
<
platform
::
CUDADeviceContext
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/concat.h
0 → 100644
浏览文件 @
60e7ee06
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/tensor.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
/*
* the tensor's shape of input will be changed,
* so the second parameter is not const.
*
*/
template
<
typename
DeviceContext
,
typename
T
>
class
ConcatFunctor
{
public:
void
operator
()(
const
DeviceContext
&
context
,
std
::
vector
<
framework
::
Tensor
>&
input
,
const
int
axis
,
framework
::
Tensor
*
output
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/concat_test.cc
0 → 100644
浏览文件 @
60e7ee06
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/tensor_util.h"
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
template
<
typename
DeviceContext
,
typename
Place
>
void
testConcat
()
{
Tensor
input_a_cpu
;
Tensor
input_b_cpu
;
Tensor
out_cpu
;
Tensor
input_a
;
Tensor
input_b
;
Tensor
out
;
DeviceContext
*
context
=
new
DeviceContext
(
Place
());
// DeviceContext context(Place());
/**
* cast1:
* inputs:
* t_a.shape: [2, 3, 4]
* t_b.shape: [3, 3, 4]
* output:
* out.shape: [5, 3, 4]
*/
auto
dim_a
=
make_ddim
({
2
,
3
,
4
});
auto
dim_b
=
make_ddim
({
3
,
3
,
4
});
auto
dim_out
=
make_ddim
({
5
,
3
,
4
});
input_a
.
mutable_data
<
int
>
(
dim_a
,
Place
());
input_b
.
mutable_data
<
int
>
(
dim_b
,
Place
());
out
.
mutable_data
<
int
>
(
dim_out
,
Place
());
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
input_a_cpu
.
mutable_data
<
int
>
(
dim_a
,
CPUPlace
());
input_b_cpu
.
mutable_data
<
int
>
(
dim_b
,
CPUPlace
());
out_cpu
.
mutable_data
<
int
>
(
dim_out
,
CPUPlace
());
}
int
*
a_ptr
;
int
*
b_ptr
;
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
a_ptr
=
input_a_cpu
.
data
<
int
>
();
b_ptr
=
input_b_cpu
.
data
<
int
>
();
}
else
{
a_ptr
=
input_a
.
data
<
int
>
();
b_ptr
=
input_b
.
data
<
int
>
();
}
for
(
int
i
=
0
;
i
<
2
*
3
*
4
;
++
i
)
{
a_ptr
[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
3
*
3
*
4
;
++
i
)
{
b_ptr
[
i
]
=
i
;
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
input_a_cpu
,
Place
(),
*
context
,
&
input_a
);
TensorCopy
(
input_b_cpu
,
Place
(),
*
context
,
&
input_b
);
}
std
::
vector
<
Tensor
>
input
;
input
.
push_back
(
input_a
);
input
.
push_back
(
input_b
);
paddle
::
operators
::
math
::
ConcatFunctor
<
DeviceContext
,
int
>
concat_functor
;
concat_functor
(
*
context
,
input
,
0
,
&
out
);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ
(
input_a
.
dims
(),
dim_a
);
PADDLE_ENFORCE_EQ
(
input_b
.
dims
(),
dim_b
);
int
*
out_ptr
;
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
out
,
CPUPlace
(),
*
context
,
&
out_cpu
);
out_ptr
=
out_cpu
.
data
<
int
>
();
}
else
{
out_ptr
=
out
.
data
<
int
>
();
}
int
cols
=
2
*
3
*
4
;
int
idx_a
=
0
,
idx_b
=
0
;
for
(
int
j
=
0
;
j
<
5
*
3
*
4
;
++
j
)
{
if
(
j
>=
cols
)
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
j
],
b_ptr
[
idx_b
]);
++
idx_b
;
}
else
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
j
],
a_ptr
[
idx_a
]);
++
idx_a
;
}
}
//
/**
* cast2:
* inputs:
* t_a.shape: [2, 3, 4]
* t_b.shape: [2, 4, 4]
* output:
* out.shape: [2, 7, 4]
*/
dim_a
=
make_ddim
({
2
,
3
,
4
});
dim_b
=
make_ddim
({
2
,
4
,
4
});
dim_out
=
make_ddim
({
2
,
7
,
4
});
input_a
.
Resize
(
dim_a
);
input_b
.
Resize
(
dim_b
);
out
.
Resize
(
dim_out
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
input_a_cpu
.
Resize
(
dim_a
);
input_b_cpu
.
Resize
(
dim_b
);
out_cpu
.
Resize
(
dim_out
);
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
a_ptr
=
input_a_cpu
.
data
<
int
>
();
b_ptr
=
input_b_cpu
.
data
<
int
>
();
}
else
{
a_ptr
=
input_a
.
data
<
int
>
();
b_ptr
=
input_b
.
data
<
int
>
();
}
for
(
int
i
=
0
;
i
<
2
*
3
*
4
;
++
i
)
{
a_ptr
[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
2
*
4
*
4
;
++
i
)
{
b_ptr
[
i
]
=
i
;
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
input_a_cpu
,
Place
(),
*
context
,
&
input_a
);
TensorCopy
(
input_b_cpu
,
Place
(),
*
context
,
&
input_b
);
}
input
.
clear
();
input
.
push_back
(
input_a
);
input
.
push_back
(
input_b
);
concat_functor
(
*
context
,
input
,
1
,
&
out
);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ
(
input_a
.
dims
(),
dim_a
);
PADDLE_ENFORCE_EQ
(
input_b
.
dims
(),
dim_b
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
out
,
CPUPlace
(),
*
context
,
&
out_cpu
);
out_ptr
=
out_cpu
.
data
<
int
>
();
}
else
{
out_ptr
=
out
.
data
<
int
>
();
}
cols
=
3
*
4
;
idx_a
=
0
,
idx_b
=
0
;
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
for
(
int
j
=
0
;
j
<
28
;
++
j
)
{
if
(
j
>=
cols
)
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
28
+
j
],
b_ptr
[
idx_b
]);
++
idx_b
;
}
else
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
28
+
j
],
a_ptr
[
idx_a
]);
++
idx_a
;
}
}
}
/**
* cast3:
* inputs:
* t_a.shape: [2, 3, 5]
* t_b.shape: [2, 3, 4]
* output:
* out.shape: [2, 3, 9]
*/
dim_a
=
make_ddim
({
2
,
3
,
4
});
dim_b
=
make_ddim
({
2
,
3
,
5
});
dim_out
=
make_ddim
({
2
,
3
,
9
});
input_a
.
Resize
(
dim_a
);
input_b
.
Resize
(
dim_b
);
out
.
Resize
(
dim_out
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
input_a_cpu
.
Resize
(
dim_a
);
input_b_cpu
.
Resize
(
dim_b
);
out_cpu
.
Resize
(
dim_out
);
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
a_ptr
=
input_a_cpu
.
data
<
int
>
();
b_ptr
=
input_b_cpu
.
data
<
int
>
();
}
else
{
a_ptr
=
input_a
.
data
<
int
>
();
b_ptr
=
input_b
.
data
<
int
>
();
}
for
(
int
i
=
0
;
i
<
2
*
3
*
4
;
++
i
)
{
a_ptr
[
i
]
=
i
;
}
for
(
int
i
=
0
;
i
<
2
*
3
*
5
;
++
i
)
{
b_ptr
[
i
]
=
i
;
}
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
input_a_cpu
,
Place
(),
*
context
,
&
input_a
);
TensorCopy
(
input_b_cpu
,
Place
(),
*
context
,
&
input_b
);
}
input
.
clear
();
input
.
push_back
(
input_a
);
input
.
push_back
(
input_b
);
concat_functor
(
*
context
,
input
,
2
,
&
out
);
// check the dim of input_a, input_b
PADDLE_ENFORCE_EQ
(
input_a
.
dims
(),
dim_a
);
PADDLE_ENFORCE_EQ
(
input_b
.
dims
(),
dim_b
);
if
(
paddle
::
platform
::
is_gpu_place
(
Place
()))
{
TensorCopy
(
out
,
CPUPlace
(),
*
context
,
&
out_cpu
);
out_ptr
=
out_cpu
.
data
<
int
>
();
}
else
{
out_ptr
=
out
.
data
<
int
>
();
}
// check the data
cols
=
4
;
idx_a
=
0
,
idx_b
=
0
;
for
(
int
i
=
0
;
i
<
6
;
++
i
)
{
for
(
int
j
=
0
;
j
<
9
;
++
j
)
{
if
(
j
>=
cols
)
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
9
+
j
],
b_ptr
[
idx_b
]);
++
idx_b
;
}
else
{
PADDLE_ENFORCE_EQ
(
out_ptr
[
i
*
9
+
j
],
a_ptr
[
idx_a
]);
++
idx_a
;
}
}
}
}
TEST
(
math
,
concat
)
{
testConcat
<
paddle
::
platform
::
CPUDeviceContext
,
paddle
::
platform
::
CPUPlace
>
();
#ifdef PADDLE_WITH_CUDA
testConcat
<
paddle
::
platform
::
CUDADeviceContext
,
paddle
::
platform
::
CUDAPlace
>
();
#endif
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录