Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
30909889
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
30909889
编写于
10月 21, 2021
作者:
N
niuliling123
提交者:
GitHub
10月 21, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Cherry-pick] Add functor_primitives.h for kernel primitive api (#36418)
* Add functor_primitives.h for kernel primtive api
上级
a201a691
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
231 addition
and
0 deletion
+231
-0
paddle/fluid/operators/kernel_primitives/functor_primitives.h
...le/fluid/operators/kernel_primitives/functor_primitives.h
+230
-0
paddle/fluid/operators/kernel_primitives/kernel_primitives.h
paddle/fluid/operators/kernel_primitives/kernel_primitives.h
+1
-0
未找到文件。
paddle/fluid/operators/kernel_primitives/functor_primitives.h
0 → 100644
浏览文件 @
30909889
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace
paddle
{
namespace
operators
{
namespace
kernel_primitives
{
namespace
details
{
static
__device__
__forceinline__
platform
::
float16
Exp
(
platform
::
float16
x
)
{
return
::
Eigen
::
numext
::
exp
(
x
);
}
static
__device__
__forceinline__
float
Exp
(
float
x
)
{
return
expf
(
x
);
}
static
__device__
__forceinline__
double
Exp
(
double
x
)
{
return
exp
(
x
);
}
static
__device__
__forceinline__
platform
::
float16
Log
(
platform
::
float16
x
)
{
return
::
Eigen
::
numext
::
log
(
x
);
}
static
__device__
__forceinline__
float
Log
(
float
x
)
{
return
logf
(
x
);
}
static
__device__
__forceinline__
double
Log
(
double
x
)
{
return
log
(
x
);
}
}
// namespace details
/******************************** Unary Functor *******************************/
/**
* @brief Default unary exp functor
*/
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
ExpFunctor
{
HOSTDEVICE
inline
ExpFunctor
()
{}
HOSTDEVICE
explicit
inline
ExpFunctor
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
details
::
Exp
(
x
));
}
};
/**
* @brief Default unary identity functor
*/
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
IdentityFunctor
{
HOSTDEVICE
inline
IdentityFunctor
()
{}
HOSTDEVICE
explicit
inline
IdentityFunctor
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
);
}
};
/**
* @brief Default unary div functor. Divide by a constant
*/
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
DivideFunctor
{
HOSTDEVICE
inline
DivideFunctor
()
{
n_inv
=
static_cast
<
Tx
>
(
1.0
f
);
}
HOSTDEVICE
explicit
inline
DivideFunctor
(
int
n
)
:
n_inv
((
Tx
)(
1.0
/
n
))
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
*
n_inv
);
}
private:
Tx
n_inv
;
};
/**
* @brief Default unary square functor
*/
template
<
typename
Tx
,
typename
Ty
=
Tx
>
struct
SquareFunctor
{
HOSTDEVICE
inline
SquareFunctor
()
{}
HOSTDEVICE
explicit
inline
SquareFunctor
(
int
n
)
{}
HOSTDEVICE
inline
Ty
operator
()(
const
Tx
&
x
)
const
{
return
static_cast
<
Ty
>
(
x
)
*
static_cast
<
Ty
>
(
x
);
}
};
/****************************** Binary Functor ********************************/
/**
* @brief Default binary min functor
*/
template
<
typename
T
>
struct
MinFunctor
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
std
::
numeric_limits
<
T
>::
max
());
}
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
(
b
<
a
)
?
b
:
a
;
}
};
/**
* @brief Default binary max functor
*/
template
<
typename
T
>
struct
MaxFunctor
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
std
::
numeric_limits
<
T
>::
lowest
());
}
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
(
b
>
a
)
?
b
:
a
;
}
};
/**
* @brief Default binary add functor
*/
template
<
typename
T
>
struct
AddFunctor
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
0.0
f
);
}
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
b
+
a
;
}
};
/**
* @brief Default binary add functor
*/
template
<
typename
T
>
struct
MulFunctor
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
1.0
f
);
}
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
b
*
a
;
}
};
/**
* @brief Default binary logic or functor
*/
template
<
typename
T
>
struct
LogicalOrFunctor
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
false
);
}
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
b
||
a
;
}
};
/**
* @brief Default binary logic and functor
*/
template
<
typename
T
>
struct
LogicalAndFunctor
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
true
);
}
__device__
__forceinline__
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
b
&&
a
;
}
};
/**
* @brief Default binary sub functor
*/
template
<
typename
T
>
struct
SubFunctor
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
0.0
f
);
}
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
-
b
;
}
};
/**
* @brief Default binary div functor
*/
template
<
typename
T
,
typename
Enable
=
void
>
struct
DivFunctor
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
1.0
f
);
}
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
/
b
;
}
};
template
<
typename
T
>
struct
DivFunctor
<
T
,
typename
std
::
enable_if
<
std
::
is_integral
<
T
>::
value
>::
type
>
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
1.0
f
);
}
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
// For int32/int64, need to check whether the divison is zero.
PADDLE_ENFORCE_NE
(
b
,
0
,
platform
::
errors
::
InvalidArgument
(
"Integer division by zero encountered "
"in (floor) divide. Please check the input value."
));
return
a
/
b
;
}
};
/**
* @brief Default binary floor divide functor
*/
template
<
typename
T
>
struct
FloorDivFunctor
{
inline
T
initial
()
{
return
static_cast
<
T
>
(
1.0
f
);
}
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
PADDLE_ENFORCE_NE
(
b
,
0
,
platform
::
errors
::
InvalidArgument
(
"Integer division by zero encountered "
"in (floor) divide. Please check the input value."
));
return
static_cast
<
T
>
(
std
::
trunc
(
a
/
b
));
}
};
}
// namespace kernel_primitives
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/kernel_primitives/kernel_primitives.h
浏览文件 @
30909889
...
...
@@ -16,6 +16,7 @@
#include "paddle/fluid/operators/kernel_primitives/compute_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/datamover_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/functor_primitives.h"
#include "paddle/fluid/operators/kernel_primitives/helper_primitives.h"
namespace
paddle
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录