Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
06803c29
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
06803c29
编写于
1月 21, 2022
作者:
C
chentianyu03
提交者:
GitHub
1月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[pten] add concat pten kernel (#38955)
上级
814e5ab4
变更
39
展开全部
隐藏空白更改
内联
并排
Showing
39 changed file
with
1552 addition
and
671 deletion
+1552
-671
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+3
-39
paddle/fluid/framework/lod_tensor.h
paddle/fluid/framework/lod_tensor.h
+0
-14
paddle/fluid/framework/lod_tensor_test.cc
paddle/fluid/framework/lod_tensor_test.cc
+3
-2
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+4
-0
paddle/fluid/imperative/prepared_operator.cc
paddle/fluid/imperative/prepared_operator.cc
+4
-0
paddle/fluid/operators/array_to_lod_tensor_op.cc
paddle/fluid/operators/array_to_lod_tensor_op.cc
+2
-1
paddle/fluid/operators/concat_op.cc
paddle/fluid/operators/concat_op.cc
+13
-2
paddle/fluid/operators/concat_op.h
paddle/fluid/operators/concat_op.h
+11
-100
paddle/fluid/operators/concat_op_xpu.cc
paddle/fluid/operators/concat_op_xpu.cc
+4
-2
paddle/fluid/operators/lod_tensor_to_array_op.cc
paddle/fluid/operators/lod_tensor_to_array_op.cc
+2
-1
paddle/fluid/operators/math/concat_and_split.cc
paddle/fluid/operators/math/concat_and_split.cc
+11
-70
paddle/fluid/operators/math/concat_and_split.cu
paddle/fluid/operators/math/concat_and_split.cu
+10
-425
paddle/fluid/operators/merge_lod_tensor_op.cc
paddle/fluid/operators/merge_lod_tensor_op.cc
+3
-1
paddle/fluid/operators/shrink_rnn_memory_op.cc
paddle/fluid/operators/shrink_rnn_memory_op.cc
+3
-1
paddle/fluid/operators/split_lod_tensor_op.cc
paddle/fluid/operators/split_lod_tensor_op.cc
+2
-1
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+2
-2
paddle/pten/CMakeLists.txt
paddle/pten/CMakeLists.txt
+1
-1
paddle/pten/api/include/kernel_signature.h
paddle/pten/api/include/kernel_signature.h
+5
-0
paddle/pten/api/lib/utils/tensor_utils.cc
paddle/pten/api/lib/utils/tensor_utils.cc
+8
-1
paddle/pten/core/CMakeLists.txt
paddle/pten/core/CMakeLists.txt
+1
-0
paddle/pten/core/kernel_context.h
paddle/pten/core/kernel_context.h
+1
-1
paddle/pten/core/lod_utils.cc
paddle/pten/core/lod_utils.cc
+59
-0
paddle/pten/core/lod_utils.h
paddle/pten/core/lod_utils.h
+37
-0
paddle/pten/infermeta/multiary.cc
paddle/pten/infermeta/multiary.cc
+40
-1
paddle/pten/infermeta/multiary.h
paddle/pten/infermeta/multiary.h
+10
-1
paddle/pten/kernels/CMakeLists.txt
paddle/pten/kernels/CMakeLists.txt
+1
-1
paddle/pten/kernels/concat_kernel.h
paddle/pten/kernels/concat_kernel.h
+43
-0
paddle/pten/kernels/cpu/concat_and_split.h
paddle/pten/kernels/cpu/concat_and_split.h
+138
-0
paddle/pten/kernels/cpu/concat_kernel.cc
paddle/pten/kernels/cpu/concat_kernel.cc
+125
-0
paddle/pten/kernels/funcs/concat_funcs.h
paddle/pten/kernels/funcs/concat_funcs.h
+95
-0
paddle/pten/kernels/gpu/concat_and_split.h
paddle/pten/kernels/gpu/concat_and_split.h
+569
-0
paddle/pten/kernels/gpu/concat_kernel.cu
paddle/pten/kernels/gpu/concat_kernel.cu
+125
-0
paddle/pten/tests/api/CMakeLists.txt
paddle/pten/tests/api/CMakeLists.txt
+1
-0
paddle/pten/tests/api/test_concat_api.cc
paddle/pten/tests/api/test_concat_api.cc
+86
-0
paddle/pten/tests/kernels/CMakeLists.txt
paddle/pten/tests/kernels/CMakeLists.txt
+1
-0
paddle/pten/tests/kernels/test_concat_dev_api.cc
paddle/pten/tests/kernels/test_concat_dev_api.cc
+82
-0
python/paddle/utils/code_gen/api.yaml
python/paddle/utils/code_gen/api.yaml
+10
-0
python/paddle/utils/code_gen/api_gen.py
python/paddle/utils/code_gen/api_gen.py
+36
-3
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
06803c29
...
...
@@ -94,7 +94,7 @@ else()
endif
()
cc_library
(
lod_tensor SRCS lod_tensor.cc DEPS ddim mixed_vector place tensor framework_proto version
)
cc_test
(
lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory
)
cc_test
(
lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_
utils lod_
tensor memory
)
if
(
WITH_GPU
)
nv_test
(
lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor
)
...
...
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
06803c29
...
...
@@ -117,7 +117,8 @@ bool CheckLoD(const LoD &in, int tensor_height) {
}
// check: the lowest level's last offset should equals `tensor_height` if
// tensor_height>0.
if
(
tensor_height
>
0
&&
(
size_t
)
tensor_height
!=
in
.
back
().
back
())
if
(
tensor_height
>
0
&&
static_cast
<
size_t
>
(
tensor_height
)
!=
in
.
back
().
back
())
return
false
;
// check: the higher level's last offset should equals the lower level's
...
...
@@ -150,7 +151,7 @@ bool CheckAbsLoD(const LoD &in, int tensor_height) {
if
(
level
.
front
()
!=
0
)
return
false
;
if
(
tensor_height
<
0
)
{
tensor_height
=
level
.
back
();
}
else
if
(
(
size_t
)
tensor_height
!=
level
.
back
())
{
}
else
if
(
static_cast
<
size_t
>
(
tensor_height
)
!=
level
.
back
())
{
return
false
;
}
}
...
...
@@ -186,27 +187,6 @@ LoDAndOffset GetSubLoDAndAbsoluteOffset(const LoD &lod, size_t start_idx,
return
LoDAndOffset
{
sub_lod
,
{
start_idx
,
end_idx
}};
}
void
AppendLoD
(
LoD
*
lod
,
const
LoD
&
lod_length
)
{
PADDLE_ENFORCE
(
lod
->
empty
()
||
lod
->
size
()
==
lod_length
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The input LoD length should be equal to the appended LoD size, but "
"received input LoD length is %d, actual LoD size is %d."
,
lod_length
,
lod
->
size
()));
if
(
lod
->
empty
())
{
for
(
size_t
i
=
0
;
i
<
lod_length
.
size
();
++
i
)
{
lod
->
emplace_back
(
1
,
0
);
// size = 1, value = 0;
}
*
lod
=
LoD
(
lod_length
.
size
(),
std
::
vector
<
size_t
>
({
0
}));
}
for
(
size_t
i
=
0
;
i
<
lod
->
size
();
++
i
)
{
auto
&
level
=
(
*
lod
)[
i
];
for
(
size_t
len
:
lod_length
[
i
])
{
level
.
push_back
(
level
.
back
()
+
len
);
}
}
}
void
SerializeToStream
(
std
::
ostream
&
os
,
const
LoDTensor
&
tensor
,
const
platform
::
DeviceContext
&
dev_ctx
)
{
{
// the 1st field, uint32_t version for LoDTensor
...
...
@@ -313,22 +293,6 @@ void DeserializeFromStream(std::istream &is, LoDTensor *tensor,
TensorFromStream
(
is
,
static_cast
<
Tensor
*>
(
tensor
),
dev_ctx
);
}
LoD
ConvertToLengthBasedLoD
(
const
LoD
&
offset_lod
)
{
LoD
length_lod
;
length_lod
.
reserve
(
offset_lod
.
size
());
for
(
size_t
lvl
=
0
;
lvl
<
offset_lod
.
size
();
++
lvl
)
{
std
::
vector
<
size_t
>
level
;
if
(
offset_lod
[
lvl
].
size
()
>
0
)
{
level
.
reserve
(
offset_lod
[
lvl
].
size
()
-
1
);
}
for
(
size_t
idx
=
0
;
idx
<
offset_lod
[
lvl
].
size
()
-
1
;
++
idx
)
{
level
.
push_back
(
offset_lod
[
lvl
][
idx
+
1
]
-
offset_lod
[
lvl
][
idx
]);
}
length_lod
.
push_back
(
level
);
}
return
length_lod
;
}
LoD
ConvertToOffsetBasedLoD
(
const
LoD
&
length_lod
)
{
LoD
offset_lod
;
offset_lod
.
reserve
(
length_lod
.
size
());
...
...
paddle/fluid/framework/lod_tensor.h
浏览文件 @
06803c29
...
...
@@ -157,8 +157,6 @@ LoDTensor LodExpand(const LoDTensor& source, const LoD& lod, size_t level,
std
::
pair
<
LoD
,
std
::
pair
<
size_t
,
size_t
>>
GetSubLoDAndAbsoluteOffset
(
const
LoD
&
lod
,
size_t
start_idx
,
size_t
end_idx
,
size_t
start_level
);
void
AppendLoD
(
LoD
*
lod
,
const
LoD
&
lod_length
);
/*
* Serialize/Desiralize LoDTensor to std::ostream
* You can pass ofstream or ostringstream to serilize to file
...
...
@@ -173,18 +171,6 @@ void DeserializeFromStream(std::istream& is, LoDTensor* tensor,
const
size_t
&
seek
,
const
std
::
vector
<
int64_t
>&
shape
);
/*
* Convert between length-based LoD and offset-based LoD.
* The implementation of LoDTensor class use offset-based LoD.
* However, we want to expose the more user-friendly length-based
* LoD to the Python side instead.
*
* Example:
* If offset_lod = [[0, 2, 3],[0, 3, 5, 9]]
* then length_lod = [[2, 1], [3, 2, 4]]
*/
LoD
ConvertToLengthBasedLoD
(
const
LoD
&
offset_lod
);
LoD
ConvertToOffsetBasedLoD
(
const
LoD
&
length_lod
);
void
SerializeToStream
(
std
::
ostream
&
os
,
const
LoDTensor
&
tensor
);
...
...
paddle/fluid/framework/lod_tensor_test.cc
浏览文件 @
06803c29
...
...
@@ -16,6 +16,7 @@
#include <gtest/gtest.h>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/pten/core/lod_utils.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -98,7 +99,7 @@ TEST(LoD, AppendLoD) {
origin
.
push_back
(
std
::
vector
<
size_t
>
({
0
,
1
,
6
}));
origin
.
push_back
(
std
::
vector
<
size_t
>
({
0
,
2
,
5
,
7
,
10
,
12
,
15
}));
p
addle
::
framework
::
AppendLoD
(
&
origin
,
lod_lens
);
p
ten
::
AppendLoD
(
&
origin
,
lod_lens
);
LoD
expected
;
expected
.
push_back
(
std
::
vector
<
size_t
>
({
0
,
2
,
4
}));
...
...
@@ -277,7 +278,7 @@ TEST(LoD, ConvertToLengthBasedLoD) {
offset_lod
.
push_back
(
std
::
vector
<
size_t
>
({
0
,
1
,
3
}));
offset_lod
.
push_back
(
std
::
vector
<
size_t
>
({
0
,
2
,
4
,
5
}));
LoD
length_lod
=
ConvertToLengthBasedLoD
(
offset_lod
);
LoD
length_lod
=
pten
::
ConvertToLengthBasedLoD
(
offset_lod
);
LoD
expected
;
expected
.
push_back
(
std
::
vector
<
size_t
>
({
2
}));
...
...
paddle/fluid/framework/operator.cc
浏览文件 @
06803c29
...
...
@@ -1978,6 +1978,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
std
::
type_index
(
typeid
(
std
::
string
)))
{
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
pten
::
Scalar
(
BOOST_GET_CONST
(
std
::
string
,
attr
))));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
int
)))
{
pt_kernel_context
->
EmplaceBackAttr
(
std
::
move
(
pten
::
Scalar
(
BOOST_GET_CONST
(
int
,
attr
))));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported cast op attribute `%s` to Scalar when construct "
...
...
paddle/fluid/imperative/prepared_operator.cc
浏览文件 @
06803c29
...
...
@@ -438,6 +438,10 @@ static void BuildDygraphPtenKernelContext(
std
::
type_index
(
typeid
(
std
::
string
)))
{
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
pten
::
Scalar
(
BOOST_GET_CONST
(
std
::
string
,
attr
))));
}
else
if
(
std
::
type_index
(
attr
.
type
())
==
std
::
type_index
(
typeid
(
int
)))
{
kernel_ctx
->
EmplaceBackAttr
(
std
::
move
(
pten
::
Scalar
(
BOOST_GET_CONST
(
int
,
attr
))));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported cast op attribute `%s` to Scalar when construct "
...
...
paddle/fluid/operators/array_to_lod_tensor_op.cc
浏览文件 @
06803c29
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include <paddle/fluid/operators/math/concat_and_split.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/core/lod_utils.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -168,7 +169,7 @@ class ArrayToLoDTensorOp : public framework::OperatorBase {
x
[
x_idx
].
lod
(),
idx
,
idx
+
1
,
0
);
auto
&
lod_length
=
lod_and_offset
.
first
;
framework
::
AppendLoD
(
out_lod
,
lod_length
);
pten
::
AppendLoD
(
out_lod
,
lod_length
);
size_t
start_offset
=
lod_and_offset
.
second
.
first
;
size_t
end_offset
=
lod_and_offset
.
second
.
second
;
...
...
paddle/fluid/operators/concat_op.cc
浏览文件 @
06803c29
...
...
@@ -19,6 +19,8 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/pten/kernels/funcs/concat_funcs.h"
#ifdef PADDLE_WITH_MKLDNN
#include <paddle/fluid/platform/mkldnn_helper.h>
#endif
...
...
@@ -56,8 +58,8 @@ class ConcatOp : public framework::OperatorWithKernel {
size_t
axis
=
ComputeAxis
(
static_cast
<
int64_t
>
(
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
)),
static_cast
<
int64_t
>
(
inputs_dims
[
0
].
size
()));
framework
::
DDim
out_dims
=
ComputeAndCheckShape
(
ctx
->
IsRuntime
(),
inputs_dims
,
axis
);
framework
::
DDim
out_dims
=
pten
::
funcs
::
ComputeAndCheckShape
(
ctx
->
IsRuntime
(),
inputs_dims
,
axis
);
if
(
out_dims
[
axis
]
<
0
)
{
out_dims
[
axis
]
=
-
1
;
}
...
...
@@ -102,6 +104,15 @@ class ConcatOp : public framework::OperatorWithKernel {
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
framework
::
KernelSignature
GetExpectedPtenKernelArgs
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
if
(
ctx
.
HasInput
(
"AxisTensor"
))
{
return
framework
::
KernelSignature
(
"concat"
,
{
"X"
},
{
"AxisTensor"
},
{
"Out"
});
}
return
framework
::
KernelSignature
(
"concat"
,
{
"X"
},
{
"axis"
},
{
"Out"
});
}
};
class
ConcatOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
...
...
paddle/fluid/operators/concat_op.h
浏览文件 @
06803c29
...
...
@@ -22,54 +22,11 @@ limitations under the License. */
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/pten/kernels/concat_kernel.h"
#include "paddle/pten/kernels/funcs/concat_funcs.h"
namespace
paddle
{
namespace
operators
{
static
inline
framework
::
DDim
ComputeAndCheckShape
(
const
bool
is_runtime
,
const
std
::
vector
<
framework
::
DDim
>&
inputs_dims
,
const
size_t
axis
)
{
const
size_t
n
=
inputs_dims
.
size
();
auto
out_dims
=
inputs_dims
[
0
];
size_t
in_zero_dims_size
=
out_dims
.
size
();
for
(
size_t
i
=
1
;
i
<
n
;
i
++
)
{
PADDLE_ENFORCE_EQ
(
inputs_dims
[
i
].
size
(),
out_dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The shape of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s]."
,
i
,
inputs_dims
[
0
],
i
,
inputs_dims
[
i
]));
for
(
size_t
j
=
0
;
j
<
in_zero_dims_size
;
j
++
)
{
if
(
j
==
axis
)
{
if
(
is_runtime
)
{
out_dims
[
axis
]
+=
inputs_dims
[
i
][
j
];
}
else
{
if
(
inputs_dims
[
i
][
j
]
==
-
1
||
out_dims
[
j
]
==
-
1
)
{
out_dims
[
axis
]
=
-
1
;
}
else
{
out_dims
[
axis
]
+=
inputs_dims
[
i
][
j
];
}
}
}
else
{
bool
check_shape
=
is_runtime
||
(
inputs_dims
[
0
][
j
]
>
0
&&
inputs_dims
[
i
][
j
]
>
0
);
if
(
check_shape
)
{
// check all shape in run time
PADDLE_ENFORCE_EQ
(
inputs_dims
[
0
][
j
],
inputs_dims
[
i
][
j
],
platform
::
errors
::
InvalidArgument
(
"The %d-th dimension of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s]."
,
j
,
i
,
inputs_dims
[
0
],
i
,
inputs_dims
[
i
]));
}
if
(
!
is_runtime
&&
out_dims
[
j
]
==
-
1
&&
inputs_dims
[
i
][
j
]
>
0
)
{
out_dims
[
j
]
=
inputs_dims
[
i
][
j
];
}
}
}
}
return
out_dims
;
}
static
inline
int64_t
ComputeAxis
(
int64_t
axis
,
int64_t
rank
)
{
PADDLE_ENFORCE_EQ
(
...
...
@@ -109,67 +66,21 @@ class ConcatKernel : public framework::OpKernel<T> {
ins_dims
[
i
]
=
ins
[
i
]
->
dims
();
}
framework
::
DDim
out_dims
=
ComputeAndCheckShape
(
true
,
ins_dims
,
axis
);
framework
::
DDim
out_dims
=
pten
::
funcs
::
ComputeAndCheckShape
(
true
,
ins_dims
,
axis
);
out
->
Resize
(
out_dims
);
}
auto
place
=
ctx
.
GetPlace
();
out
->
mutable_data
<
T
>
(
place
);
// If axis is 0, the lod of the output is not the same as inputs.
if
(
axis
==
0
&&
ins
[
0
]
->
lod
().
size
()
>
0
)
{
size_t
lod_size_0
=
ins
[
0
]
->
lod
().
size
();
size_t
lod_size
=
lod_size_0
;
for
(
size_t
i
=
1
;
i
<
ins
.
size
();
++
i
)
{
if
(
ins
[
i
]
->
lod
().
size
()
>
0
)
{
PADDLE_ENFORCE_EQ
(
ins
[
i
]
->
lod
().
size
(),
lod_size_0
,
platform
::
errors
::
Unimplemented
(
"The lod level of all input LoDTensors should be same. "
"Maybe different lod level of input LoDTensors can concat,"
"it is not supported currently. The lod level of %dth input "
"is %d and first input is %d."
,
i
,
ins
[
i
]
->
lod
().
size
(),
lod_size_0
));
}
else
{
lod_size
=
0
;
break
;
}
}
if
(
lod_size
)
{
auto
*
out_lod
=
out
->
mutable_lod
();
for
(
size_t
i
=
1
;
i
<
ins
.
size
();
++
i
)
{
auto
in_lod
=
ConvertToLengthBasedLoD
(
ins
[
i
]
->
lod
());
AppendLoD
(
out_lod
,
in_lod
);
}
}
// call new kernel
auto
&
dev_ctx
=
ctx
.
device_context
<
DeviceContext
>
();
std
::
vector
<
pten
::
DenseTensor
>
pt_ins
;
for
(
auto
&
in
:
ins
)
{
pt_ins
.
push_back
(
*
in
);
}
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if
(
axis
==
0
&&
ins
.
size
()
<
10
)
{
size_t
output_offset
=
0
;
for
(
auto
*
in
:
ins
)
{
if
(
!
in
||
in
->
numel
()
==
0UL
)
{
continue
;
}
auto
in_stride
=
framework
::
stride_numel
(
in
->
dims
());
auto
out_stride
=
framework
::
stride_numel
(
out
->
dims
());
StridedNumelCopyWithAxis
<
T
>
(
ctx
.
device_context
(),
axis
,
out
->
data
<
T
>
()
+
output_offset
,
out_stride
,
in
->
data
<
T
>
(),
in_stride
,
in_stride
[
axis
]);
output_offset
+=
in_stride
[
axis
];
}
}
else
{
std
::
vector
<
framework
::
Tensor
>
inputs
;
for
(
size_t
j
=
0
;
j
<
ins
.
size
();
++
j
)
{
if
(
ins
[
j
]
&&
ins
[
j
]
->
numel
()
>
0
)
{
inputs
.
push_back
(
*
ins
[
j
]);
}
else
{
continue
;
}
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
paddle
::
operators
::
math
::
ConcatFunctor
<
DeviceContext
,
T
>
concat_functor
;
concat_functor
(
dev_ctx
,
inputs
,
static_cast
<
int
>
(
axis
),
out
);
}
pten
::
ConcatKernel
<
T
>
(
dev_ctx
,
pt_ins
,
axis
,
out
);
}
};
...
...
paddle/fluid/operators/concat_op_xpu.cc
浏览文件 @
06803c29
...
...
@@ -18,6 +18,8 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/pten/core/lod_utils.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
...
...
@@ -69,8 +71,8 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
if
(
lod_size
)
{
auto
*
out_lod
=
out
->
mutable_lod
();
for
(
size_t
i
=
1
;
i
<
ins
.
size
();
++
i
)
{
auto
in_lod
=
ConvertToLengthBasedLoD
(
ins
[
i
]
->
lod
());
AppendLoD
(
out_lod
,
in_lod
);
auto
in_lod
=
pten
::
ConvertToLengthBasedLoD
(
ins
[
i
]
->
lod
());
pten
::
AppendLoD
(
out_lod
,
in_lod
);
}
}
}
...
...
paddle/fluid/operators/lod_tensor_to_array_op.cc
浏览文件 @
06803c29
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/core/lod_utils.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -134,7 +135,7 @@ class LoDTensorToArrayOp : public framework::OperatorBase {
auto
lod_and_offset
=
framework
::
GetSubLoDAndAbsoluteOffset
(
x
.
lod
(),
start_idx
,
start_idx
+
1
,
rank_level
+
1
);
auto
&
lod_length
=
lod_and_offset
.
first
;
framework
::
AppendLoD
(
&
lod
,
lod_length
);
pten
::
AppendLoD
(
&
lod
,
lod_length
);
size_t
start_offset
=
lod_and_offset
.
second
.
first
;
size_t
end_offset
=
lod_and_offset
.
second
.
second
;
copy_ranges
[
t
].
emplace_back
(
CopyRange
{
start_offset
,
end_offset
});
...
...
paddle/fluid/operators/math/concat_and_split.cc
浏览文件 @
06803c29
...
...
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/pten/kernels/cpu/concat_and_split.h"
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#endif
...
...
@@ -44,36 +46,9 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
std
::
vector
<
framework
::
Tensor
>&
input
,
int
axis
,
framework
::
Tensor
*
output
)
{
// TODO(zcd): Add input data validity checking
size_t
num
=
input
.
size
();
int64_t
rows
=
1
;
auto
dim_0
=
input
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
rows
*=
dim_0
[
i
];
}
int64_t
out_rows
=
rows
,
out_cols
=
0
;
std
::
vector
<
int64_t
>
input_cols
(
input
.
size
());
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int64_t
t_cols
=
input
[
i
].
numel
()
/
rows
;
out_cols
+=
t_cols
;
input_cols
[
i
]
=
t_cols
;
}
auto
cpu_place
=
context
.
GetPlace
();
// computation
auto
output_data
=
output
->
data
<
T
>
();
int64_t
col_idx
=
0
;
for
(
size_t
j
=
0
;
j
<
num
;
++
j
)
{
int64_t
col_len
=
input_cols
[
j
];
auto
input_data
=
input
[
j
].
data
<
T
>
();
for
(
int64_t
k
=
0
;
k
<
out_rows
;
++
k
)
{
memory
::
Copy
(
cpu_place
,
output_data
+
k
*
out_cols
+
col_idx
,
cpu_place
,
input_data
+
k
*
col_len
,
sizeof
(
T
)
*
col_len
);
}
col_idx
+=
col_len
;
}
std
::
vector
<
pten
::
DenseTensor
>
pt_input
{
input
.
begin
(),
input
.
end
()};
pten
::
ConcatImpl
<
T
,
platform
::
CPUDeviceContext
>
(
context
,
pt_input
,
axis
,
output
);
}
};
...
...
@@ -88,46 +63,12 @@ class SplitFunctor<platform::CPUDeviceContext, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
const
framework
::
Tensor
*>&
ref_inputs
,
const
int
axis
,
std
::
vector
<
framework
::
Tensor
*>*
outputs
)
{
// NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
// tensors of shape [0,1,4]
if
(
input
.
numel
()
==
0
)
{
return
;
}
// TODO(zcd): Add input data validity checking
size_t
num
=
outputs
->
size
();
int
input_rows
=
1
;
auto
dim_0
=
ref_inputs
[
0
]
->
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
input_rows
*=
dim_0
[
i
];
}
int
input_cols
=
0
;
std
::
vector
<
int64_t
>
output_cols
(
outputs
->
size
());
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
ref_inputs
[
i
]
->
numel
()
/
input_rows
;
input_cols
+=
t_cols
;
output_cols
[
i
]
=
t_cols
;
}
auto
cpu_place
=
context
.
GetPlace
();
// computation
for
(
int
k
=
0
;
k
<
input_rows
;
++
k
)
{
const
T
*
src_ptr
=
input
.
data
<
T
>
()
+
k
*
input_cols
;
int
col_idx
=
0
;
for
(
size_t
j
=
0
;
j
<
num
;
++
j
)
{
int
col_len
=
output_cols
[
j
];
auto
*
out_tensor
=
outputs
->
at
(
j
);
if
(
out_tensor
!=
nullptr
)
{
T
*
dst_ptr
=
out_tensor
->
data
<
T
>
()
+
k
*
col_len
;
memory
::
Copy
(
cpu_place
,
dst_ptr
,
cpu_place
,
src_ptr
+
col_idx
,
sizeof
(
T
)
*
col_len
);
}
col_idx
+=
col_len
;
}
}
std
::
vector
<
const
pten
::
DenseTensor
*>
pt_ref_inputs
{
ref_inputs
.
begin
(),
ref_inputs
.
end
()};
std
::
vector
<
pten
::
DenseTensor
*>
pt_outputs
{
outputs
->
begin
(),
outputs
->
end
()};
pten
::
SplitImpl
<
T
,
platform
::
CPUDeviceContext
>
(
context
,
input
,
pt_ref_inputs
,
axis
,
&
pt_outputs
);
}
};
...
...
paddle/fluid/operators/math/concat_and_split.cu
浏览文件 @
06803c29
...
...
@@ -12,218 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <vector>
#include "gflags/gflags.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/kernels/gpu/concat_and_split.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
__global__
void
ConcatKernel
(
const
T
**
inputs
,
const
int64_t
*
input_cols
,
int
col_size
,
const
int64_t
output_rows
,
const
int64_t
output_cols
,
T
*
output
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
curr_segment
=
0
;
int
curr_offset
=
input_cols
[
0
];
for
(;
tid_x
<
output_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
curr_col_offset
=
input_cols
[
curr_segment
+
1
];
while
(
curr_col_offset
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
curr_col_offset
=
input_cols
[
curr_segment
+
1
];
}
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
const
T
*
input_ptr
=
inputs
[
curr_segment
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
output_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output
[
tid_y
*
output_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
segment_width
+
local_col
];
}
}
template
<
typename
T
>
__device__
void
ConcatKernelDetail
(
const
T
**
inputs_data
,
const
int
fixed_in_col
,
const
int
out_rows
,
const
int
out_cols
,
T
*
output_data
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
tid_x
<
out_cols
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
*
1.0
/
fixed_in_col
;
int
in_offset
=
tid_x
-
split
*
fixed_in_col
;
const
T
*
input_ptr
=
inputs_data
[
split
];
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
out_rows
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
{
output_data
[
tid_y
*
out_cols
+
tid_x
]
=
input_ptr
[
tid_y
*
fixed_in_col
+
in_offset
];
}
}
}
template
<
typename
T
>
__global__
void
ConcatKernel
(
const
T
*
input_addr0
,
const
T
*
input_addr1
,
const
int64_t
fixed_in_col
,
const
int64_t
out_rows
,
const
int64_t
out_cols
,
T
*
output_data
)
{
const
T
*
inputs_data
[
2
];
inputs_data
[
0
]
=
input_addr0
;
inputs_data
[
1
]
=
input_addr1
;
ConcatKernelDetail
<
T
>
(
inputs_data
,
fixed_in_col
,
out_rows
,
out_cols
,
output_data
);
}
template
<
typename
T
>
__global__
void
ConcatKernel
(
const
T
*
input_addr0
,
const
T
*
input_addr1
,
const
T
*
input_addr2
,
const
int64_t
fixed_in_col
,
const
int64_t
out_rows
,
const
int64_t
out_cols
,
T
*
output_data
)
{
const
T
*
inputs_data
[
3
];
inputs_data
[
0
]
=
input_addr0
;
inputs_data
[
1
]
=
input_addr1
;
inputs_data
[
2
]
=
input_addr2
;
ConcatKernelDetail
<
T
>
(
inputs_data
,
fixed_in_col
,
out_rows
,
out_cols
,
output_data
);
}
template
<
typename
T
>
__global__
void
ConcatKernel
(
const
T
*
input_addr0
,
const
T
*
input_addr1
,
const
T
*
input_addr2
,
const
T
*
input_addr3
,
const
int64_t
fixed_in_col
,
const
int64_t
out_rows
,
const
int64_t
out_cols
,
T
*
output_data
)
{
const
T
*
inputs_data
[
4
];
inputs_data
[
0
]
=
input_addr0
;
inputs_data
[
1
]
=
input_addr1
;
inputs_data
[
2
]
=
input_addr2
;
inputs_data
[
3
]
=
input_addr3
;
ConcatKernelDetail
<
T
>
(
inputs_data
,
fixed_in_col
,
out_rows
,
out_cols
,
output_data
);
}
template
<
typename
T
>
__global__
void
ConcatKernel
(
const
T
**
inputs_data
,
const
int
in_num
,
const
int64_t
fixed_in_col
,
const
int64_t
out_rows
,
const
int64_t
out_cols
,
T
*
output_data
)
{
ConcatKernelDetail
<
T
>
(
inputs_data
,
fixed_in_col
,
out_rows
,
out_cols
,
output_data
);
}
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int64_t
in_row
,
const
int64_t
in_col
,
const
int64_t
*
out_cols
,
int
out_cols_size
,
T
**
outputs_data
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
curr_segment
=
0
;
int
curr_offset
=
out_cols
[
0
];
for
(;
tid_x
<
in_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
curr_col_offset
=
out_cols
[
curr_segment
+
1
];
while
(
curr_col_offset
<=
tid_x
)
{
curr_offset
=
curr_col_offset
;
++
curr_segment
;
curr_col_offset
=
out_cols
[
curr_segment
+
1
];
}
int
local_col
=
tid_x
-
curr_offset
;
int
segment_width
=
curr_col_offset
-
curr_offset
;
T
*
output_ptr
=
outputs_data
[
curr_segment
];
if
(
output_ptr
!=
nullptr
)
{
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
in_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
segment_width
+
local_col
]
=
input_data
[
tid_y
*
in_col
+
tid_x
];
}
}
}
template
<
typename
T
>
__device__
void
SplitKernelDetail
(
const
T
*
input_data
,
const
int
in_row
,
const
int
in_col
,
const
int
fixed_out_col
,
T
**
outputs_data
)
{
int
tid_x
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
for
(;
tid_x
<
in_col
;
tid_x
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
split
=
tid_x
/
fixed_out_col
;
int
in_offset
=
tid_x
-
split
*
fixed_out_col
;
T
*
output_ptr
=
outputs_data
[
split
];
if
(
output_ptr
!=
nullptr
)
{
int
tid_y
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
for
(;
tid_y
<
in_row
;
tid_y
+=
blockDim
.
y
*
gridDim
.
y
)
output_ptr
[
tid_y
*
fixed_out_col
+
in_offset
]
=
input_data
[
tid_y
*
in_col
+
tid_x
];
}
}
}
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int64_t
in_row
,
const
int64_t
in_col
,
const
int64_t
fixed_out_col
,
T
**
outputs_data
)
{
SplitKernelDetail
<
T
>
(
input_data
,
in_row
,
in_col
,
fixed_out_col
,
outputs_data
);
}
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int64_t
in_row
,
const
int64_t
in_col
,
const
int64_t
fixed_out_col
,
T
*
outputs_addr0
,
T
*
outputs_addr1
)
{
T
*
outputs_data
[
2
];
outputs_data
[
0
]
=
outputs_addr0
;
outputs_data
[
1
]
=
outputs_addr1
;
SplitKernelDetail
<
T
>
(
input_data
,
in_row
,
in_col
,
fixed_out_col
,
outputs_data
);
}
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int64_t
in_row
,
const
int64_t
in_col
,
const
int64_t
fixed_out_col
,
T
*
outputs_addr0
,
T
*
outputs_addr1
,
T
*
outputs_addr2
)
{
T
*
outputs_data
[
3
];
outputs_data
[
0
]
=
outputs_addr0
;
outputs_data
[
1
]
=
outputs_addr1
;
outputs_data
[
2
]
=
outputs_addr2
;
SplitKernelDetail
<
T
>
(
input_data
,
in_row
,
in_col
,
fixed_out_col
,
outputs_data
);
}
template
<
typename
T
>
__global__
void
SplitKernel
(
const
T
*
input_data
,
const
int64_t
in_row
,
const
int64_t
in_col
,
const
int64_t
fixed_out_col
,
T
*
outputs_addr0
,
T
*
outputs_addr1
,
T
*
outputs_addr2
,
T
*
outputs_addr3
)
{
T
*
outputs_data
[
4
];
outputs_data
[
0
]
=
outputs_addr0
;
outputs_data
[
1
]
=
outputs_addr1
;
outputs_data
[
2
]
=
outputs_addr2
;
outputs_data
[
3
]
=
outputs_addr3
;
SplitKernelDetail
<
T
>
(
input_data
,
in_row
,
in_col
,
fixed_out_col
,
outputs_data
);
}
static
inline
void
GetBlockDims
(
const
platform
::
CUDADeviceContext
&
context
,
int64_t
num_rows
,
int64_t
num_cols
,
dim3
*
block_dims
,
dim3
*
grid_dims
)
{
// Set the thread block and grid according to CurrentDeviceId
const
int
kThreadsPerBlock
=
1024
;
int
block_cols
=
kThreadsPerBlock
;
if
(
num_cols
<
kThreadsPerBlock
)
{
// block_cols is aligned by 32.
block_cols
=
((
num_cols
+
31
)
>>
5
)
<<
5
;
}
int
block_rows
=
kThreadsPerBlock
/
block_cols
;
*
block_dims
=
dim3
(
block_cols
,
block_rows
,
1
);
int
max_threads
=
context
.
GetMaxPhysicalThreadCount
();
int64_t
max_blocks
=
std
::
max
(
max_threads
/
kThreadsPerBlock
,
1
);
int
grid_cols
=
std
::
min
((
num_cols
+
block_cols
-
1
)
/
block_cols
,
max_blocks
);
int
grid_rows
=
std
::
min
(
max_blocks
/
grid_cols
,
std
::
max
(
num_rows
/
block_rows
,
(
int64_t
)
1
));
*
grid_dims
=
dim3
(
grid_cols
,
grid_rows
,
1
);
}
/*
* All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension.
...
...
@@ -234,112 +29,10 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
const
std
::
vector
<
framework
::
Tensor
>&
input
,
int
axis
,
framework
::
Tensor
*
output
)
{
// TODO(zcd): Add input data validity checking
int
in_num
=
input
.
size
();
int64_t
in_row
=
1
;
auto
dim_0
=
input
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
in_row
*=
dim_0
[
i
];
}
int64_t
in_col
=
input
[
0
].
numel
()
/
in_row
;
int64_t
out_row
=
in_row
,
out_col
=
0
;
int
inputs_col_num
=
in_num
+
1
;
std
::
vector
<
const
T
*>
inputs_data_vec
(
in_num
);
std
::
vector
<
int64_t
>
inputs_col_vec
(
inputs_col_num
);
const
T
**
inputs_data
=
inputs_data_vec
.
data
();
int64_t
*
inputs_col
=
inputs_col_vec
.
data
();
// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
memory
::
AllocationPtr
data_alloc
,
col_alloc
;
data_alloc
=
memory
::
Alloc
(
platform
::
CUDAPinnedPlace
(),
in_num
*
sizeof
(
T
*
));
inputs_data
=
reinterpret_cast
<
const
T
**>
(
data_alloc
->
ptr
());
col_alloc
=
memory
::
Alloc
(
platform
::
CUDAPinnedPlace
(),
inputs_col_num
*
sizeof
(
int
));
inputs_col
=
reinterpret_cast
<
int64_t
*>
(
col_alloc
->
ptr
());
#endif
inputs_col
[
0
]
=
0
;
bool
has_same_shape
=
true
;
for
(
int
i
=
0
;
i
<
in_num
;
++
i
)
{
int64_t
t_cols
=
input
[
i
].
numel
()
/
in_row
;
if
(
has_same_shape
)
{
if
(
t_cols
!=
in_col
)
has_same_shape
=
false
;
}
out_col
+=
t_cols
;
inputs_col
[
i
+
1
]
=
out_col
;
inputs_data
[
i
]
=
input
[
i
].
data
<
T
>
();
}
dim3
block_dims
;
dim3
grid_dims
;
GetBlockDims
(
context
,
out_row
,
out_col
,
&
block_dims
,
&
grid_dims
);
std
::
vector
<
pten
::
DenseTensor
>
pt_input
{
input
.
begin
(),
input
.
end
()};
memory
::
allocation
::
AllocationPtr
tmp_dev_ins_data
;
const
T
**
dev_ins_data
=
nullptr
;
if
(
!
has_same_shape
||
in_num
<
2
||
in_num
>
4
)
{
tmp_dev_ins_data
=
memory
::
Alloc
(
context
,
in_num
*
sizeof
(
T
*
));
auto
*
restored
=
platform
::
RestoreHostMemIfCapturingCUDAGraph
(
inputs_data
,
in_num
);
memory
::
Copy
(
context
.
GetPlace
(),
tmp_dev_ins_data
->
ptr
(),
platform
::
CPUPlace
(),
restored
,
in_num
*
sizeof
(
T
*
),
context
.
stream
());
dev_ins_data
=
reinterpret_cast
<
const
T
**>
(
tmp_dev_ins_data
->
ptr
());
}
if
(
has_same_shape
)
{
if
(
in_num
==
2
)
{
ConcatKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
inputs_data
[
0
],
inputs_data
[
1
],
in_col
,
out_row
,
out_col
,
output
->
data
<
T
>
());
}
else
if
(
in_num
==
3
)
{
ConcatKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
inputs_data
[
0
],
inputs_data
[
1
],
inputs_data
[
2
],
in_col
,
out_row
,
out_col
,
output
->
data
<
T
>
());
}
else
if
(
in_num
==
4
)
{
ConcatKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
inputs_data
[
0
],
inputs_data
[
1
],
inputs_data
[
2
],
inputs_data
[
3
],
in_col
,
out_row
,
out_col
,
output
->
data
<
T
>
());
}
else
{
ConcatKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
dev_ins_data
,
in_num
,
in_col
,
out_row
,
out_col
,
output
->
data
<
T
>
());
}
}
else
{
auto
tmp_dev_ins_col_data
=
memory
::
Alloc
(
context
,
inputs_col_num
*
sizeof
(
int64_t
));
auto
*
restored
=
platform
::
RestoreHostMemIfCapturingCUDAGraph
(
inputs_col
,
inputs_col_num
);
memory
::
Copy
(
context
.
GetPlace
(),
tmp_dev_ins_col_data
->
ptr
(),
platform
::
CPUPlace
(),
restored
,
inputs_col_num
*
sizeof
(
int64_t
),
context
.
stream
());
int64_t
*
dev_ins_col_data
=
static_cast
<
int64_t
*>
(
tmp_dev_ins_col_data
->
ptr
());
ConcatKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
dev_ins_data
,
dev_ins_col_data
,
static_cast
<
int
>
(
inputs_col_num
),
out_row
,
out_col
,
output
->
data
<
T
>
());
}
#ifdef PADDLE_WITH_HIP
// Prevent the pinned memory value from being covered and release the memory
// after the launch kernel of the stream is executed (reapply pinned memory
// next time)
auto
*
data_alloc_released
=
data_alloc
.
release
();
auto
*
col_alloc_released
=
col_alloc
.
release
();
context
.
AddStreamCallback
([
data_alloc_released
,
col_alloc_released
]
{
memory
::
allocation
::
Allocator
::
AllocationDeleter
(
data_alloc_released
);
memory
::
allocation
::
Allocator
::
AllocationDeleter
(
col_alloc_released
);
});
#endif
pten
::
ConcatImpl
<
T
,
platform
::
CUDADeviceContext
>
(
context
,
pt_input
,
axis
,
output
);
}
};
...
...
@@ -355,120 +48,12 @@ class SplitFunctor<platform::CUDADeviceContext, T> {
const
framework
::
Tensor
&
input
,
const
std
::
vector
<
const
framework
::
Tensor
*>&
ref_inputs
,
int
axis
,
std
::
vector
<
framework
::
Tensor
*>*
outputs
)
{
// NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
// tensors of shape [0,1,4]
if
(
input
.
numel
()
==
0
)
{
return
;
}
// TODO(zcd): Add input data validity checking
int
o_num
=
outputs
->
size
();
int64_t
out_row
=
1
;
auto
dim_0
=
ref_inputs
[
0
]
->
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
out_row
*=
dim_0
[
i
];
}
int64_t
out0_col
=
ref_inputs
[
0
]
->
numel
()
/
out_row
;
int64_t
in_col
=
0
,
in_row
=
out_row
;
bool
has_same_shape
=
true
;
int
outputs_cols_num
=
o_num
+
1
;
std
::
vector
<
T
*>
outputs_data_vec
(
o_num
);
std
::
vector
<
int64_t
>
outputs_cols_vec
(
outputs_cols_num
);
T
**
outputs_data
=
outputs_data_vec
.
data
();
int64_t
*
outputs_cols
=
outputs_cols_vec
.
data
();
// There are some differences between hip runtime and NV runtime.
// In NV, when the pageable memory data less than 64K is transferred from
// hosttodevice, it will be automatically asynchronous.
// However, only pinned memory in hip can copy asynchronously
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device
// 3.2.6.1. Concurrent Execution between Host and Device
// Memory copies from host to device of a memory block of 64 KB or less
#ifdef PADDLE_WITH_HIP
memory
::
AllocationPtr
data_alloc
,
cols_alloc
;
data_alloc
=
memory
::
Alloc
(
platform
::
CUDAPinnedPlace
(),
o_num
*
sizeof
(
T
*
));
outputs_data
=
reinterpret_cast
<
T
**>
(
data_alloc
->
ptr
());
cols_alloc
=
memory
::
Alloc
(
platform
::
CUDAPinnedPlace
(),
(
outputs_cols_num
)
*
sizeof
(
int64_t
));
outputs_cols
=
reinterpret_cast
<
int64_t
*>
(
cols_alloc
->
ptr
());
#endif
outputs_cols
[
0
]
=
0
;
for
(
int
i
=
0
;
i
<
o_num
;
++
i
)
{
int64_t
t_col
=
ref_inputs
.
at
(
i
)
->
numel
()
/
out_row
;
if
(
has_same_shape
)
{
if
(
t_col
!=
out0_col
)
has_same_shape
=
false
;
}
in_col
+=
t_col
;
outputs_cols
[
i
+
1
]
=
in_col
;
if
(
outputs
->
at
(
i
)
!=
nullptr
)
{
outputs_data
[
i
]
=
outputs
->
at
(
i
)
->
data
<
T
>
();
}
else
{
outputs_data
[
i
]
=
nullptr
;
}
}
dim3
block_dims
;
dim3
grid_dims
;
GetBlockDims
(
context
,
out_row
,
in_col
,
&
block_dims
,
&
grid_dims
);
memory
::
allocation
::
AllocationPtr
tmp_dev_outs_data
;
T
**
dev_out_gpu_data
=
nullptr
;
if
(
!
has_same_shape
||
o_num
<
2
||
o_num
>
4
)
{
tmp_dev_outs_data
=
memory
::
Alloc
(
context
,
o_num
*
sizeof
(
T
*
));
auto
*
restored
=
platform
::
RestoreHostMemIfCapturingCUDAGraph
(
outputs_data
,
o_num
);
memory
::
Copy
(
context
.
GetPlace
(),
tmp_dev_outs_data
->
ptr
(),
platform
::
CPUPlace
(),
restored
,
o_num
*
sizeof
(
T
*
),
context
.
stream
());
dev_out_gpu_data
=
reinterpret_cast
<
T
**>
(
tmp_dev_outs_data
->
ptr
());
}
if
(
has_same_shape
)
{
if
(
o_num
==
2
)
{
SplitKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
in_row
,
in_col
,
out0_col
,
outputs_data
[
0
],
outputs_data
[
1
]);
}
else
if
(
o_num
==
3
)
{
SplitKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
in_row
,
in_col
,
out0_col
,
outputs_data
[
0
],
outputs_data
[
1
],
outputs_data
[
2
]);
}
else
if
(
o_num
==
4
)
{
SplitKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
in_row
,
in_col
,
out0_col
,
outputs_data
[
0
],
outputs_data
[
1
],
outputs_data
[
2
],
outputs_data
[
3
]);
}
else
{
SplitKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
in_row
,
in_col
,
out0_col
,
dev_out_gpu_data
);
}
}
else
{
auto
tmp_dev_ins_col_data
=
memory
::
Alloc
(
context
,
outputs_cols_num
*
sizeof
(
int64_t
));
auto
*
restored
=
platform
::
RestoreHostMemIfCapturingCUDAGraph
(
outputs_cols
,
outputs_cols_num
);
memory
::
Copy
(
context
.
GetPlace
(),
tmp_dev_ins_col_data
->
ptr
(),
platform
::
CPUPlace
(),
restored
,
outputs_cols_num
*
sizeof
(
int64_t
),
context
.
stream
());
int64_t
*
dev_outs_col_data
=
reinterpret_cast
<
int64_t
*>
(
tmp_dev_ins_col_data
->
ptr
());
SplitKernel
<<<
grid_dims
,
block_dims
,
0
,
context
.
stream
()
>>>
(
input
.
data
<
T
>
(),
in_row
,
in_col
,
dev_outs_col_data
,
static_cast
<
int
>
(
outputs_cols_num
),
dev_out_gpu_data
);
}
#ifdef PADDLE_WITH_HIP
// Prevent the pinned memory value from being covered and release the memory
// after the launch kernel of the stream is executed (reapply pinned memory
// next time)
auto
*
data_alloc_released
=
data_alloc
.
release
();
auto
*
cols_alloc_released
=
cols_alloc
.
release
();
context
.
AddStreamCallback
([
data_alloc_released
,
cols_alloc_released
]
{
memory
::
allocation
::
Allocator
::
AllocationDeleter
(
data_alloc_released
);
memory
::
allocation
::
Allocator
::
AllocationDeleter
(
cols_alloc_released
);
});
#endif
std
::
vector
<
const
pten
::
DenseTensor
*>
pt_ref_inputs
{
ref_inputs
.
begin
(),
ref_inputs
.
end
()};
std
::
vector
<
pten
::
DenseTensor
*>
pt_outputs
{
outputs
->
begin
(),
outputs
->
end
()};
pten
::
SplitImpl
<
T
,
platform
::
CUDADeviceContext
>
(
context
,
input
,
pt_ref_inputs
,
axis
,
&
pt_outputs
);
}
};
...
...
paddle/fluid/operators/merge_lod_tensor_op.cc
浏览文件 @
06803c29
...
...
@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/pten/core/lod_utils.h"
namespace
pten
{
class
DenseTensor
;
}
// namespace pten
...
...
@@ -122,7 +124,7 @@ class MergeLoDTensorOp : public framework::OperatorBase {
input
->
lod
(),
*
in_idx
,
(
*
in_idx
)
+
1
,
0
);
auto
&
lod_length
=
lod_and_offset
.
first
;
framework
::
AppendLoD
(
out_lod
,
lod_length
);
pten
::
AppendLoD
(
out_lod
,
lod_length
);
size_t
start_offset
=
lod_and_offset
.
second
.
first
;
size_t
end_offset
=
lod_and_offset
.
second
.
second
;
...
...
paddle/fluid/operators/shrink_rnn_memory_op.cc
浏览文件 @
06803c29
...
...
@@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/array_operator.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/pten/core/lod_utils.h"
namespace
paddle
{
namespace
framework
{
class
OpDesc
;
...
...
@@ -73,7 +75,7 @@ class ShrinkRNNMemoryOp : public ArrayOp {
dst_num_rows
,
0
);
height
=
lod_offset
.
second
.
second
;
auto
out_lod
=
out_tensor
.
mutable_lod
();
framework
::
AppendLoD
(
out_lod
,
lod_offset
.
first
);
pten
::
AppendLoD
(
out_lod
,
lod_offset
.
first
);
}
if
(
dst_num_rows
!=
0
)
{
...
...
paddle/fluid/operators/split_lod_tensor_op.cc
浏览文件 @
06803c29
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/core/lod_utils.h"
namespace
pten
{
class
DenseTensor
;
...
...
@@ -96,7 +97,7 @@ class SplitLoDTensorOp : public framework::OperatorBase {
x_lod
,
start_idx
,
start_idx
+
1
,
level
);
auto
&
lod_length
=
lod_and_offset
.
first
;
framework
::
AppendLoD
(
lod
,
lod_length
);
pten
::
AppendLoD
(
lod
,
lod_length
);
size_t
start_offset
=
lod_and_offset
.
second
.
first
;
size_t
end_offset
=
lod_and_offset
.
second
.
second
;
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
06803c29
...
...
@@ -43,7 +43,6 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/generate_pass.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include "paddle/fluid/framework/op_info.h"
...
...
@@ -75,6 +74,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/cuda_streams_py.h"
#include "paddle/pten/core/lod_utils.h"
#ifndef PADDLE_ON_INFERENCE
#include "paddle/fluid/pybind/eager.h"
#endif
...
...
@@ -1093,7 +1093,7 @@ PYBIND11_MODULE(core_noavx, m) {
.
def
(
"recursive_sequence_lengths"
,
[](
framework
::
Tensor
&
self
)
->
std
::
vector
<
std
::
vector
<
size_t
>>
{
// output the length-based lod info
LoD
lod
=
ConvertToLengthBasedLoD
(
self
.
lod
());
LoD
lod
=
pten
::
ConvertToLengthBasedLoD
(
self
.
lod
());
std
::
vector
<
std
::
vector
<
size_t
>>
new_lod
;
new_lod
.
reserve
(
lod
.
size
());
std
::
copy
(
lod
.
begin
(),
lod
.
end
(),
std
::
back_inserter
(
new_lod
));
...
...
paddle/pten/CMakeLists.txt
浏览文件 @
06803c29
...
...
@@ -18,7 +18,7 @@ add_subdirectory(ops)
add_subdirectory
(
tests
)
# make an unity target for compile deps
set
(
PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta
)
set
(
PTEN_DEPS convert_utils dense_tensor pten_context kernel_factory kernel_context arg_map_context infermeta
lod_utils
)
get_property
(
pten_kernels GLOBAL PROPERTY PTEN_KERNELS
)
# keep this message for debug, remove it later if needless
message
(
STATUS
"All standard pten kernels:
${
pten_kernels
}
"
)
...
...
paddle/pten/api/include/kernel_signature.h
浏览文件 @
06803c29
...
...
@@ -38,6 +38,11 @@ using cast_kernel = void (*)(const DeviceContext&,
DataType
,
DenseTensor
*
);
using
concat_kernel
=
void
(
*
)(
const
DeviceContext
&
,
const
std
::
vector
<
DenseTensor
>&
,
const
Scalar
&
,
DenseTensor
*
);
using
divide_kernel
=
void
(
*
)(
const
DeviceContext
&
,
const
DenseTensor
&
,
const
DenseTensor
&
,
...
...
paddle/pten/api/lib/utils/tensor_utils.cc
浏览文件 @
06803c29
...
...
@@ -38,6 +38,11 @@ std::unique_ptr<pten::DenseTensor> MakePtenDenseTensorBase(
src
.
dims
(),
src
.
layout
(),
src
.
offset
()};
if
(
!
src
.
IsInitialized
())
{
return
std
::
make_unique
<
pten
::
DenseTensor
>
(
std
::
move
(
pten
::
make_intrusive
<
SharedStorage
>
(
src
.
place
())),
std
::
move
(
meta
));
}
auto
shared_storage
=
pten
::
make_intrusive
<
SharedStorage
>
(
src
.
Holder
());
return
std
::
make_unique
<
pten
::
DenseTensor
>
(
std
::
move
(
shared_storage
),
std
::
move
(
meta
));
...
...
@@ -247,7 +252,9 @@ std::unique_ptr<pten::TensorBase> MakePtenTensorBaseFromVar(
if
(
variable
.
IsType
<
framework
::
LoDTensor
>
())
{
const
auto
&
tensor
=
variable
.
Get
<
framework
::
LoDTensor
>
();
if
(
!
platform
::
is_same_place
(
tensor
.
place
(),
expected_place
))
{
if
(
tensor
.
IsInitialized
()
&&
!
platform
::
is_same_place
(
tensor
.
place
(),
expected_place
))
{
framework
::
LoDTensor
tmp_tensor
;
framework
::
TensorCopySync
(
tensor
,
expected_place
,
&
tmp_tensor
);
return
MakePtenDenseTensor
(
tmp_tensor
);
...
...
paddle/pten/core/CMakeLists.txt
浏览文件 @
06803c29
...
...
@@ -12,6 +12,7 @@ cc_library(arg_map_context SRCS arg_map_context.cc DEPS enforce)
cc_library
(
tensor_base SRCS tensor_base.cc allocator.cc storage.cc DEPS enforce
)
cc_library
(
tensor_meta SRCS tensor_meta.cc DEPS enforce mixed_vector
)
cc_library
(
lod_utils SRCS lod_utils.cc DEPS enforce mixed_vector
)
cc_library
(
dense_tensor SRCS dense_tensor.cc DEPS convert_utils tensor_meta tensor_base
)
cc_library
(
pten_device_context SRCS device_context.cc DEPS tensor_base
)
...
...
paddle/pten/core/kernel_context.h
浏览文件 @
06803c29
...
...
@@ -92,7 +92,7 @@ class KernelContext {
std
::
vector
<
TensorType
>
MoveInputsBetween
(
size_t
start
,
size_t
end
)
{
std
::
vector
<
TensorType
>
v
;
for
(
size_t
i
=
start
;
i
<
end
;
++
i
)
{
auto
t
=
st
d
::
dynamic_pointer_cast
<
TensorType
>
(
inputs_
.
at
(
i
));
auto
t
=
st
atic_cast
<
const
TensorType
*
>
(
inputs_
.
at
(
i
));
v
.
emplace_back
(
*
t
);
inputs_
.
at
(
i
)
=
nullptr
;
}
...
...
paddle/pten/core/lod_utils.cc
0 → 100644
浏览文件 @
06803c29
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/lod_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace
pten
{
void
AppendLoD
(
LoD
*
lod
,
const
LoD
&
lod_length
)
{
PADDLE_ENFORCE
(
lod
->
empty
()
||
lod
->
size
()
==
lod_length
.
size
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"The input LoD length should be equal to the appended LoD size, but "
"received input LoD length is %d, actual LoD size is %d."
,
lod_length
.
size
(),
lod
->
size
()));
if
(
lod
->
empty
())
{
for
(
size_t
i
=
0
;
i
<
lod_length
.
size
();
++
i
)
{
lod
->
emplace_back
(
1
,
0
);
// size = 1, value = 0;
}
*
lod
=
LoD
(
lod_length
.
size
(),
std
::
vector
<
size_t
>
({
0
}));
}
for
(
size_t
i
=
0
;
i
<
lod
->
size
();
++
i
)
{
auto
&
level
=
(
*
lod
)[
i
];
for
(
size_t
len
:
lod_length
[
i
])
{
level
.
push_back
(
level
.
back
()
+
len
);
}
}
}
LoD
ConvertToLengthBasedLoD
(
const
LoD
&
offset_lod
)
{
LoD
length_lod
;
length_lod
.
reserve
(
offset_lod
.
size
());
for
(
size_t
lvl
=
0
;
lvl
<
offset_lod
.
size
();
++
lvl
)
{
std
::
vector
<
size_t
>
level
;
if
(
offset_lod
[
lvl
].
size
()
>
0
)
{
level
.
reserve
(
offset_lod
[
lvl
].
size
()
-
1
);
}
for
(
size_t
idx
=
0
;
idx
<
offset_lod
[
lvl
].
size
()
-
1
;
++
idx
)
{
level
.
push_back
(
offset_lod
[
lvl
][
idx
+
1
]
-
offset_lod
[
lvl
][
idx
]);
}
length_lod
.
push_back
(
level
);
}
return
length_lod
;
}
}
// namespace pten
paddle/pten/core/lod_utils.h
0 → 100644
浏览文件 @
06803c29
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/framework/mixed_vector.h"
namespace
pten
{
using
LoD
=
std
::
vector
<
paddle
::
framework
::
Vector
<
size_t
>>
;
void
AppendLoD
(
LoD
*
lod
,
const
LoD
&
lod_length
);
/*
* Convert between length-based LoD and offset-based LoD.
* The implementation of LoDTensor class use offset-based LoD.
* However, we want to expose the more user-friendly length-based
* LoD to the Python side instead.
*
* Example:
* If offset_lod = [[0, 2, 3],[0, 3, 5, 9]]
* then length_lod = [[2, 1], [3, 2, 4]]
*/
LoD
ConvertToLengthBasedLoD
(
const
LoD
&
offset_lod
);
}
// namespace pten
paddle/pten/infermeta/multiary.cc
浏览文件 @
06803c29
...
...
@@ -14,4 +14,43 @@ limitations under the License. */
#include "paddle/pten/infermeta/multiary.h"
namespace
pten
{}
// namespace pten
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/kernels/funcs/concat_funcs.h"
namespace
pten
{
DenseTensorMeta
ConcatInferMeta
(
const
std
::
vector
<
DenseTensorMeta
>&
x_meta
,
const
Scalar
&
axis_scalar
,
bool
is_runtime
)
{
PADDLE_ENFORCE_GE
(
x_meta
.
size
(),
0
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The size of input meta vector should be greater"
"than 0."
));
int
axis
=
axis_scalar
.
to
<
int
>
();
// 1. calculate axis
int
rank
=
x_meta
[
0
].
dims
.
size
();
PADDLE_ENFORCE_EQ
(
axis
>=
-
rank
&&
axis
<
rank
,
true
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The axis is expected to be in range of [%d, %d), but got %d"
,
-
rank
,
rank
,
axis
));
if
(
axis
<
0
)
{
axis
=
axis
+
rank
;
}
// 2. calculate out dims
std
::
vector
<
pten
::
DDim
>
x_dims
;
for
(
auto
meta
:
x_meta
)
{
x_dims
.
push_back
(
meta
.
dims
);
}
pten
::
DDim
out_dim
=
pten
::
funcs
::
ComputeAndCheckShape
(
is_runtime
,
x_dims
,
axis
);
return
{
x_meta
[
0
].
dtype
,
out_dim
,
x_meta
[
0
].
layout
};
}
}
// namespace pten
paddle/pten/infermeta/multiary.h
浏览文件 @
06803c29
...
...
@@ -14,4 +14,13 @@ limitations under the License. */
#pragma once
namespace
pten
{}
// namespace pten
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/tensor_meta.h"
namespace
pten
{
// TODO(chentianyu03) use std::vector<DenseTensor> as InferMeta inputs
DenseTensorMeta
ConcatInferMeta
(
const
std
::
vector
<
DenseTensorMeta
>&
x_meta
,
const
Scalar
&
axis_scalar
,
bool
is_runtime
);
}
// namespace pten
paddle/pten/kernels/CMakeLists.txt
浏览文件 @
06803c29
...
...
@@ -24,7 +24,7 @@ endif()
# pten depends all pten kernel targets
set_property
(
GLOBAL PROPERTY PTEN_KERNELS
""
)
set
(
COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory arg_map_context convert_utils
)
set
(
COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory arg_map_context convert_utils
lod_utils
)
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
eigen_function blas
)
# remove this dep after removing fluid deps on tensor creation
set
(
COMMON_KERNEL_DEPS
${
COMMON_KERNEL_DEPS
}
pten_api_utils
)
...
...
paddle/pten/kernels/concat_kernel.h
0 → 100644
浏览文件 @
06803c29
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/infermeta/multiary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace
pten
{
template
<
typename
T
,
typename
Context
>
void
ConcatKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
DenseTensor
>&
x
,
const
Scalar
&
axis
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
DenseTensor
Concat
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
DenseTensor
>&
x
,
const
Scalar
&
axis
)
{
std
::
vector
<
DenseTensorMeta
>
x_meta
;
for
(
auto
t
:
x
)
{
x_meta
.
push_back
(
t
.
meta
());
}
auto
out_meta
=
ConcatInferMeta
(
x_meta
,
axis
.
to
<
int
>
(),
true
);
auto
dense_out
=
pten
::
Empty
<
T
,
Context
>
(
dev_ctx
,
std
::
move
(
out_meta
));
ConcatKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
axis
,
&
dense_out
);
return
dense_out
;
}
}
// namespace pten
paddle/pten/kernels/cpu/concat_and_split.h
0 → 100644
浏览文件 @
06803c29
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
namespace
pten
{
/*
* \brief Concatenate the input tensors along the dimension axis.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input[0] = [[1,2],[3,4]]
* Input[1] = [[5,6]]
* axis = 0
*
* Output = [[1,2],
* [3,4],
* [5,6]]
*/
template
<
typename
T
,
typename
Context
>
void
ConcatImpl
(
const
Context
&
context
,
const
std
::
vector
<
DenseTensor
>&
input
,
int
axis
,
DenseTensor
*
output
)
{
// TODO(zcd): Add input data validity checking
size_t
num
=
input
.
size
();
int64_t
rows
=
1
;
auto
dim_0
=
input
[
0
].
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
rows
*=
dim_0
[
i
];
}
int64_t
out_rows
=
rows
,
out_cols
=
0
;
std
::
vector
<
int64_t
>
input_cols
(
input
.
size
());
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int64_t
t_cols
=
input
[
i
].
numel
()
/
rows
;
out_cols
+=
t_cols
;
input_cols
[
i
]
=
t_cols
;
}
auto
cpu_place
=
context
.
GetPlace
();
// computation
auto
output_data
=
output
->
data
<
T
>
();
int64_t
col_idx
=
0
;
for
(
size_t
j
=
0
;
j
<
num
;
++
j
)
{
int64_t
col_len
=
input_cols
[
j
];
auto
input_data
=
input
[
j
].
data
<
T
>
();
for
(
int64_t
k
=
0
;
k
<
out_rows
;
++
k
)
{
paddle
::
memory
::
Copy
(
cpu_place
,
output_data
+
k
*
out_cols
+
col_idx
,
cpu_place
,
input_data
+
k
*
col_len
,
sizeof
(
T
)
*
col_len
);
}
col_idx
+=
col_len
;
}
}
/*
* \brief Split the input tensors along the dimension axis into outputs.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input = [[1,2],
* [3,4],
* [5,6]]
* axis = 0
*
* Output[0] = [[1,2],[3,4]]
* Output[1] = [[5,6]]
*/
template
<
typename
T
,
typename
Context
>
void
SplitImpl
(
const
Context
&
context
,
const
DenseTensor
&
input
,
const
std
::
vector
<
const
DenseTensor
*>&
ref_inputs
,
const
int
axis
,
std
::
vector
<
DenseTensor
*>*
outputs
)
{
// NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
// tensors of shape [0,1,4]
if
(
input
.
numel
()
==
0
)
{
return
;
}
// TODO(zcd): Add input data validity checking
size_t
num
=
outputs
->
size
();
int
input_rows
=
1
;
auto
dim_0
=
ref_inputs
[
0
]
->
dims
();
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
input_rows
*=
dim_0
[
i
];
}
int
input_cols
=
0
;
std
::
vector
<
int64_t
>
output_cols
(
outputs
->
size
());
for
(
size_t
i
=
0
;
i
<
num
;
++
i
)
{
int
t_cols
=
ref_inputs
[
i
]
->
numel
()
/
input_rows
;
input_cols
+=
t_cols
;
output_cols
[
i
]
=
t_cols
;
}
auto
cpu_place
=
context
.
GetPlace
();
// computation
for
(
int
k
=
0
;
k
<
input_rows
;
++
k
)
{
const
T
*
src_ptr
=
input
.
data
<
T
>
()
+
k
*
input_cols
;
int
col_idx
=
0
;
for
(
size_t
j
=
0
;
j
<
num
;
++
j
)
{
int
col_len
=
output_cols
[
j
];
auto
*
out_tensor
=
outputs
->
at
(
j
);
if
(
out_tensor
!=
nullptr
)
{
T
*
dst_ptr
=
out_tensor
->
data
<
T
>
()
+
k
*
col_len
;
paddle
::
memory
::
Copy
(
cpu_place
,
dst_ptr
,
cpu_place
,
src_ptr
+
col_idx
,
sizeof
(
T
)
*
col_len
);
}
col_idx
+=
col_len
;
}
}
}
}
// namespace pten
paddle/pten/kernels/cpu/concat_kernel.cc
0 → 100644
浏览文件 @
06803c29
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/concat_kernel.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/lod_utils.h"
#include "paddle/pten/kernels/cpu/concat_and_split.h"
#include "paddle/pten/kernels/funcs/concat_funcs.h"
namespace
pten
{
template
<
typename
T
,
typename
Context
>
void
ConcatKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
DenseTensor
>&
x
,
const
Scalar
&
axis_scalar
,
DenseTensor
*
out
)
{
int64_t
axis
=
axis_scalar
.
to
<
int64_t
>
();
axis
=
pten
::
funcs
::
ComputeAxis
(
axis
,
x
[
0
].
dims
().
size
());
std
::
vector
<
pten
::
DDim
>
x_dims
;
for
(
size_t
i
=
0
;
i
<
x
.
size
();
++
i
)
{
x_dims
.
push_back
(
x
[
i
].
dims
());
}
pten
::
DDim
out_dims
=
pten
::
funcs
::
ComputeAndCheckShape
(
true
,
x_dims
,
axis
);
out
->
Resize
(
out_dims
);
out
->
mutable_data
<
T
>
();
// If axis is 0, the lod of the output is not the same as inputs.
if
(
axis
==
0
&&
x
[
0
].
lod
().
size
()
>
0
)
{
size_t
lod_size_0
=
x
[
0
].
lod
().
size
();
size_t
lod_size
=
lod_size_0
;
for
(
size_t
i
=
1
;
i
<
x
.
size
();
++
i
)
{
if
(
x
[
i
].
lod
().
size
()
>
0
)
{
PADDLE_ENFORCE_EQ
(
x
[
i
].
lod
().
size
(),
lod_size_0
,
paddle
::
platform
::
errors
::
Unimplemented
(
"The lod level of all input LoDTensors should be same. "
"Maybe different lod level of input LoDTensors can concat,"
"it is not supported currently. The lod level of %dth input "
"is %d and first input is %d."
,
i
,
x
[
i
].
lod
().
size
(),
lod_size_0
));
}
else
{
lod_size
=
0
;
break
;
}
}
if
(
lod_size
)
{
auto
*
out_lod
=
out
->
mutable_lod
();
for
(
size_t
i
=
1
;
i
<
x
.
size
();
++
i
)
{
auto
in_lod
=
pten
::
ConvertToLengthBasedLoD
(
x
[
i
].
lod
());
pten
::
AppendLoD
(
out_lod
,
in_lod
);
}
}
}
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if
(
axis
==
0
&&
x
.
size
()
<
10
)
{
size_t
output_offset
=
0
;
for
(
auto
&
in
:
x
)
{
if
(
in
.
numel
()
==
0UL
)
{
continue
;
}
auto
in_stride
=
paddle
::
framework
::
stride_numel
(
in
.
dims
());
auto
out_stride
=
paddle
::
framework
::
stride_numel
(
out
->
dims
());
paddle
::
operators
::
StridedNumelCopyWithAxis
<
T
>
(
dev_ctx
,
axis
,
out
->
data
<
T
>
()
+
output_offset
,
out_stride
,
in
.
data
<
T
>
(),
in_stride
,
in_stride
[
axis
]);
output_offset
+=
in_stride
[
axis
];
}
}
else
{
std
::
vector
<
pten
::
DenseTensor
>
inputs
;
for
(
size_t
j
=
0
;
j
<
x
.
size
();
++
j
)
{
if
(
x
[
j
].
numel
()
>
0
)
{
inputs
.
push_back
(
x
[
j
]);
}
else
{
continue
;
}
}
ConcatImpl
<
T
,
Context
>
(
dev_ctx
,
inputs
,
axis
,
out
);
}
}
}
// namespace pten
PT_REGISTER_KERNEL
(
concat
,
CPU
,
ALL_LAYOUT
,
pten
::
ConcatKernel
,
float
,
double
,
bool
,
int64_t
,
int
,
uint8_t
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/kernels/funcs/concat_funcs.h
0 → 100644
浏览文件 @
06803c29
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/errors.h"
namespace
pten
{
namespace
funcs
{
static
inline
int64_t
ComputeAxis
(
int64_t
axis
,
int64_t
rank
)
{
PADDLE_ENFORCE_EQ
(
axis
>=
-
rank
&&
axis
<
rank
,
true
,
paddle
::
platform
::
errors
::
InvalidArgument
(
"The axis is expected to be in range of [%d, %d), but got %d"
,
-
rank
,
rank
,
axis
));
if
(
axis
<
0
)
{
axis
=
axis
+
rank
;
}
return
axis
>
0
?
axis
:
0
;
}
static
inline
pten
::
DDim
ComputeAndCheckShape
(
const
bool
is_runtime
,
const
std
::
vector
<
pten
::
DDim
>&
inputs_dims
,
const
size_t
axis
)
{
const
size_t
n
=
inputs_dims
.
size
();
auto
out_dims
=
inputs_dims
[
0
];
size_t
in_zero_dims_size
=
out_dims
.
size
();
for
(
size_t
i
=
1
;
i
<
n
;
i
++
)
{
PADDLE_ENFORCE_EQ
(
inputs_dims
[
i
].
size
(),
out_dims
.
size
(),
paddle
::
platform
::
errors
::
InvalidArgument
(
"The shape of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s]."
,
i
,
inputs_dims
[
0
],
i
,
inputs_dims
[
i
]));
for
(
size_t
j
=
0
;
j
<
in_zero_dims_size
;
j
++
)
{
if
(
j
==
axis
)
{
if
(
is_runtime
)
{
out_dims
[
axis
]
+=
inputs_dims
[
i
][
j
];
}
else
{
if
(
inputs_dims
[
i
][
j
]
==
-
1
||
out_dims
[
j
]
==
-
1
)
{
out_dims
[
axis
]
=
-
1
;
}
else
{
out_dims
[
axis
]
+=
inputs_dims
[
i
][
j
];
}
}
}
else
{
bool
check_shape
=
is_runtime
||
(
inputs_dims
[
0
][
j
]
>
0
&&
inputs_dims
[
i
][
j
]
>
0
);
if
(
check_shape
)
{
// check all shape in run time
PADDLE_ENFORCE_EQ
(
inputs_dims
[
0
][
j
],
inputs_dims
[
i
][
j
],
paddle
::
platform
::
errors
::
InvalidArgument
(
"The %d-th dimension of input[0] and input[%d] "
"is expected to be equal."
"But received input[0]'s shape = "
"[%s], input[%d]'s shape = [%s]."
,
j
,
i
,
inputs_dims
[
0
],
i
,
inputs_dims
[
i
]));
}
if
(
!
is_runtime
&&
out_dims
[
j
]
==
-
1
&&
inputs_dims
[
i
][
j
]
>
0
)
{
out_dims
[
j
]
=
inputs_dims
[
i
][
j
];
}
}
}
}
return
out_dims
;
}
}
// namespace funcs
}
// namespace pten
paddle/pten/kernels/gpu/concat_and_split.h
0 → 100644
浏览文件 @
06803c29
此差异已折叠。
点击以展开。
paddle/pten/kernels/gpu/concat_kernel.cu
0 → 100644
浏览文件 @
06803c29
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/concat_kernel.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/lod_utils.h"
#include "paddle/pten/kernels/funcs/concat_funcs.h"
#include "paddle/pten/kernels/gpu/concat_and_split.h"
namespace
pten
{
template
<
typename
T
,
typename
Context
>
void
ConcatKernel
(
const
Context
&
dev_ctx
,
const
std
::
vector
<
DenseTensor
>&
x
,
const
Scalar
&
axis_scalar
,
DenseTensor
*
out
)
{
int64_t
axis
=
axis_scalar
.
to
<
int64_t
>
();
axis
=
pten
::
funcs
::
ComputeAxis
(
axis
,
x
[
0
].
dims
().
size
());
std
::
vector
<
pten
::
DDim
>
x_dims
;
for
(
size_t
i
=
0
;
i
<
x
.
size
();
++
i
)
{
x_dims
.
push_back
(
x
[
i
].
dims
());
}
pten
::
DDim
out_dims
=
pten
::
funcs
::
ComputeAndCheckShape
(
true
,
x_dims
,
axis
);
out
->
Resize
(
out_dims
);
out
->
mutable_data
<
T
>
();
// If axis is 0, the lod of the output is not the same as inputs.
if
(
axis
==
0
&&
x
[
0
].
lod
().
size
()
>
0
)
{
size_t
lod_size_0
=
x
[
0
].
lod
().
size
();
size_t
lod_size
=
lod_size_0
;
for
(
size_t
i
=
1
;
i
<
x
.
size
();
++
i
)
{
if
(
x
[
i
].
lod
().
size
()
>
0
)
{
PADDLE_ENFORCE_EQ
(
x
[
i
].
lod
().
size
(),
lod_size_0
,
paddle
::
platform
::
errors
::
Unimplemented
(
"The lod level of all input LoDTensors should be same. "
"Maybe different lod level of input LoDTensors can concat,"
"it is not supported currently. The lod level of %dth input "
"is %d and first input is %d."
,
i
,
x
[
i
].
lod
().
size
(),
lod_size_0
));
}
else
{
lod_size
=
0
;
break
;
}
}
if
(
lod_size
)
{
auto
*
out_lod
=
out
->
mutable_lod
();
for
(
size_t
i
=
1
;
i
<
x
.
size
();
++
i
)
{
auto
in_lod
=
pten
::
ConvertToLengthBasedLoD
(
x
[
i
].
lod
());
pten
::
AppendLoD
(
out_lod
,
in_lod
);
}
}
}
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if
(
axis
==
0
&&
x
.
size
()
<
10
)
{
size_t
output_offset
=
0
;
for
(
auto
&
in
:
x
)
{
if
(
in
.
numel
()
==
0UL
)
{
continue
;
}
auto
in_stride
=
paddle
::
framework
::
stride_numel
(
in
.
dims
());
auto
out_stride
=
paddle
::
framework
::
stride_numel
(
out
->
dims
());
paddle
::
operators
::
StridedNumelCopyWithAxis
<
T
>
(
dev_ctx
,
axis
,
out
->
data
<
T
>
()
+
output_offset
,
out_stride
,
in
.
data
<
T
>
(),
in_stride
,
in_stride
[
axis
]);
output_offset
+=
in_stride
[
axis
];
}
}
else
{
std
::
vector
<
pten
::
DenseTensor
>
inputs
;
for
(
size_t
j
=
0
;
j
<
x
.
size
();
++
j
)
{
if
(
x
[
j
].
numel
()
>
0
)
{
inputs
.
push_back
(
x
[
j
]);
}
else
{
continue
;
}
}
ConcatImpl
<
T
,
Context
>
(
dev_ctx
,
inputs
,
axis
,
out
);
}
}
}
// namespace pten
PT_REGISTER_KERNEL
(
concat
,
GPU
,
ALL_LAYOUT
,
pten
::
ConcatKernel
,
float
,
double
,
bool
,
int64_t
,
int
,
uint8_t
,
paddle
::
platform
::
float16
,
paddle
::
platform
::
complex
<
float
>
,
paddle
::
platform
::
complex
<
double
>
)
{}
paddle/pten/tests/api/CMakeLists.txt
浏览文件 @
06803c29
...
...
@@ -21,3 +21,4 @@ cc_test(test_sum_api SRCS test_sum_api.cc DEPS pten_tensor pten_api pten_api_uti
cc_test
(
test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api_utils
)
cc_test
(
test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils
)
cc_test
(
test_conj_api SRCS test_conj_api.cc DEPS pten_tensor pten_api pten_api_utils
)
cc_test
(
test_concat_api SRCS test_concat_api.cc DEPS pten_tensor pten_api pten_api_utils
)
paddle/pten/tests/api/test_concat_api.cc
0 → 100644
浏览文件 @
06803c29
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/api/include/api.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace
paddle
{
namespace
tests
{
namespace
framework
=
paddle
::
framework
;
using
DDim
=
paddle
::
framework
::
DDim
;
// TODO(chentianyu03): Remove this test after the API is used in the dygraph
TEST
(
API
,
concat
)
{
// 1. create tensor
const
auto
alloc
=
std
::
make_unique
<
paddle
::
experimental
::
DefaultAllocator
>
(
paddle
::
platform
::
CPUPlace
());
auto
dense_x
=
std
::
make_shared
<
pten
::
DenseTensor
>
(
alloc
.
get
(),
pten
::
DenseTensorMeta
(
pten
::
DataType
::
FLOAT32
,
framework
::
make_ddim
({
3
,
10
}),
pten
::
DataLayout
::
NCHW
));
auto
*
dense_x_data
=
dense_x
->
mutable_data
<
float
>
();
auto
dense_y
=
std
::
make_shared
<
pten
::
DenseTensor
>
(
alloc
.
get
(),
pten
::
DenseTensorMeta
(
pten
::
DataType
::
FLOAT32
,
framework
::
make_ddim
({
3
,
10
}),
pten
::
DataLayout
::
NCHW
));
auto
*
dense_y_data
=
dense_y
->
mutable_data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
10
;
++
j
)
{
dense_x_data
[
i
*
10
+
j
]
=
(
i
*
10
+
j
)
*
1.0
;
dense_y_data
[
i
*
10
+
j
]
=
(
i
*
10
+
j
)
*
1.0
;
}
}
paddle
::
experimental
::
Tensor
x
(
dense_x
);
paddle
::
experimental
::
Tensor
y
(
dense_y
);
std
::
vector
<
paddle
::
experimental
::
Tensor
>
inputs
{
x
,
y
};
// 2. test API
auto
out
=
paddle
::
experimental
::
concat
(
inputs
,
0
);
// 3. check result
ASSERT_EQ
(
out
.
dims
().
size
(),
2
);
ASSERT_EQ
(
out
.
dims
()[
0
],
6
);
ASSERT_EQ
(
out
.
dims
()[
1
],
10
);
ASSERT_EQ
(
out
.
numel
(),
60
);
ASSERT_EQ
(
out
.
is_cpu
(),
true
);
ASSERT_EQ
(
out
.
type
(),
pten
::
DataType
::
FLOAT32
);
ASSERT_EQ
(
out
.
layout
(),
pten
::
DataLayout
::
NCHW
);
ASSERT_EQ
(
out
.
initialized
(),
true
);
auto
dense_out
=
std
::
dynamic_pointer_cast
<
pten
::
DenseTensor
>
(
out
.
impl
());
auto
out_data
=
dense_out
->
data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
60
;
++
i
)
{
if
(
i
<
30
)
{
ASSERT_NEAR
(
dense_x_data
[
i
],
out_data
[
i
],
1e-6
f
);
}
else
{
ASSERT_NEAR
(
dense_y_data
[
i
-
30
],
out_data
[
i
],
1e-6
f
);
}
}
}
}
// namespace tests
}
// namespace paddle
paddle/pten/tests/kernels/CMakeLists.txt
浏览文件 @
06803c29
...
...
@@ -10,3 +10,4 @@ cc_test(test_elementwise_dev_api SRCS test_elementwise_dev_api.cc DEPS pten pten
cc_test
(
test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_utils
)
cc_test
(
test_sum_dev_api SRCS test_sum_dev_api.cc DEPS pten pten_api_utils
)
cc_test
(
test_conj_dev_api SRCS test_conj_dev_api.cc DEPS pten pten_api_utils
)
cc_test
(
test_concat_dev_api SRCS test_concat_dev_api.cc DEPS pten pten_api_utils
)
paddle/pten/tests/kernels/test_concat_dev_api.cc
0 → 100644
浏览文件 @
06803c29
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/kernels/concat_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace
pten
{
namespace
tests
{
namespace
framework
=
paddle
::
framework
;
using
DDim
=
paddle
::
framework
::
DDim
;
TEST
(
DEV_API
,
concat
)
{
// 1. create tensor
const
auto
alloc
=
std
::
make_unique
<
paddle
::
experimental
::
DefaultAllocator
>
(
paddle
::
platform
::
CPUPlace
());
pten
::
DenseTensor
dense_x
(
alloc
.
get
(),
pten
::
DenseTensorMeta
(
pten
::
DataType
::
FLOAT32
,
framework
::
make_ddim
({
3
,
10
}),
pten
::
DataLayout
::
NCHW
));
auto
*
dense_x_data
=
dense_x
.
mutable_data
<
float
>
();
pten
::
DenseTensor
dense_y
(
alloc
.
get
(),
pten
::
DenseTensorMeta
(
pten
::
DataType
::
FLOAT32
,
framework
::
make_ddim
({
3
,
10
}),
pten
::
DataLayout
::
NCHW
));
auto
*
dense_y_data
=
dense_y
.
mutable_data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
3
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
10
;
++
j
)
{
dense_x_data
[
i
*
10
+
j
]
=
(
i
*
10
+
j
)
*
1.0
;
dense_y_data
[
i
*
10
+
j
]
=
(
i
*
10
+
j
)
*
1.0
;
}
}
paddle
::
platform
::
DeviceContextPool
&
pool
=
paddle
::
platform
::
DeviceContextPool
::
Instance
();
auto
*
dev_ctx
=
pool
.
Get
(
paddle
::
platform
::
CPUPlace
());
std
::
vector
<
pten
::
DenseTensor
>
inputs
=
{
dense_x
,
dense_y
};
// 2. test API
auto
out
=
pten
::
Concat
<
float
>
(
*
(
static_cast
<
paddle
::
platform
::
CPUDeviceContext
*>
(
dev_ctx
)),
inputs
,
0
);
// 3. check result
ASSERT_EQ
(
out
.
dims
().
size
(),
2
);
ASSERT_EQ
(
out
.
dims
()[
0
],
6
);
ASSERT_EQ
(
out
.
dims
()[
1
],
10
);
ASSERT_EQ
(
out
.
meta
().
dtype
,
pten
::
DataType
::
FLOAT32
);
ASSERT_EQ
(
out
.
meta
().
layout
,
pten
::
DataLayout
::
NCHW
);
auto
out_data
=
out
.
data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
60
;
++
i
)
{
if
(
i
<
30
)
{
ASSERT_NEAR
(
dense_x_data
[
i
],
out_data
[
i
],
1e-6
f
);
}
else
{
ASSERT_NEAR
(
dense_y_data
[
i
-
30
],
out_data
[
i
],
1e-6
f
);
}
}
}
}
// namespace tests
}
// namespace pten
python/paddle/utils/code_gen/api.yaml
浏览文件 @
06803c29
...
...
@@ -18,6 +18,16 @@
param
:
[
x
,
out_dtype
]
data_type
:
x
-
api
:
concat
args
:
(const std::vector<Tensor>& x, const Scalar& axis)
output
:
Tensor
infer_meta
:
func
:
ConcatInferMeta
param
:
[
x
,
axis
,
true
]
kernel
:
func
:
concat
-
api
:
conj
args
:
(const Tensor& x)
output
:
Tensor
...
...
python/paddle/utils/code_gen/api_gen.py
浏览文件 @
06803c29
...
...
@@ -58,7 +58,10 @@ class API:
f
"Args declaration should start with '(' and end with ')', please check the args of
{
self
.
api
}
in api.yaml."
args_str
=
args_str
[
1
:
-
1
]
args_list
=
args_str
.
split
(
','
)
input_types
=
[
'const Tensor&'
,
'const Tensor &'
]
input_types
=
[
'const Tensor&'
,
'const Tensor &'
,
'const std::vector<Tensor>&'
,
'const std::vector<Tensor> &'
]
attr_types
=
[
'const Scalar&'
,
'const Scalar &'
,
'const ScalarArray&'
,
'const ScalarArray &'
,
\
'int'
,
'int32_t'
,
'int64_t'
,
'size_t'
,
'float'
,
'double'
,
'bool'
,
\
'const std::vector<int64_t>&'
,
'Backend'
,
'DataLayout'
,
'DataType'
]
...
...
@@ -247,7 +250,7 @@ PADDLE_API {self.output} {self.api}({self.args['args_declare']});
param_code
=
""
for
param
in
infer_meta_params
:
if
param
in
input_names
:
param_code
=
param_code
+
self
.
prefix_tensor_name
+
param
+
"->meta(
), "
param_code
=
param_code
+
"GetDenseTensorMeta("
+
self
.
prefix_tensor_name
+
param
+
"
), "
elif
param
in
attr_names
:
param_code
=
param_code
+
param
+
", "
elif
isinstance
(
param
,
str
):
...
...
@@ -267,7 +270,7 @@ PADDLE_API {self.output} {self.api}({self.args['args_declare']});
for
input_name
in
input_names
:
# set input code
input_tensor_code
=
input_tensor_code
+
f
"""
auto
{
self
.
prefix_tensor_name
}{
input_name
}
=
std::dynamic_pointer_cast<pten::DenseTensor>(
{
input_name
}
.impl()
);"""
auto
{
self
.
prefix_tensor_name
}{
input_name
}
=
TensorToDenseTensor(
{
input_name
}
);"""
attr_names
=
attrs
[
'names'
]
if
kernel_param
is
None
:
...
...
@@ -374,6 +377,35 @@ namespace experimental {
"""
)
def
tensor_to_densetensor
():
return
"""
std::shared_ptr<pten::DenseTensor> TensorToDenseTensor(const Tensor& tensor) {
return std::dynamic_pointer_cast<pten::DenseTensor>(tensor.impl());
}
std::shared_ptr<std::vector<pten::DenseTensor>> TensorToDenseTensor(const std::vector<Tensor>& tensors) {
std::vector<pten::DenseTensor> pt_tensors;
for(auto & t : tensors) {
pt_tensors.push_back(*std::dynamic_pointer_cast<pten::DenseTensor>(t.impl()));
}
return std::make_shared<std::vector<pten::DenseTensor>>(pt_tensors);
}
const pten::DenseTensorMeta GetDenseTensorMeta(const std::shared_ptr<pten::DenseTensor> & x) {
return x->meta();
}
const std::vector<pten::DenseTensorMeta> GetDenseTensorMeta(const std::shared_ptr<std::vector<pten::DenseTensor>>& x) {
std::vector<pten::DenseTensorMeta> metas;
for(auto& t : *x) {
metas.push_back(t.meta());
}
return metas;
}
"""
def
generate_api
(
api_yaml_path
,
header_file_path
,
source_file_path
):
with
open
(
api_yaml_path
,
'r'
)
as
f
:
...
...
@@ -390,6 +422,7 @@ def generate_api(api_yaml_path, header_file_path, source_file_path):
include_header_file
=
"paddle/pten/api/include/api.h"
source_file
.
write
(
source_include
(
include_header_file
))
source_file
.
write
(
namespace
[
0
])
source_file
.
write
(
tensor_to_densetensor
())
for
api
in
apis
:
api_code
=
API
(
api
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录