Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
14949521
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
14949521
编写于
5月 20, 2021
作者:
L
limingshu
提交者:
GitHub
5月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Binary functor envoking of elementwise broadcast (#32928)
上级
6f8de31d
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
193 addition
and
170 deletion
+193
-170
paddle/fluid/operators/abs_op.cu
paddle/fluid/operators/abs_op.cu
+3
-2
paddle/fluid/operators/activation_op.cu
paddle/fluid/operators/activation_op.cu
+5
-5
paddle/fluid/operators/elementwise/elementwise_add_op.cc
paddle/fluid/operators/elementwise/elementwise_add_op.cc
+0
-9
paddle/fluid/operators/elementwise/elementwise_add_op.cu
paddle/fluid/operators/elementwise/elementwise_add_op.cu
+15
-21
paddle/fluid/operators/elementwise/elementwise_add_op.h
paddle/fluid/operators/elementwise/elementwise_add_op.h
+13
-19
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
...fluid/operators/elementwise/elementwise_op_broadcast.cu.h
+148
-105
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
+2
-8
paddle/fluid/platform/fast_divmod.h
paddle/fluid/platform/fast_divmod.h
+7
-1
未找到文件。
paddle/fluid/operators/abs_op.cu
浏览文件 @
14949521
...
@@ -52,8 +52,9 @@ class AbsKernel<platform::CUDADeviceContext, T>
...
@@ -52,8 +52,9 @@ class AbsKernel<platform::CUDADeviceContext, T>
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
x
};
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
x
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
out
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
out
};
auto
functor
=
CudaAbsFunctor
<
T
>
();
auto
functor
=
CudaAbsFunctor
<
T
>
();
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
math
::
Real
<
T
>>
(
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
dev_ctx
,
ins
,
&
outs
,
functor
);
math
::
Real
<
T
>>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
}
}
};
};
...
...
paddle/fluid/operators/activation_op.cu
浏览文件 @
14949521
...
@@ -1316,8 +1316,8 @@ class ActivationCudaKernel
...
@@ -1316,8 +1316,8 @@ class ActivationCudaKernel
for
(
auto
&
attr
:
attrs
)
{
for
(
auto
&
attr
:
attrs
)
{
*
attr
.
second
=
ctx
.
Attr
<
float
>
(
attr
.
first
);
*
attr
.
second
=
ctx
.
Attr
<
float
>
(
attr
.
first
);
}
}
Launch
ElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
T
>
(
dev_ctx
,
ins
,
Launch
SameDimsElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
T
>
(
&
outs
,
functor
);
dev_ctx
,
ins
,
&
outs
,
functor
);
}
}
};
};
...
@@ -1346,16 +1346,16 @@ class ActivationGradCudaKernel
...
@@ -1346,16 +1346,16 @@ class ActivationGradCudaKernel
if
(
static_cast
<
int
>
(
Functor
::
FwdDeps
())
==
static_cast
<
int
>
(
kDepOut
))
{
if
(
static_cast
<
int
>
(
Functor
::
FwdDeps
())
==
static_cast
<
int
>
(
kDepOut
))
{
// Only need forward output Out
// Only need forward output Out
ins
.
push_back
(
out
);
ins
.
push_back
(
out
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
Launch
SameDims
ElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
dev_ctx
,
ins
,
&
outs
,
functor
);
}
else
if
(
static_cast
<
int
>
(
Functor
::
FwdDeps
())
==
}
else
if
(
static_cast
<
int
>
(
Functor
::
FwdDeps
())
==
static_cast
<
int
>
(
kDepX
))
{
static_cast
<
int
>
(
kDepX
))
{
// Only need forward input X
// Only need forward input X
ins
.
push_back
(
x
);
ins
.
push_back
(
x
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
Launch
SameDims
ElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
dev_ctx
,
ins
,
&
outs
,
functor
);
}
else
{
}
else
{
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
T
>
(
Launch
SameDims
ElementwiseCudaKernel
<
ElementwiseType
::
kUnary
,
T
,
T
>
(
dev_ctx
,
ins
,
&
outs
,
functor
);
dev_ctx
,
ins
,
&
outs
,
functor
);
}
}
}
}
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.cc
浏览文件 @
14949521
...
@@ -69,15 +69,6 @@ struct SameDimsElemwiseAdd<
...
@@ -69,15 +69,6 @@ struct SameDimsElemwiseAdd<
}
}
};
};
template
<
typename
T
>
struct
BroadcastElemwiseAdd
<
platform
::
CPUDeviceContext
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
default_elementwise_add
<
platform
::
CPUDeviceContext
,
T
>
(
ctx
,
x
,
y
,
z
);
}
};
class
ElementwiseAddOpMaker
:
public
ElementwiseOpMaker
{
class
ElementwiseAddOpMaker
:
public
ElementwiseOpMaker
{
protected:
protected:
std
::
string
GetName
()
const
override
{
return
"Add"
;
}
std
::
string
GetName
()
const
override
{
return
"Add"
;
}
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.cu
浏览文件 @
14949521
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
...
@@ -13,7 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/float16.h"
...
@@ -40,29 +39,24 @@ struct CudaAddFunctor {
...
@@ -40,29 +39,24 @@ struct CudaAddFunctor {
};
};
template
<
typename
T
>
template
<
typename
T
>
struct
SameDimsElemwiseAdd
<
platform
::
CUDADeviceContext
,
T
>
{
class
ElementwiseAddKernel
<
platform
::
CUDADeviceContext
,
T
>
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
:
public
framework
::
OpKernel
<
T
>
{
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
public:
framework
::
Tensor
*
z
)
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
z
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
axis
=
axis
==
-
1
?
std
::
abs
(
x
->
dims
().
size
()
-
y
->
dims
().
size
())
:
axis
;
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
x
,
y
};
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
x
,
y
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
z
};
std
::
vector
<
framework
::
Tensor
*>
outs
=
{
z
};
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
const
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
ins
,
&
outs
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
CudaAddFunctor
<
T
>
());
}
};
template
<
typename
T
>
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
struct
BroadcastElemwiseAdd
<
platform
::
CUDADeviceContext
,
T
>
{
cuda_ctx
,
ins
,
&
outs
,
axis
,
CudaAddFunctor
<
T
>
());
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
out
)
{
std
::
vector
<
const
framework
::
Tensor
*>
ins
=
{
x
,
y
};
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
axis
=
axis
==
-
1
?
std
::
abs
(
x
->
dims
().
size
()
-
y
->
dims
().
size
())
:
axis
;
LaunchBroadcastElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
>
(
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
ins
,
out
,
CudaAddFunctor
<
T
>
(),
axis
);
}
}
};
};
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.h
浏览文件 @
14949521
...
@@ -26,7 +26,7 @@ limitations under the License. */
...
@@ -26,7 +26,7 @@ limitations under the License. */
#include <cuda.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_fp16.h>
#include "cub/cub.cuh"
#include "cub/cub.cuh"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#endif
#endif
#ifdef __HIPCC__
#ifdef __HIPCC__
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
...
@@ -40,9 +40,10 @@ namespace paddle {
...
@@ -40,9 +40,10 @@ namespace paddle {
namespace
operators
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
void
default_elementwise_add
(
const
framework
::
ExecutionContext
&
ctx
,
void
LaunchBroadcastElementwiseCpuKernel
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
x_dims
=
x
->
dims
();
auto
x_dims
=
x
->
dims
();
auto
y_dims
=
y
->
dims
();
auto
y_dims
=
y
->
dims
();
...
@@ -62,13 +63,6 @@ struct SameDimsElemwiseAdd {
...
@@ -62,13 +63,6 @@ struct SameDimsElemwiseAdd {
framework
::
Tensor
*
z
);
framework
::
Tensor
*
z
);
};
};
template
<
typename
DeviceContext
,
typename
T
,
class
Enable
=
void
>
struct
BroadcastElemwiseAdd
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
);
};
template
<
typename
DeviceContext
,
typename
T
>
template
<
typename
DeviceContext
,
typename
T
>
class
ElementwiseAddKernel
:
public
framework
::
OpKernel
<
T
>
{
class
ElementwiseAddKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
...
@@ -77,13 +71,13 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
...
@@ -77,13 +71,13 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
z
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
*
z
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dims_equal
=
x
->
dims
()
==
y
->
dims
();
if
(
x
->
dims
()
==
y
->
dims
())
{
if
(
dims_equal
)
{
SameDimsElemwiseAdd
<
platform
::
CPUDeviceContext
,
T
>
SameDimsElemwiseAdd
<
DeviceContext
,
T
>
same_dims_add
;
LaunchElementwiseCpuKernel
;
same_dims_add
(
ctx
,
x
,
y
,
z
);
LaunchElementwiseCpuKernel
(
ctx
,
x
,
y
,
z
);
}
else
{
}
else
{
BroadcastElemwiseAdd
<
DeviceContext
,
T
>
broadcast_add
;
LaunchBroadcastElementwiseCpuKernel
<
platform
::
CPUDeviceContext
,
T
>
(
ctx
,
x
,
broadcast_add
(
ctx
,
x
,
y
,
z
);
y
,
z
);
}
}
}
}
};
};
...
@@ -469,8 +463,8 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
...
@@ -469,8 +463,8 @@ class ElementwiseAddDoubleGradKernel : public framework::OpKernel<T> {
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
y
,
ddy
,
&
ddy_safe
);
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
y
,
ddy
,
&
ddy_safe
);
ddout
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ddout
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
default_elementwise_add
<
DeviceContext
,
T
>
(
ctx
,
&
ddx_safe
,
&
ddy
_safe
,
LaunchBroadcastElementwiseCpuKernel
<
DeviceContext
,
T
>
(
ctx
,
&
ddx
_safe
,
ddout
);
&
ddy_safe
,
ddout
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
浏览文件 @
14949521
...
@@ -14,7 +14,7 @@
...
@@ -14,7 +14,7 @@
#pragma once
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_
broadcast_
impl.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -28,7 +28,8 @@ struct DimensionsTransform {
...
@@ -28,7 +28,8 @@ struct DimensionsTransform {
std
::
vector
<
DimVector
>
in_dims
;
std
::
vector
<
DimVector
>
in_dims
;
private:
private:
// 1. To compensate the lackage of input_tensors` dimension;
// To compensate the lackage of input_tensors` dimension with input variable
// 'axis'
void
InputDimensionsExtend
(
int
N
,
int
axis
)
{
void
InputDimensionsExtend
(
int
N
,
int
axis
)
{
for
(
auto
&
in_dim
:
in_dims
)
{
for
(
auto
&
in_dim
:
in_dims
)
{
int64_t
in_idx
=
0
;
int64_t
in_idx
=
0
;
...
@@ -70,7 +71,7 @@ struct DimensionsTransform {
...
@@ -70,7 +71,7 @@ struct DimensionsTransform {
}
}
template
<
typename
MergeFunctor
>
template
<
typename
MergeFunctor
>
__inline__
void
DimensionsReorganise
(
MergeFunctor
merge_func
,
int
N
)
{
__inline__
void
MergeDimensions
(
MergeFunctor
merge_func
,
int
N
)
{
auto
VectorReorganise
=
[](
DimVector
*
vec
,
int
l_idx
,
int
m_idx
)
{
auto
VectorReorganise
=
[](
DimVector
*
vec
,
int
l_idx
,
int
m_idx
)
{
(
*
vec
)[
m_idx
-
1
]
=
(
*
vec
)[
m_idx
-
1
]
=
std
::
accumulate
(
vec
->
begin
()
+
l_idx
,
vec
->
begin
()
+
m_idx
,
1
,
std
::
accumulate
(
vec
->
begin
()
+
l_idx
,
vec
->
begin
()
+
m_idx
,
1
,
...
@@ -139,7 +140,7 @@ struct DimensionsTransform {
...
@@ -139,7 +140,7 @@ struct DimensionsTransform {
// To Merge the dimensions of input_tensors while the consequtive
// To Merge the dimensions of input_tensors while the consequtive
// equal-dimensions appears.
// equal-dimensions appears.
MergeFunctor
merge_ptr
=
merge_sequential_dims
;
MergeFunctor
merge_ptr
=
merge_sequential_dims
;
DimensionsReorganise
<
MergeFunctor
>
(
merge_ptr
,
N
);
MergeDimensions
<
MergeFunctor
>
(
merge_ptr
,
N
);
int
min_idx
=
0
;
int
min_idx
=
0
;
int
min_val
=
std
::
accumulate
(
in_dims
[
0
].
begin
(),
in_dims
[
0
].
end
(),
1
,
int
min_val
=
std
::
accumulate
(
in_dims
[
0
].
begin
(),
in_dims
[
0
].
end
(),
1
,
...
@@ -155,12 +156,12 @@ struct DimensionsTransform {
...
@@ -155,12 +156,12 @@ struct DimensionsTransform {
// To Merge the dimension of input_tensors while the consequtive
// To Merge the dimension of input_tensors while the consequtive
// 1-value-dimensions appears.
// 1-value-dimensions appears.
merge_ptr
=
merge_sequential_one_dims
;
merge_ptr
=
merge_sequential_one_dims
;
DimensionsReorganise
<
MergeFunctor
>
(
merge_ptr
,
N
);
MergeDimensions
<
MergeFunctor
>
(
merge_ptr
,
N
);
std
::
swap
(
in_dims
[
min_idx
],
in_dims
[
0
]);
std
::
swap
(
in_dims
[
min_idx
],
in_dims
[
0
]);
}
}
};
};
struct
CalculateInputStrides
{
struct
StridesCalculation
{
std
::
vector
<
std
::
vector
<
uint32_t
>>
strides
;
std
::
vector
<
std
::
vector
<
uint32_t
>>
strides
;
std
::
vector
<
FastDivMod
>
divmoders
;
std
::
vector
<
FastDivMod
>
divmoders
;
...
@@ -181,9 +182,9 @@ struct CalculateInputStrides {
...
@@ -181,9 +182,9 @@ struct CalculateInputStrides {
}
}
public:
public:
explicit
CalculateInputStrides
(
explicit
StridesCalculation
(
const
int64_t
&
dim_size
,
const
int64_t
&
dim_size
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
in_dims
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
in_dims
,
const
std
::
vector
<
int64_t
>
&
out_dims
)
{
const
std
::
vector
<
int64_t
>
&
out_dims
)
{
const
auto
N
=
in_dims
.
size
();
const
auto
N
=
in_dims
.
size
();
divmoders
.
resize
(
dim_size
);
divmoders
.
resize
(
dim_size
);
strides
.
resize
(
N
,
std
::
vector
<
uint32_t
>
(
dim_size
,
1
));
strides
.
resize
(
N
,
std
::
vector
<
uint32_t
>
(
dim_size
,
1
));
...
@@ -195,34 +196,40 @@ struct CalculateInputStrides {
...
@@ -195,34 +196,40 @@ struct CalculateInputStrides {
}
}
};
};
template
<
typename
T
,
ElementwiseType
ET
,
int
VecSize
,
int
kDims
>
template
<
typename
T
,
typename
Functor
,
ElementwiseType
ET
,
int
VecSize
,
int
kDims
>
struct
BroadcastArgsWarpper
{
struct
BroadcastArgsWarpper
{
using
DimsVec
=
CudaAlignedVector
<
T
,
VecSize
>
;
using
VecType
=
CudaAlignedVector
<
T
,
VecSize
>
;
T
*
out_data
;
T
*
out_data
;
VecType
*
vec_out_data
;
const
T
*
__restrict__
in_data
[
ET
];
const
T
*
__restrict__
in_data
[
ET
];
uint32_t
strides
[
ET
][
framework
::
DDim
::
kMaxRank
];
const
VecType
*
__restrict__
vec_in_data
[
ET
];
bool
no_broadcast
[
ET
];
bool
no_broadcast
[
ET
];
FastDivMod
divmoders
[
kDims
];
FastDivMod
divmoders
[
kDims
];
uint32_t
scalar_offset
;
uint32_t
strides
[
ET
][
framework
::
DDim
::
kMaxRank
];
uint32_t
scalar_cal_offset
;
Functor
func
;
HOSTDEVICE
BroadcastArgsWarpper
(
HOSTDEVICE
BroadcastArgsWarpper
(
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
framework
::
Tensor
*
out
,
const
CalculateInputStrides
&
offset_calculator
,
framework
::
Tensor
*
out
,
int
scalar_cal_offset
,
Functor
func
,
int
scalar_offset
)
const
StridesCalculation
&
offset_calculator
)
:
scalar_
offset
(
scalar_offset
)
{
:
scalar_
cal_offset
(
scalar_cal_offset
),
func
(
func
)
{
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
in_data
[
j
]
=
ins
[
j
]
->
data
<
T
>
();
in_data
[
j
]
=
ins
[
j
]
->
data
<
T
>
();
vec_in_data
[
j
]
=
reinterpret_cast
<
const
VecType
*>
(
in_data
[
j
]);
no_broadcast
[
j
]
=
ins
[
j
]
->
dims
()
==
out
->
dims
()
?
true
:
false
;
no_broadcast
[
j
]
=
ins
[
j
]
->
dims
()
==
out
->
dims
()
?
true
:
false
;
memcpy
(
strides
[
j
],
offset_calculator
.
strides
[
j
].
data
(),
memcpy
(
strides
[
j
],
offset_calculator
.
strides
[
j
].
data
(),
kDims
*
sizeof
(
uint32_t
));
kDims
*
sizeof
(
uint32_t
));
}
}
out_data
=
out
->
data
<
T
>
();
out_data
=
out
->
data
<
T
>
();
vec_out_data
=
reinterpret_cast
<
VecType
*>
(
out_data
);
memcpy
(
divmoders
,
offset_calculator
.
divmoders
.
data
(),
memcpy
(
divmoders
,
offset_calculator
.
divmoders
.
data
(),
kDims
*
sizeof
(
FastDivMod
));
kDims
*
sizeof
(
FastDivMod
));
}
}
__device__
__forceinline__
uint32_t
Get
DivmodOffset
(
int
idx
,
int
in_idx
)
{
__device__
__forceinline__
uint32_t
Get
OffsetByDivmod
(
int
idx
,
int
in_idx
)
{
uint32_t
offset
=
0
;
uint32_t
offset
=
0
;
#pragma unroll(kDims)
#pragma unroll(kDims)
...
@@ -234,120 +241,127 @@ struct BroadcastArgsWarpper {
...
@@ -234,120 +241,127 @@ struct BroadcastArgsWarpper {
return
offset
;
return
offset
;
}
}
__device__
__forceinline__
void
CommonVector
(
DimsVec
args
[],
int
tid
,
__device__
__forceinline__
void
LoadVectorizedDataCommon
(
VecType
*
vector_args
,
int
idx
)
{
int
tid
,
int
idx
)
{
const
DimsVec
*
__restrict__
vec_data
=
*
vector_args
=
vec_in_data
[
idx
][
tid
];
reinterpret_cast
<
const
DimsVec
*
__restrict__
>
(
in_data
[
idx
]);
args
[
idx
]
=
vec_data
[
tid
];
}
}
__device__
__forceinline__
void
DivmodVector
(
DimsVec
args
[],
int
tid
,
__device__
__forceinline__
void
LoadVectorizedDataByDivmod
(
T
*
scalar_args
,
int
idx
)
{
int
tid
,
int
idx
)
{
int
index
=
tid
*
VecSize
;
int
index
=
tid
*
VecSize
;
#pragma unroll(VecSize)
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
uint32_t
offset
=
Get
DivmodOffset
(
index
+
i
,
idx
);
uint32_t
offset
=
Get
OffsetByDivmod
(
index
+
i
,
idx
);
args
[
idx
].
val
[
i
]
=
in_data
[
idx
][
offset
];
scalar_args
[
i
]
=
in_data
[
idx
][
offset
];
}
}
}
}
__device__
__forceinline__
void
CommonScalar
(
T
args
[],
int
tid
,
int
idx
)
{
__device__
__forceinline__
void
LoadScalarizedDataCommon
(
T
args
[],
int
tid
,
args
[
idx
]
=
in_data
[
idx
][
tid
+
scalar_offset
];
int
idx
)
{
args
[
idx
]
=
in_data
[
idx
][
tid
+
scalar_cal_offset
];
}
}
__device__
__forceinline__
void
DivmodScalar
(
T
args
[],
int
tid
,
int
idx
)
{
__device__
__forceinline__
void
LoadScalarizedDataByDivmod
(
T
args
[],
int
tid
,
auto
offset
=
GetDivmodOffset
(
tid
+
scalar_offset
,
idx
);
int
idx
)
{
auto
offset
=
GetOffsetByDivmod
(
tid
+
scalar_cal_offset
,
idx
);
args
[
idx
]
=
in_data
[
idx
][
offset
];
args
[
idx
]
=
in_data
[
idx
][
offset
];
}
}
__device__
__forceinline__
void
LoadVector
(
DimsVec
args
[],
int
tid
)
{
__device__
__forceinline__
void
LoadVectorizedData
(
T
(
*
args
)[
VecSize
],
int
tid
)
{
#pragma unroll(ET)
#pragma unroll(ET)
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
if
(
no_broadcast
[
j
])
{
if
(
no_broadcast
[
j
])
{
CommonVector
(
args
,
tid
,
j
);
VecType
*
vector_args
=
reinterpret_cast
<
VecType
*>
(
args
[
j
]);
LoadVectorizedDataCommon
(
vector_args
,
tid
,
j
);
}
else
{
}
else
{
DivmodVector
(
args
,
tid
,
j
);
LoadVectorizedDataByDivmod
(
args
[
j
]
,
tid
,
j
);
}
}
}
}
}
}
__device__
__forceinline__
void
LoadScalar
(
T
args
[],
int
tid
)
{
__device__
__forceinline__
void
LoadScalar
izedData
(
T
args
[],
int
tid
)
{
#pragma unroll(ET)
#pragma unroll(ET)
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
if
(
no_broadcast
[
j
])
{
if
(
no_broadcast
[
j
])
{
CommonScalar
(
args
,
tid
,
j
);
LoadScalarizedDataCommon
(
args
,
tid
,
j
);
}
else
{
}
else
{
DivmodScalar
(
args
,
tid
,
j
);
LoadScalarizedDataByDivmod
(
args
,
tid
,
j
);
}
}
}
}
}
}
__device__
__forceinline__
void
StoreVector
(
DimsVec
args
[],
int
tid
)
{
__device__
__forceinline__
void
StoreVectorizedData
(
T
(
*
args
)[
VecSize
],
DimsVec
*
vec_out
=
reinterpret_cast
<
DimsVec
*>
(
out_data
);
int
tid
)
{
vec_out
[
tid
]
=
args
[
0
];
VecType
*
args_out
=
reinterpret_cast
<
VecType
*>
(
args
[
0
]);
vec_out_data
[
tid
]
=
*
args_out
;
}
}
__device__
__forceinline__
void
StoreScalar
(
T
args
[],
int
tid
)
{
__device__
__forceinline__
void
StoreScalar
izedData
(
T
args
[],
int
tid
)
{
out_data
[
scalar_offset
+
tid
]
=
args
[
0
];
out_data
[
scalar_
cal_
offset
+
tid
]
=
args
[
0
];
}
}
};
};
template
<
typename
T
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
>
template
<
typename
T
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
>
__device__
inline
void
ScalarizedBroadcastKernelImpl
(
__device__
inline
void
ScalarizedBroadcastKernelImpl
(
BroadcastArgsWarpper
data_transf
er
,
int
tid
)
{
BroadcastArgsWarpper
broadcast_warpp
er
,
int
tid
)
{
T
args
[
ET
];
T
args
[
ET
];
data_transfer
.
LoadScalar
(
args
,
tid
);
broadcast_warpper
.
LoadScalarizedData
(
args
,
tid
);
#pragma unroll(ET)
#pragma unroll(ET)
for
(
int
j
=
1
;
j
<
ET
;
++
j
)
{
for
(
int
j
=
1
;
j
<
ET
;
++
j
)
{
args
[
0
]
+=
args
[
j
]
;
args
[
0
]
=
broadcast_warpper
.
func
(
args
)
;
}
}
data_transfer
.
StoreScalar
(
args
,
tid
);
broadcast_warpper
.
StoreScalarizedData
(
args
,
tid
);
}
}
template
<
typename
T
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
,
template
<
typename
T
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
,
int
VecSize
>
int
VecSize
>
__device__
inline
void
VectorizedBroadcastKernelImpl
(
__device__
inline
void
VectorizedBroadcastKernelImpl
(
BroadcastArgsWarpper
data_transf
er
,
int
tid
)
{
BroadcastArgsWarpper
broadcast_warpp
er
,
int
tid
)
{
using
VecT
=
CudaAlignedVector
<
T
,
VecSize
>
;
T
ins
[
ET
]
;
VecT
args
[
ET
];
T
args
[
ET
][
VecSize
];
data_transfer
.
LoadVector
(
args
,
tid
);
broadcast_warpper
.
LoadVectorizedData
(
args
,
tid
);
#pragma unroll(ET)
for
(
int
j
=
1
;
j
<
ET
;
++
j
)
{
#pragma unroll(VecSize)
#pragma unroll(VecSize)
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
for
(
int
i
=
0
;
i
<
VecSize
;
++
i
)
{
args
[
0
].
val
[
i
]
+=
args
[
j
].
val
[
i
];
#pragma unroll(ET)
for
(
int
j
=
0
;
j
<
ET
;
++
j
)
{
ins
[
j
]
=
args
[
j
][
i
];
}
}
args
[
0
][
i
]
=
broadcast_warpper
.
func
(
ins
);
}
}
data_transfer
.
StoreVector
(
args
,
tid
);
broadcast_warpper
.
StoreVectorizedData
(
args
,
tid
);
}
}
template
<
typename
T
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
,
template
<
typename
T
,
typename
BroadcastArgsWarpper
,
ElementwiseType
ET
,
int
VecSize
>
int
VecSize
>
__global__
void
ElementwiseBroadcastKernel
(
BroadcastArgsWarpper
data_transfer
,
__global__
void
ElementwiseBroadcastKernel
(
int
main_tid
,
int
tail_tid
)
{
BroadcastArgsWarpper
broadcast_warpper
,
int
main_tid
,
int
tail_tid
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
// Aimming at vectorized calculation of major data whose length is max
// Vectorized calculation of major data whose length is the max multipler of
// multipler of VecSize.
// VecSize,
// eg: Calcualting the front 1024-length data in total 1027 data once VecSize
// is 4.
if
(
tid
<
main_tid
)
{
if
(
tid
<
main_tid
)
{
VectorizedBroadcastKernelImpl
<
T
,
BroadcastArgsWarpper
,
ET
,
VecSize
>
(
VectorizedBroadcastKernelImpl
<
T
,
BroadcastArgsWarpper
,
ET
,
VecSize
>
(
data_transf
er
,
tid
);
broadcast_warpp
er
,
tid
);
}
}
// Aimming at scalar calculation of rest data whose lenght cannot fulfill
// Scalarzed calculation of rest data whose lenght cannot fulfill VecSize.
// VecSize.
// eg: Calcualting the rest 3-length data in total 1027 data once VecSize is
// 4.
if
(
tid
<
tail_tid
)
{
if
(
tid
<
tail_tid
)
{
ScalarizedBroadcastKernelImpl
<
T
,
BroadcastArgsWarpper
,
ET
>
(
data_transfer
,
ScalarizedBroadcastKernelImpl
<
T
,
BroadcastArgsWarpper
,
ET
>
(
tid
);
broadcast_warpper
,
tid
);
}
}
}
}
template
<
typename
T
,
ElementwiseType
ET
,
int
VecSize
=
1
>
template
<
typename
T
,
ElementwiseType
ET
,
int
VecSize
,
typename
Functor
>
void
LaunchBroadcastKernelForDifferentDimSize
(
void
LaunchBroadcastKernelForDifferentDimSize
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
framework
::
Tensor
*
out
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
framework
::
Tensor
*
out
,
int
axis
)
{
int
axis
,
Functor
func
)
{
int
numel
=
out
->
numel
();
int
numel
=
out
->
numel
();
const
int
threads
=
256
;
const
int
threads
=
256
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
int
blocks
=
((
numel
+
VecSize
-
1
)
/
VecSize
+
threads
-
1
)
/
threads
;
...
@@ -357,72 +371,72 @@ void LaunchBroadcastKernelForDifferentDimSize(
...
@@ -357,72 +371,72 @@ void LaunchBroadcastKernelForDifferentDimSize(
auto
stream
=
ctx
.
stream
();
auto
stream
=
ctx
.
stream
();
const
auto
merge_dims
=
DimensionsTransform
(
ins
,
out
->
dims
(),
axis
);
const
auto
merge_dims
=
DimensionsTransform
(
ins
,
out
->
dims
(),
axis
);
const
auto
offset_calculator
=
CalculateInputStrides
(
const
auto
offset_calculator
=
StridesCalculation
(
merge_dims
.
dim_size
,
merge_dims
.
in_dims
,
merge_dims
.
out_dims
);
merge_dims
.
dim_size
,
merge_dims
.
in_dims
,
merge_dims
.
out_dims
);
switch
(
merge_dims
.
dim_size
)
{
switch
(
merge_dims
.
dim_size
)
{
case
1
:
{
case
1
:
{
auto
data_transfer
=
BroadcastArgsWarpper
<
T
,
ET
,
VecSize
,
1
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
1
>
(
ins
,
o
ffset_calculator
,
out
,
vec_len
);
ins
,
o
ut
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
data_transf
er
),
ET
,
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpp
er
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
data_transf
er
,
main_tid
,
tail_tid
);
broadcast_warpp
er
,
main_tid
,
tail_tid
);
break
;
break
;
}
}
case
2
:
{
case
2
:
{
auto
data_transfer
=
BroadcastArgsWarpper
<
T
,
ET
,
VecSize
,
2
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
2
>
(
ins
,
o
ffset_calculator
,
out
,
vec_len
);
ins
,
o
ut
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
data_transf
er
),
ET
,
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpp
er
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
data_transf
er
,
main_tid
,
tail_tid
);
broadcast_warpp
er
,
main_tid
,
tail_tid
);
break
;
break
;
}
}
case
3
:
{
case
3
:
{
auto
data_transfer
=
BroadcastArgsWarpper
<
T
,
ET
,
VecSize
,
3
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
3
>
(
ins
,
o
ffset_calculator
,
out
,
vec_len
);
ins
,
o
ut
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
data_transf
er
),
ET
,
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpp
er
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
data_transf
er
,
main_tid
,
tail_tid
);
broadcast_warpp
er
,
main_tid
,
tail_tid
);
break
;
break
;
}
}
case
4
:
{
case
4
:
{
auto
data_transfer
=
BroadcastArgsWarpper
<
T
,
ET
,
VecSize
,
4
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
4
>
(
ins
,
o
ffset_calculator
,
out
,
vec_len
);
ins
,
o
ut
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
data_transf
er
),
ET
,
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpp
er
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
data_transf
er
,
main_tid
,
tail_tid
);
broadcast_warpp
er
,
main_tid
,
tail_tid
);
break
;
break
;
}
}
case
5
:
{
case
5
:
{
auto
data_transfer
=
BroadcastArgsWarpper
<
T
,
ET
,
VecSize
,
5
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
5
>
(
ins
,
o
ffset_calculator
,
out
,
vec_len
);
ins
,
o
ut
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
data_transf
er
),
ET
,
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpp
er
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
data_transf
er
,
main_tid
,
tail_tid
);
broadcast_warpp
er
,
main_tid
,
tail_tid
);
break
;
break
;
}
}
case
6
:
{
case
6
:
{
auto
data_transfer
=
BroadcastArgsWarpper
<
T
,
ET
,
VecSize
,
6
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
6
>
(
ins
,
o
ffset_calculator
,
out
,
vec_len
);
ins
,
o
ut
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
data_transf
er
),
ET
,
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpp
er
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
data_transf
er
,
main_tid
,
tail_tid
);
broadcast_warpp
er
,
main_tid
,
tail_tid
);
break
;
break
;
}
}
case
7
:
{
case
7
:
{
auto
data_transfer
=
BroadcastArgsWarpper
<
T
,
ET
,
VecSize
,
7
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
7
>
(
ins
,
o
ffset_calculator
,
out
,
vec_len
);
ins
,
o
ut
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
data_transf
er
),
ET
,
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpp
er
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
data_transf
er
,
main_tid
,
tail_tid
);
broadcast_warpp
er
,
main_tid
,
tail_tid
);
break
;
break
;
}
}
case
8
:
{
case
8
:
{
auto
data_transfer
=
BroadcastArgsWarpper
<
T
,
ET
,
VecSize
,
8
>
(
auto
broadcast_warpper
=
BroadcastArgsWarpper
<
T
,
Functor
,
ET
,
VecSize
,
8
>
(
ins
,
o
ffset_calculator
,
out
,
vec_len
);
ins
,
o
ut
,
vec_len
,
func
,
offset_calculator
);
ElementwiseBroadcastKernel
<
T
,
decltype
(
data_transf
er
),
ET
,
ElementwiseBroadcastKernel
<
T
,
decltype
(
broadcast_warpp
er
),
ET
,
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
VecSize
><<<
blocks
,
threads
,
0
,
stream
>>>
(
data_transf
er
,
main_tid
,
tail_tid
);
broadcast_warpp
er
,
main_tid
,
tail_tid
);
break
;
break
;
}
}
default:
{
default:
{
...
@@ -437,9 +451,11 @@ void LaunchBroadcastKernelForDifferentDimSize(
...
@@ -437,9 +451,11 @@ void LaunchBroadcastKernelForDifferentDimSize(
template
<
ElementwiseType
ET
,
typename
T
,
typename
Functor
>
template
<
ElementwiseType
ET
,
typename
T
,
typename
Functor
>
void
LaunchBroadcastElementwiseCudaKernel
(
void
LaunchBroadcastElementwiseCudaKernel
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
framework
::
Tensor
*
out
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
Functor
func
,
int
axis
)
{
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
int
axis
,
Functor
func
)
{
static_assert
(
ET
==
(
ElementwiseType
)
2
,
"Only Support binary calculation."
);
int
in_vec_size
=
4
;
int
in_vec_size
=
4
;
framework
::
Tensor
*
out
=
(
*
outs
)[
0
];
for
(
auto
*
in
:
ins
)
{
for
(
auto
*
in
:
ins
)
{
auto
temp_size
=
GetVectorizedSizeImpl
<
T
>
(
in
->
data
<
T
>
());
auto
temp_size
=
GetVectorizedSizeImpl
<
T
>
(
in
->
data
<
T
>
());
in_vec_size
=
in
->
dims
()
==
out
->
dims
()
?
std
::
min
(
temp_size
,
in_vec_size
)
in_vec_size
=
in
->
dims
()
==
out
->
dims
()
?
std
::
min
(
temp_size
,
in_vec_size
)
...
@@ -450,19 +466,46 @@ void LaunchBroadcastElementwiseCudaKernel(
...
@@ -450,19 +466,46 @@ void LaunchBroadcastElementwiseCudaKernel(
switch
(
vec_size
)
{
switch
(
vec_size
)
{
case
4
:
{
case
4
:
{
LaunchBroadcastKernelForDifferentDimSize
<
T
,
ET
,
4
>
(
ctx
,
ins
,
out
,
axis
);
LaunchBroadcastKernelForDifferentDimSize
<
T
,
ET
,
4
>
(
ctx
,
ins
,
out
,
axis
,
func
);
break
;
break
;
}
}
case
2
:
{
case
2
:
{
LaunchBroadcastKernelForDifferentDimSize
<
T
,
ET
,
2
>
(
ctx
,
ins
,
out
,
axis
);
LaunchBroadcastKernelForDifferentDimSize
<
T
,
ET
,
2
>
(
ctx
,
ins
,
out
,
axis
,
func
);
break
;
}
case
1
:
{
LaunchBroadcastKernelForDifferentDimSize
<
T
,
ET
,
1
>
(
ctx
,
ins
,
out
,
axis
,
func
);
break
;
break
;
}
}
default:
{
default:
{
LaunchBroadcastKernelForDifferentDimSize
<
T
,
ET
,
1
>
(
ctx
,
ins
,
out
,
axis
);
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Unsupported vectorized size: %d !"
,
vec_size
));
break
;
break
;
}
}
}
}
}
}
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutType
,
typename
Functor
>
void
LaunchElementwiseCudaKernel
(
const
platform
::
CUDADeviceContext
&
cuda_ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
int
axis
,
Functor
func
)
{
bool
no_broadcast_flag
=
true
;
for
(
auto
*
in
:
ins
)
{
no_broadcast_flag
=
ins
[
0
]
->
dims
()
==
in
->
dims
();
}
if
(
no_broadcast_flag
)
{
LaunchSameDimsElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
InT
,
OutType
>
(
cuda_ctx
,
ins
,
outs
,
func
);
}
else
{
LaunchBroadcastElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
InT
>
(
cuda_ctx
,
ins
,
outs
,
axis
,
func
);
}
}
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
浏览文件 @
14949521
...
@@ -15,8 +15,7 @@ limitations under the License. */
...
@@ -15,8 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/fast_divmod.h"
#include "paddle/fluid/platform/float16.h"
#ifdef __HIPCC__
#ifdef __HIPCC__
#define ELEMENTWISE_BLOCK_SIZE 256
#define ELEMENTWISE_BLOCK_SIZE 256
...
@@ -29,11 +28,6 @@ namespace operators {
...
@@ -29,11 +28,6 @@ namespace operators {
enum
ElementwiseType
{
kUnary
=
1
,
kBinary
=
2
};
enum
ElementwiseType
{
kUnary
=
1
,
kBinary
=
2
};
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
CudaAlignedVector
{
T
val
[
Size
];
};
template
<
typename
T
>
template
<
typename
T
>
int
GetVectorizedSizeImpl
(
const
T
*
pointer
)
{
int
GetVectorizedSizeImpl
(
const
T
*
pointer
)
{
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
uint64_t
address
=
reinterpret_cast
<
uint64_t
>
(
pointer
);
...
@@ -181,7 +175,7 @@ __global__ void ScalarKernel(const InT *__restrict__ in0,
...
@@ -181,7 +175,7 @@ __global__ void ScalarKernel(const InT *__restrict__ in0,
}
}
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
void
LaunchElementwiseCudaKernel
(
void
Launch
SameDims
ElementwiseCudaKernel
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
Functor
func
)
{
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
Functor
func
)
{
...
...
paddle/fluid/
operators/elementwise/elementwise_op_broadcast_impl.cu
.h
→
paddle/fluid/
platform/fast_divmod
.h
浏览文件 @
14949521
...
@@ -14,13 +14,19 @@ limitations under the License. */
...
@@ -14,13 +14,19 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include <cstdint>
#include "paddle/fluid/platform/hostdevice.h"
#define INT_BITS 32
#define INT_BITS 32
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
template
<
typename
T
,
int
Size
>
struct
alignas
(
sizeof
(
T
)
*
Size
)
CudaAlignedVector
{
T
val
[
Size
];
};
struct
FastDivMod
{
struct
FastDivMod
{
// 1st value represents the result of input number divides by recorded divisor
// 1st value represents the result of input number divides by recorded divisor
// 2nd value represents the result of input number modulo by recorded divisor
// 2nd value represents the result of input number modulo by recorded divisor
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录