Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
b432d024
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
b432d024
编写于
6月 02, 2021
作者:
L
limingshu
提交者:
GitHub
6月 02, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Support Add Sub Mul Max Min Pow binary functors in elementwise system (#33050)
上级
9c52adef
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
231 addition
and
112 deletion
+231
-112
paddle/fluid/operators/controlflow/compare_op.cu
paddle/fluid/operators/controlflow/compare_op.cu
+23
-24
paddle/fluid/operators/elementwise/elementwise_add_op.cu
paddle/fluid/operators/elementwise/elementwise_add_op.cu
+7
-4
paddle/fluid/operators/elementwise/elementwise_add_op.h
paddle/fluid/operators/elementwise/elementwise_add_op.h
+2
-4
paddle/fluid/operators/elementwise/elementwise_max_op.cu
paddle/fluid/operators/elementwise/elementwise_max_op.cu
+31
-0
paddle/fluid/operators/elementwise/elementwise_min_op.cu
paddle/fluid/operators/elementwise/elementwise_min_op.cu
+31
-0
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
+57
-28
paddle/fluid/operators/elementwise/elementwise_mul_op.h
paddle/fluid/operators/elementwise/elementwise_mul_op.h
+0
-1
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
...fluid/operators/elementwise/elementwise_op_broadcast.cu.h
+10
-14
paddle/fluid/operators/elementwise/elementwise_op_function.h
paddle/fluid/operators/elementwise/elementwise_op_function.h
+10
-6
paddle/fluid/operators/elementwise/elementwise_pow_op.cu
paddle/fluid/operators/elementwise/elementwise_pow_op.cu
+42
-0
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
+17
-30
paddle/fluid/operators/elementwise/elementwise_sub_op.h
paddle/fluid/operators/elementwise/elementwise_sub_op.h
+1
-1
未找到文件。
paddle/fluid/operators/controlflow/compare_op.cu
浏览文件 @
b432d024
...
...
@@ -21,21 +21,21 @@ namespace plat = paddle::platform;
namespace
paddle
{
namespace
operators
{
#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(
F
unc, op) \
#define DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT(
f
unc, op) \
template <typename T, typename Enable = void> \
struct
Func##Functor {
\
struct
func {
\
using ELEMENT_TYPE = T; \
inline HOSTDEVICE bool operator()(const T* args) const { \
return args[0] op args[1]; \
} \
};
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaLessThan
,
<
)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaLessEqual
,
<=
)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaGreaterThan
,
>
)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaGreaterEqual
,
>=
)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaEqual
,
==
)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaNotEqual
,
!=
)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaLessThan
Functor
,
<
)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaLessEqual
Functor
,
<=
)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaGreaterThan
Functor
,
>
)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaGreaterEqual
Functor
,
>=
)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaEqual
Functor
,
==
)
DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
(
CudaNotEqual
Functor
,
!=
)
#undef DEFINE_CMP_BINARY_FUNCTOR_WITH_PONTER_INPUT
template
<
typename
T
>
...
...
@@ -67,10 +67,12 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
auto
functor
=
Functor
();
std
::
vector
<
const
framework
::
Tensor
*>
ins
;
std
::
vector
<
framework
::
Tensor
*>
outs
;
const
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
PackTensorsIntoVector
<
OutT
>
(
ctx
,
&
ins
,
&
outs
);
int
axis
=
PackTensorsIntoVector
<
OutT
>
(
ctx
,
&
ins
,
&
outs
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
InT
,
OutT
>
(
c
tx
,
ins
,
&
out
s
,
functor
);
c
uda_ctx
,
ins
,
&
outs
,
axi
s
,
functor
);
}
};
...
...
@@ -79,19 +81,16 @@ class CompareOpKernel<platform::CUDADeviceContext, Functor, InverseFunctor>
#define REGISTER_CUDA_COMPARE_KERNEL(op_type, func) \
REGISTER_OP_CUDA_KERNEL( \
op_type, ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func##Functor<float>, \
void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, \
ops::func##Functor<double>, void>);
op_type, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<int64_t>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<float>, void>, \
ops::CompareOpKernel<plat::CUDADeviceContext, ops::func<double>, void>);
REGISTER_CUDA_COMPARE_KERNEL
(
equal
,
CudaEqual
)
REGISTER_CUDA_COMPARE_KERNEL
(
not_equal
,
CudaNotEqual
)
REGISTER_CUDA_COMPARE_KERNEL
(
less_than
,
CudaLessThan
)
REGISTER_CUDA_COMPARE_KERNEL
(
less_equal
,
CudaLessEqual
)
REGISTER_CUDA_COMPARE_KERNEL
(
greater_than
,
CudaGreaterThan
)
REGISTER_CUDA_COMPARE_KERNEL
(
greater_equal
,
CudaGreaterEqual
)
REGISTER_CUDA_COMPARE_KERNEL
(
equal
,
CudaEqual
Functor
)
REGISTER_CUDA_COMPARE_KERNEL
(
not_equal
,
CudaNotEqual
Functor
)
REGISTER_CUDA_COMPARE_KERNEL
(
less_than
,
CudaLessThan
Functor
)
REGISTER_CUDA_COMPARE_KERNEL
(
less_equal
,
CudaLessEqual
Functor
)
REGISTER_CUDA_COMPARE_KERNEL
(
greater_than
,
CudaGreaterThan
Functor
)
REGISTER_CUDA_COMPARE_KERNEL
(
greater_equal
,
CudaGreaterEqual
Functor
)
#undef REGISTER_CUDA_COMPARE_KERNEL
paddle/fluid/operators/elementwise/elementwise_add_op.cu
浏览文件 @
b432d024
...
...
@@ -28,11 +28,11 @@ namespace operators {
1. For Unary Op, the length of input array is 1,
e.g. Relu: return args[0] > 0 ? args[0] : 0;
2. For Binary Op, the length of input array is 2,
e.g. Add: return args[0]
+
args[1];
e.g. Add: return args[0]
expr
args[1];
*/
template
<
typename
T
>
struct
CudaAddFunctor
{
__device__
__forceinline__
T
operator
()(
const
T
*
args
)
const
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
+
args
[
1
];
}
};
...
...
@@ -44,9 +44,12 @@ class ElementwiseAddKernel<platform::CUDADeviceContext, T>
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
std
::
vector
<
const
framework
::
Tensor
*>
ins
;
std
::
vector
<
framework
::
Tensor
*>
outs
;
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
const
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
c
tx
,
ins
,
&
out
s
,
CudaAddFunctor
<
T
>
());
c
uda_ctx
,
ins
,
&
outs
,
axi
s
,
CudaAddFunctor
<
T
>
());
}
};
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.h
浏览文件 @
b432d024
...
...
@@ -72,12 +72,10 @@ class ElementwiseAddKernel : public framework::OpKernel<T> {
auto
*
z
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
x
->
dims
()
==
y
->
dims
())
{
SameDimsElemwiseAdd
<
platform
::
CPUDeviceContext
,
T
>
LaunchElementwiseCpuKernel
;
SameDimsElemwiseAdd
<
DeviceContext
,
T
>
LaunchElementwiseCpuKernel
;
LaunchElementwiseCpuKernel
(
ctx
,
x
,
y
,
z
);
}
else
{
LaunchBroadcastElementwiseCpuKernel
<
platform
::
CPUDeviceContext
,
T
>
(
ctx
,
x
,
y
,
z
);
LaunchBroadcastElementwiseCpuKernel
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
z
);
}
}
};
...
...
paddle/fluid/operators/elementwise/elementwise_max_op.cu
浏览文件 @
b432d024
...
...
@@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_max_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
namespace
ops
=
paddle
::
operators
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
CudaMaxFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
return
(
args
[
0
]
>
args
[
1
]
?
args
[
0
]
:
args
[
1
]);
}
};
template
<
typename
T
>
class
ElementwiseMaxKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
std
::
vector
<
const
framework
::
Tensor
*>
ins
;
std
::
vector
<
framework
::
Tensor
*>
outs
;
const
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
cuda_ctx
,
ins
,
&
outs
,
axis
,
CudaMaxFunctor
<
T
>
());
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
elementwise_max
,
ops
::
ElementwiseMaxKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_min_op.cu
浏览文件 @
b432d024
...
...
@@ -12,9 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_min_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
namespace
ops
=
paddle
::
operators
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
struct
CudaMinFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
return
(
args
[
0
]
>
args
[
1
]
?
args
[
1
]
:
args
[
0
]);
}
};
template
<
typename
T
>
class
ElementwiseMinKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
std
::
vector
<
const
framework
::
Tensor
*>
ins
;
std
::
vector
<
framework
::
Tensor
*>
outs
;
const
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
cuda_ctx
,
ins
,
&
outs
,
axis
,
CudaMinFunctor
<
T
>
());
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
elementwise_min
,
ops
::
ElementwiseMinKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_mul_op.cu
浏览文件 @
b432d024
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -24,37 +25,65 @@ namespace paddle {
namespace
operators
{
template
<
typename
T
>
struct
SameDimsElemwiseMul
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
MulRangeFunctor
<
T
>
functor
(
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
dev_ctx
,
x
->
numel
());
for_range
(
functor
);
struct
CudaMulFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
*
args
[
1
];
}
};
template
<
>
struct
SameDimsElemwiseMul
<
platform
::
CUDADeviceContext
,
platform
::
float16
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
size
=
x
->
numel
();
dim3
grid_size
=
dim3
(((
size
+
7
)
/
8
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
const
half
*
x2
=
reinterpret_cast
<
const
half
*>
(
x
->
data
<
platform
::
float16
>
());
const
half
*
y2
=
reinterpret_cast
<
const
half
*>
(
y
->
data
<
platform
::
float16
>
());
half
*
z2
=
reinterpret_cast
<
half
*>
(
z
->
data
<
platform
::
float16
>
());
SameDimsElemwiseMulCUDAKernel
<<<
grid_size
,
block_size
,
0
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
()
>>>
(
x2
,
y2
,
z2
,
size
);
template
<
typename
T
>
class
ElementwiseMulKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
int
axis
=
-
1
;
auto
x_var
=
ctx
.
InputVar
(
"X"
);
PADDLE_ENFORCE_NOT_NULL
(
x_var
,
platform
::
errors
::
InvalidArgument
(
"Cannot get input Variable X, Variable name = %s."
,
ctx
.
InputName
(
"X"
)));
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
framework
::
Tensor
x
,
*
z
;
std
::
vector
<
const
framework
::
Tensor
*>
ins
;
std
::
vector
<
framework
::
Tensor
*>
outs
;
const
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
if
(
x_var
->
IsType
<
framework
::
LoDTensor
>
())
{
x
=
x_var
->
Get
<
framework
::
LoDTensor
>
();
z
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
}
else
if
(
x_var
->
IsType
<
framework
::
SelectedRows
>
())
{
PADDLE_ENFORCE_EQ
(
y
->
dims
().
size
()
==
1
&&
y
->
dims
()[
0
]
==
1
,
true
,
platform
::
errors
::
InvalidArgument
(
"For elementwise_op, if X is Sparse, Y must be "
"scalar. But reveived the size of Y = %s."
,
y
->
dims
().
size
()));
auto
&
x_sele
=
x_var
->
Get
<
framework
::
SelectedRows
>
();
auto
out_sele
=
ctx
.
Output
<
framework
::
SelectedRows
>
(
"Out"
);
x
=
x_sele
.
value
();
out_sele
->
set_rows
(
x_sele
.
rows
());
out_sele
->
set_height
(
x_sele
.
height
());
out_sele
->
mutable_value
()
->
Resize
(
x_sele
.
value
().
dims
());
out_sele
->
mutable_value
()
->
mutable_data
(
ctx
.
GetPlace
(),
x
.
type
());
z
=
ctx
.
Output
<
framework
::
SelectedRows
>
(
"Out"
)
->
mutable_value
();
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
outs
.
emplace_back
(
z
);
ins
.
emplace_back
(
&
x
);
ins
.
emplace_back
(
y
);
axis
=
ctx
.
HasAttr
(
"axis"
)
?
ctx
.
Attr
<
int
>
(
"axis"
)
:
-
1
;
axis
=
axis
==
-
1
?
std
::
abs
(
y
->
dims
().
size
()
-
x
.
dims
().
size
())
:
axis
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"X's type[%s] is not supported by elementwise_op. X's type should be "
"LoDTensor or SelectedRows."
,
framework
::
ToTypeName
(
x_var
->
Type
())));
}
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
cuda_ctx
,
ins
,
&
outs
,
axis
,
CudaMulFunctor
<
T
>
());
}
};
...
...
paddle/fluid/operators/elementwise/elementwise_mul_op.h
浏览文件 @
b432d024
...
...
@@ -126,7 +126,6 @@ class ElementwiseMulKernel : public framework::OpKernel<T> {
}
}
};
template
<
typename
T
>
struct
MulGradDX
{
HOSTDEVICE
T
operator
()(
T
x
,
T
y
,
T
out
,
T
dout
)
const
{
return
dout
*
y
;
}
...
...
paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h
浏览文件 @
b432d024
...
...
@@ -465,7 +465,11 @@ void LaunchBroadcastElementwiseCudaKernel(
const
platform
::
CUDADeviceContext
&
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
int
axis
,
Functor
func
)
{
static_assert
(
ET
==
(
ElementwiseType
)
2
,
"Only Support binary calculation."
);
PADDLE_ENFORCE_EQ
(
ET
,
ElementwiseType
::
kBinary
,
platform
::
errors
::
InvalidArgument
(
"Currently, only Support binary calculation, "
"but received %d input tensors.
\n
"
,
static_cast
<
int
>
(
ET
)));
int
in_vec_size
=
4
;
framework
::
Tensor
*
out
=
(
*
outs
)[
0
];
for
(
auto
*
in
:
ins
)
{
...
...
@@ -502,26 +506,18 @@ void LaunchBroadcastElementwiseCudaKernel(
template
<
ElementwiseType
ET
,
typename
InT
,
typename
OutT
,
typename
Functor
>
void
LaunchElementwiseCudaKernel
(
const
framework
::
ExecutionContext
&
ctx
,
const
platform
::
CUDADeviceContext
&
cuda_
ctx
,
const
std
::
vector
<
const
framework
::
Tensor
*>
&
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
Functor
func
)
{
std
::
vector
<
int
>
dims_size
;
std
::
vector
<
framework
::
Tensor
*>
*
outs
,
int
axis
,
Functor
func
)
{
bool
no_broadcast_flag
=
true
;
for
(
auto
*
in
:
ins
)
{
no_broadcast_flag
=
ins
[
0
]
->
dims
()
==
in
->
dims
();
dims_size
.
emplace_back
(
in
->
dims
().
size
());
}
const
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
if
(
no_broadcast_flag
)
{
LaunchSameDimsElementwiseCudaKernel
<
E
lementwiseType
::
kBinary
,
InT
,
OutT
>
(
cuda_ctx
,
ins
,
outs
,
func
);
LaunchSameDimsElementwiseCudaKernel
<
E
T
,
InT
,
OutT
>
(
cuda_ctx
,
ins
,
outs
,
func
);
}
else
{
int
axis
=
ctx
.
HasAttr
(
"axis"
)
?
ctx
.
Attr
<
int
>
(
"axis"
)
:
-
1
;
axis
=
axis
==
-
1
?
*
std
::
max_element
(
dims_size
.
begin
(),
dims_size
.
end
())
-
*
std
::
min_element
(
dims_size
.
begin
(),
dims_size
.
end
())
:
axis
;
LaunchBroadcastElementwiseCudaKernel
<
ET
,
InT
,
OutT
>
(
cuda_ctx
,
ins
,
outs
,
axis
,
func
);
}
...
...
paddle/fluid/operators/elementwise/elementwise_op_function.h
浏览文件 @
b432d024
...
...
@@ -64,20 +64,24 @@ namespace operators {
* To pack the input and output tnesors into vector for
* LaunchElementwiseCudaKernel
*/
template
<
typename
T
>
void
PackTensorsIntoVector
(
const
framework
::
ExecutionContext
&
ctx
,
template
<
typename
Out
T
>
int
PackTensorsIntoVector
(
const
framework
::
ExecutionContext
&
ctx
,
std
::
vector
<
const
framework
::
Tensor
*>
*
ins
,
std
::
vector
<
framework
::
Tensor
*>
*
outs
)
{
int
axis
=
-
1
;
auto
*
x
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"Y"
);
auto
*
z
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
ins
->
emplace_back
(
x
);
z
->
mutable_data
<
OutT
>
(
ctx
.
GetPlace
());
outs
->
emplace_back
(
z
);
ins
->
emplace_back
(
x
);
if
(
y
!=
nullptr
)
{
ins
->
emplace_back
(
y
);
axis
=
ctx
.
HasAttr
(
"axis"
)
?
ctx
.
Attr
<
int
>
(
"axis"
)
:
-
1
;
axis
=
axis
==
-
1
?
std
::
abs
(
y
->
dims
().
size
()
-
x
->
dims
().
size
())
:
axis
;
}
return
axis
;
}
/*
...
...
paddle/fluid/operators/elementwise/elementwise_pow_op.cu
浏览文件 @
b432d024
...
...
@@ -8,10 +8,52 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_pow_op.h"
namespace
ops
=
paddle
::
operators
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
typename
Enable
=
void
>
struct
CudaPowFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
args
[])
const
{
return
std
::
pow
(
args
[
0
],
args
[
1
]);
}
};
template
<
typename
T
>
struct
CudaPowFunctor
<
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>
{
// On CUDAPlace, std::pow(3, 1) calls pow(float, float), and
// it will return a float number like 2.99... , which floor to 2
// when cast to int by default and it is wrong.
// Use llrint to cast it to the nearest integer, which is 3.
inline
HOSTDEVICE
T
operator
()(
const
T
args
[])
const
{
return
std
::
llrint
(
std
::
pow
(
args
[
0
],
args
[
1
]));
}
};
template
<
typename
T
>
class
ElementwisePowKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
std
::
vector
<
const
framework
::
Tensor
*>
ins
;
std
::
vector
<
framework
::
Tensor
*>
outs
;
const
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
cuda_ctx
,
ins
,
&
outs
,
axis
,
CudaPowFunctor
<
T
>
());
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
elementwise_pow
,
ops
::
ElementwisePowKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.cu
浏览文件 @
b432d024
...
...
@@ -11,8 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/platform/complex.h"
#include "paddle/fluid/platform/float16.h"
...
...
@@ -24,37 +23,25 @@ namespace paddle {
namespace
operators
{
template
<
typename
T
>
struct
SameDimsElemwiseSub
<
platform
::
CUDADeviceContext
,
T
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
SubRangeFunctor
<
T
>
functor
(
x
->
data
<
T
>
(),
y
->
data
<
T
>
(),
z
->
data
<
T
>
());
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
platform
::
ForRange
<
platform
::
CUDADeviceContext
>
for_range
(
dev_ctx
,
x
->
numel
());
for_range
(
functor
);
struct
CudaSubFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
*
args
)
const
{
return
args
[
0
]
-
args
[
1
];
}
};
template
<
>
struct
SameDimsElemwiseSub
<
platform
::
CUDADeviceContext
,
platform
::
float16
>
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
auto
size
=
x
->
numel
();
dim3
grid_size
=
dim3
(((
size
+
7
)
/
8
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
const
half
*
x2
=
reinterpret_cast
<
const
half
*>
(
x
->
data
<
platform
::
float16
>
());
const
half
*
y2
=
reinterpret_cast
<
const
half
*>
(
y
->
data
<
platform
::
float16
>
());
half
*
z2
=
reinterpret_cast
<
half
*>
(
z
->
data
<
platform
::
float16
>
());
SameDimsElemwiseSubCUDAKernel
<<<
grid_size
,
block_size
,
0
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>().
stream
()
>>>
(
x2
,
y2
,
z2
,
size
);
template
<
typename
T
>
class
ElementwiseSubKernel
<
platform
::
CUDADeviceContext
,
T
>
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
std
::
vector
<
const
framework
::
Tensor
*>
ins
;
std
::
vector
<
framework
::
Tensor
*>
outs
;
const
auto
&
cuda_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
int
axis
=
PackTensorsIntoVector
<
T
>
(
ctx
,
&
ins
,
&
outs
);
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
cuda_ctx
,
ins
,
&
outs
,
axis
,
CudaSubFunctor
<
T
>
());
}
};
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.h
浏览文件 @
b432d024
...
...
@@ -11,8 +11,8 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录