Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
c1187cd6
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c1187cd6
编写于
4月 08, 2020
作者:
W
wangchaochaohu
提交者:
GitHub
4月 08, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fp16 refine for fusion group (#23472)
上级
ce08fdcf
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
267 addition
and
61 deletion
+267
-61
paddle/fluid/framework/ir/fusion_group/code_generator.cc
paddle/fluid/framework/ir/fusion_group/code_generator.cc
+3
-3
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
.../fluid/framework/ir/fusion_group/code_generator_helper.cc
+38
-40
paddle/fluid/framework/ir/fusion_group/code_generator_helper.h
...e/fluid/framework/ir/fusion_group/code_generator_helper.h
+0
-1
paddle/fluid/framework/ir/fusion_group/cuda_resources.h
paddle/fluid/framework/ir/fusion_group/cuda_resources.h
+213
-6
paddle/fluid/framework/ir/fusion_group/operation.cc
paddle/fluid/framework/ir/fusion_group/operation.cc
+13
-11
未找到文件。
paddle/fluid/framework/ir/fusion_group/code_generator.cc
浏览文件 @
c1187cd6
...
@@ -36,7 +36,7 @@ std::string ExtractDataType(const std::vector<Node*>& nodes) {
...
@@ -36,7 +36,7 @@ std::string ExtractDataType(const std::vector<Node*>& nodes) {
}
else
if
(
dtype
==
proto
::
VarType
::
FP64
)
{
}
else
if
(
dtype
==
proto
::
VarType
::
FP64
)
{
dtype_str
=
"double"
;
dtype_str
=
"double"
;
}
else
if
(
dtype
==
proto
::
VarType
::
FP16
)
{
}
else
if
(
dtype
==
proto
::
VarType
::
FP16
)
{
dtype_str
=
"
float16
"
;
dtype_str
=
"
__half
"
;
}
}
break
;
break
;
}
}
...
@@ -147,13 +147,13 @@ std::string CodeGenerator::Generate(
...
@@ -147,13 +147,13 @@ std::string CodeGenerator::Generate(
}
}
std
::
string
predefined_cuda_functions
=
""
;
std
::
string
predefined_cuda_functions
=
""
;
if
(
all_dtype
.
find
(
"float"
)
!=
all_dtype
.
end
()
&&
if
(
all_dtype
.
find
(
"float"
)
!=
all_dtype
.
end
()
&&
all_dtype
.
find
(
"
float16
"
)
==
all_dtype
.
end
())
{
all_dtype
.
find
(
"
__half
"
)
==
all_dtype
.
end
())
{
predefined_cuda_functions
+=
predefined_cuda_functions_fp32
;
predefined_cuda_functions
+=
predefined_cuda_functions_fp32
;
}
}
if
(
all_dtype
.
find
(
"double"
)
!=
all_dtype
.
end
())
{
if
(
all_dtype
.
find
(
"double"
)
!=
all_dtype
.
end
())
{
predefined_cuda_functions
+=
predefined_cuda_functions_fp64
;
predefined_cuda_functions
+=
predefined_cuda_functions_fp64
;
}
}
if
(
all_dtype
.
find
(
"
float16
"
)
!=
all_dtype
.
end
())
{
if
(
all_dtype
.
find
(
"
__half
"
)
!=
all_dtype
.
end
())
{
predefined_cuda_functions
+=
predefined_cuda_functions_fp16
;
predefined_cuda_functions
+=
predefined_cuda_functions_fp16
;
}
}
return
predefined_cuda_functions
+
code_templates_
[
0
].
Format
(
template_var
);
return
predefined_cuda_functions
+
code_templates_
[
0
].
Format
(
template_var
);
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_helper.cc
浏览文件 @
c1187cd6
...
@@ -112,22 +112,7 @@ static std::string RefineTemplateWithAttr(const std::string& op_type,
...
@@ -112,22 +112,7 @@ static std::string RefineTemplateWithAttr(const std::string& op_type,
return
ret
;
return
ret
;
}
}
// In order to avoid multiple __half2float function calls, we do this
// optimization
static
std
::
string
OptimzeFP16RHS
(
std
::
unordered_set
<
int
>*
used
,
const
int
index
,
const
std
::
vector
<
int
>&
input_ids
)
{
std
::
stringstream
ret
;
if
(
used
->
find
(
input_ids
[
index
])
==
used
->
end
())
{
ret
<<
"float half2fp32_"
+
TmpName
(
input_ids
[
index
])
+
" = __half2float("
+
TmpName
(
input_ids
[
index
])
+
");"
;
}
return
ret
.
str
();
}
std
::
string
OperationExpression
::
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
std
::
string
OperationExpression
::
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
std
::
string
*
half2fp32_statement
,
size_t
exprs_index
)
const
{
size_t
exprs_index
)
const
{
auto
rhs
=
OperationMap
::
Instance
().
Get
(
op_type_
).
exprs
[
exprs_index
];
auto
rhs
=
OperationMap
::
Instance
().
Get
(
op_type_
).
exprs
[
exprs_index
];
auto
num_operands
=
OperationMap
::
Instance
().
Get
(
op_type_
).
num_operands
;
auto
num_operands
=
OperationMap
::
Instance
().
Get
(
op_type_
).
num_operands
;
...
@@ -136,16 +121,22 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
...
@@ -136,16 +121,22 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
size_t
input_size
=
input_ids_
.
size
();
size_t
input_size
=
input_ids_
.
size
();
rhs
=
ExpandMultivariateTemplate
(
rhs
,
input_size
);
rhs
=
ExpandMultivariateTemplate
(
rhs
,
input_size
);
}
}
for
(
size_t
i
=
0
;
i
<
rhs
.
size
();
i
++
)
{
size_t
pos
=
i
;
size_t
pos
=
0
;
while
(
pos
<
rhs
.
size
())
{
if
(
rhs
[
pos
]
==
'$'
&&
rhs
[
pos
+
1
]
==
'{'
)
{
if
(
rhs
[
pos
]
==
'$'
&&
rhs
[
pos
+
1
]
==
'{'
)
{
int
length
=
0
;
size_t
length
=
0
;
while
(
rhs
[
pos
+
2
+
length
]
!=
'}'
)
{
int
bracket_number
=
1
;
length
++
;
for
(
length
=
0
;
(
pos
+
2
+
length
)
<
rhs
.
size
();
length
++
)
{
char
ch
=
rhs
[
pos
+
2
+
length
];
if
(
ch
==
'}'
)
bracket_number
--
;
if
(
ch
==
'{'
)
bracket_number
++
;
if
(
bracket_number
==
0
)
break
;
}
}
std
::
string
index_str
=
rhs
.
substr
(
pos
+
2
,
length
);
std
::
string
index_str
=
rhs
.
substr
(
pos
+
2
,
length
);
std
::
string
refine_str
=
std
::
string
refine_str
=
RefineTemplateWithAttr
(
op_type_
,
index_str
,
attr_
);
RefineTemplateWithAttr
(
op_type_
,
index_str
,
attr_
);
std
::
string
var_name
;
if
(
index_str
==
refine_str
)
{
if
(
index_str
==
refine_str
)
{
int
index
=
StringTo
<
int
>
(
index_str
);
int
index
=
StringTo
<
int
>
(
index_str
);
PADDLE_ENFORCE_LT
(
index
,
input_ids_
.
size
(),
PADDLE_ENFORCE_LT
(
index
,
input_ids_
.
size
(),
...
@@ -160,20 +151,31 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
...
@@ -160,20 +151,31 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
index
,
op_type_
,
input_ids_
[
index
]));
index
,
op_type_
,
input_ids_
[
index
]));
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we
// need to add general fp16 compute later.
// need to add general fp16 compute later.
std
::
string
var_name
;
var_name
=
TmpName
(
input_ids_
[
index
]);
if
(
rhs_type_
==
"float16"
)
{
half2fp32_statement
->
append
(
OptimzeFP16RHS
(
used
,
index
,
input_ids_
));
var_name
=
"half2fp32_"
+
TmpName
(
input_ids_
[
index
]);
}
else
{
var_name
=
TmpName
(
input_ids_
[
index
]);
}
rhs
.
replace
(
pos
,
length
+
3
,
var_name
);
rhs
.
replace
(
pos
,
length
+
3
,
var_name
);
used
->
insert
(
input_ids_
[
index
]);
used
->
insert
(
input_ids_
[
index
]);
}
else
{
}
else
{
std
::
string
var_name
=
refine_str
;
var_name
=
refine_str
;
rhs
.
replace
(
pos
,
length
+
3
,
var_name
);
rhs
.
replace
(
pos
,
length
+
3
,
var_name
);
}
}
pos
=
pos
+
var_name
.
length
();
}
pos
++
;
}
pos
=
0
;
while
(
pos
<
rhs
.
size
())
{
if
(
rhs
[
pos
]
==
'%'
&&
rhs
[
pos
+
1
]
==
'{'
)
{
int
length
=
0
;
while
(
rhs
[
pos
+
2
+
length
]
!=
'}'
)
{
length
++
;
}
std
::
string
number_str
=
rhs
.
substr
(
pos
+
2
,
length
);
if
(
rhs_type_
==
"__half"
)
number_str
=
"__float2half("
+
number_str
+
")"
;
rhs
.
replace
(
pos
,
length
+
3
,
number_str
);
pos
=
pos
+
number_str
.
length
();
}
}
pos
++
;
}
}
return
rhs
;
return
rhs
;
}
}
...
@@ -192,28 +194,24 @@ bool OperationExpression::IsSupport() const {
...
@@ -192,28 +194,24 @@ bool OperationExpression::IsSupport() const {
// unique for the node which belong the group
// unique for the node which belong the group
std
::
string
OperationExpression
::
GetExpression
(
std
::
string
OperationExpression
::
GetExpression
(
std
::
unordered_set
<
int
>*
used
)
const
{
std
::
unordered_set
<
int
>*
used
)
const
{
std
::
string
half2fp32_statement
;
std
::
stringstream
ret
;
std
::
stringstream
ret
;
if
(
IsSupport
())
{
if
(
IsSupport
())
{
for
(
size_t
i
=
0
;
i
<
output_ids_
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
output_ids_
.
size
();
++
i
)
{
std
::
string
cast_str
=
""
;
std
::
string
cast_str
=
""
;
if
((
lhs_type_
==
rhs_type_
&&
rhs_type_
!=
"float16"
)
||
if
(
lhs_type_
==
rhs_type_
)
{
(
lhs_type_
!=
rhs_type_
&&
rhs_type_
==
"float16"
))
{
ret
<<
GetLHS
(
i
)
<<
" = "
<<
GetRHS
(
used
,
i
)
<<
";"
;
ret
<<
GetLHS
(
i
)
<<
" = "
<<
GetRHS
(
used
,
&
half2fp32_statement
,
i
)
<<
";"
;
}
else
{
}
else
{
if
((
lhs_type_
==
rhs_type_
&&
rhs_type_
==
"float16"
)
||
if
(
lhs_type_
==
"__half"
)
lhs_type_
==
"float16"
)
{
cast_str
=
"__float2half"
;
cast_str
=
"__float2half"
;
}
else
{
else
if
(
rhs_type_
==
"__half"
)
cast_str
=
"__half2float"
;
else
cast_str
=
"static_cast<"
+
lhs_type_
+
">"
;
cast_str
=
"static_cast<"
+
lhs_type_
+
">"
;
}
ret
<<
GetLHS
(
i
)
<<
" = "
<<
cast_str
<<
"("
<<
GetRHS
(
used
,
i
)
<<
");"
;
ret
<<
GetLHS
(
i
)
<<
" = "
<<
cast_str
<<
"("
<<
GetRHS
(
used
,
&
half2fp32_statement
,
i
)
<<
");"
;
}
}
}
}
}
}
return
half2fp32_statement
+
ret
.
str
();
return
ret
.
str
();
}
}
}
// namespace fusion_group
}
// namespace fusion_group
...
...
paddle/fluid/framework/ir/fusion_group/code_generator_helper.h
浏览文件 @
c1187cd6
...
@@ -68,7 +68,6 @@ class OperationExpression {
...
@@ -68,7 +68,6 @@ class OperationExpression {
private:
private:
// TODO(wangchao): make offset more flexible we add stride and basic offset
// TODO(wangchao): make offset more flexible we add stride and basic offset
std
::
string
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
std
::
string
GetRHS
(
std
::
unordered_set
<
int
>*
used
,
std
::
string
*
half2fp32_statement
,
size_t
exprs_index
=
0
)
const
;
size_t
exprs_index
=
0
)
const
;
std
::
string
GetLHS
(
size_t
i
=
0
)
const
;
std
::
string
GetLHS
(
size_t
i
=
0
)
const
;
...
...
paddle/fluid/framework/ir/fusion_group/cuda_resources.h
浏览文件 @
c1187cd6
...
@@ -36,11 +36,6 @@ __device__ inline double Sqrt(double x) { return sqrt(x); }
...
@@ -36,11 +36,6 @@ __device__ inline double Sqrt(double x) { return sqrt(x); }
)"
;
)"
;
static
constexpr
char
predefined_cuda_functions_fp16
[]
=
R"(
static
constexpr
char
predefined_cuda_functions_fp16
[]
=
R"(
__device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
...
@@ -65,6 +60,219 @@ __device__ float __half2float(const __half h) {
...
@@ -65,6 +60,219 @@ __device__ float __half2float(const __half h) {
return val;
return val;
}
}
#define __CUDA_FP16_DECL__ __host__ __device__
/******************************************************************************
* __half comparison *
******************************************************************************/
#define __COMPARISON_OP_HALF_MACRO(name) do {\
unsigned short val; \
asm( "{ .reg .pred __$temp3;\n" \
" setp."#name".f16 __$temp3, %1, %2;\n" \
" selp.u16 %0, 1, 0, __$temp3;}" \
: "=h"(val) : "h"(__HALF_TO_CUS(a)), "h"(__HALF_TO_CUS(b))); \
return val ? true : false; \
} while(0);
__CUDA_FP16_DECL__ bool __heq(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(eq);
}
__CUDA_FP16_DECL__ bool __hne(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(ne);
}
__CUDA_FP16_DECL__ bool __hle(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(le);
}
__CUDA_FP16_DECL__ bool __hge(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(ge);
}
__CUDA_FP16_DECL__ bool __hlt(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(lt);
}
__CUDA_FP16_DECL__ bool __hgt(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(gt);
}
__CUDA_FP16_DECL__ bool __hequ(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(equ);
}
__CUDA_FP16_DECL__ bool __hneu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(neu);
}
__CUDA_FP16_DECL__ bool __hleu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(leu);
}
__CUDA_FP16_DECL__ bool __hgeu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(geu);
}
__CUDA_FP16_DECL__ bool __hltu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(ltu);
}
__CUDA_FP16_DECL__ bool __hgtu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(gtu);
}
#undef __COMPARISON_OP_HALF_MACRO
/******************************************************************************
* __half arithmetic *
******************************************************************************/
#define __BINARY_OP_HALF_MACRO(name) do {\
__half val; \
asm( "{"#name".f16 %0,%1,%2;\n}" \
:"=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)),"h"(__HALF_TO_CUS(b))); \
return val; \
} while(0);
__CUDA_FP16_DECL__ __half __hadd(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(add);
}
__CUDA_FP16_DECL__ __half __hsub(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(sub);
}
__CUDA_FP16_DECL__ __half __hmul(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(mul);
}
__CUDA_FP16_DECL__ __half __hadd_sat(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(add.sat);
}
__CUDA_FP16_DECL__ __half __hsub_sat(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(sub.sat);
}
__CUDA_FP16_DECL__ __half __hmul_sat(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(mul.sat);
}
#undef __BINARY_OP_HALF_MACRO
#define __TERNARY_OP_HALF_MACRO(name) do {\
__half val; \
asm( "{"#name".f16 %0,%1,%2,%3;\n}" \
:"=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)),"h"(__HALF_TO_CUS(b)),"h"(__HALF_TO_CUS(c))); \
return val; \
} while(0);
__CUDA_FP16_DECL__ __half __hfma(const __half a, const __half b, const __half c)
{
__TERNARY_OP_HALF_MACRO(fma.rn);
}
__CUDA_FP16_DECL__ __half __hfma_sat(const __half a, const __half b, const __half c)
{
__TERNARY_OP_HALF_MACRO(fma.rn.sat);
}
#undef __TERNARY_OP_HALF2_MACRO
__CUDA_FP16_DECL__ __half __hdiv(__half a, __half b) {
__half v, abs, den;
__HALF_TO_US(den) = 0x008F;
float fa, fb, fv, rcp;
fa = __half2float(a);
fb = __half2float(b);
asm("{rcp.approx.f32 %0, %1;\n}" :"=f"(rcp) : "f"(fb));
fv = rcp * fa;
v = __float2half(fv);
__HALF_TO_US(abs) = (unsigned short)(((unsigned int)__HALF_TO_CUS(v)) & 0x00007FFF);
if (__hlt(abs, den) && (!(__HALF_TO_CUS(abs) == 0x0000))) {
float err = __fmaf_rn(-fb, fv, fa);
fv = __fmaf_rn(rcp, err, fv);
v = __float2half(fv);
}
return v;
}
/* Some basic arithmetic operations expected of a builtin */
__device__ __forceinline__ __half operator+(const __half &lh, const __half &rh) { return __hadd(lh, rh); }
__device__ __forceinline__ __half operator-(const __half &lh, const __half &rh) { return __hsub(lh, rh); }
__device__ __forceinline__ __half operator*(const __half &lh, const __half &rh) { return __hmul(lh, rh); }
__device__ __forceinline__ __half operator/(const __half &lh, const __half &rh) { return __hdiv(lh, rh); }
/* Some basic comparison operations to make it look like a builtin */
__device__ __forceinline__ bool operator==(const __half &lh, const __half &rh) { return __heq(lh, rh); }
__device__ __forceinline__ bool operator!=(const __half &lh, const __half &rh) { return __hne(lh, rh); }
__device__ __forceinline__ bool operator> (const __half &lh, const __half &rh) { return __hgt(lh, rh); }
__device__ __forceinline__ bool operator< (const __half &lh, const __half &rh) { return __hlt(lh, rh); }
__device__ __forceinline__ bool operator>=(const __half &lh, const __half &rh) { return __hge(lh, rh); }
__device__ __forceinline__ bool operator<=(const __half &lh, const __half &rh) { return __hle(lh, rh); }
#define __SPEC_CASE(i,r, spc, ulp) \
"{.reg.b16 spc, ulp, p;\n"\
" mov.b16 spc,"#spc";\n"\
" mov.b16 ulp,"#ulp";\n"\
" set.eq.f16.f16 p,"#i", spc;\n"\
" fma.rn.f16 "#r",p,ulp,"#r";\n}\n"
__CUDA_FP16_DECL__ __half hexp(const __half a) {
__half val;
asm("{.reg.b32 f, C; \n"
" .reg.b16 h,r; \n"
" mov.b16 h,%1; \n"
" cvt.f32.f16 f,h; \n"
" mov.b32 C, 0x3fb8aa3b; \n"
" mul.f32 f,f,C; \n"
" ex2.approx.f32 f,f; \n"
" cvt.rn.f16.f32 r,f; \n"
__SPEC_CASE(h, r, 0X1F79, 0x9400)
__SPEC_CASE(h, r, 0X25CF, 0x9400)
__SPEC_CASE(h, r, 0XC13B, 0x0400)
__SPEC_CASE(h, r, 0XC1EF, 0x0200)
" mov.b16 %0,r; \n"
"}": "=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)));
return val;
}
__CUDA_FP16_DECL__ __half hlog(const __half a) {
__half val;
asm("{.reg.b32 f, C; \n"
" .reg.b16 r,h; \n"
" mov.b16 h,%1; \n"
" cvt.f32.f16 f,h; \n"
" lg2.approx.f32 f,f; \n"
" mov.b32 C, 0x3f317218; \n"
" mul.f32 f,f,C; \n"
" cvt.rn.f16.f32 r,f; \n"
__SPEC_CASE(h, r, 0X160D, 0x9C00)
__SPEC_CASE(h, r, 0X3BFE, 0x8010)
__SPEC_CASE(h, r, 0X3C0B, 0x8080)
__SPEC_CASE(h, r, 0X6051, 0x1C00)
" mov.b16 %0,r; \n"
"}": "=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)));
return val;
}
#define __APPROX_FCAST(fun) do {\
__half val;\
asm("{.reg.b32 f; \n"\
" .reg.b16 r; \n"\
" mov.b16 r,%1; \n"\
" cvt.f32.f16 f,r; \n"\
" "#fun".approx.f32 f,f; \n"\
" cvt.rn.f16.f32 r,f; \n"\
" mov.b16 %0,r; \n"\
"}": "=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)));\
return val;\
} while(0);
__CUDA_FP16_DECL__ __half hsqrt(const __half a) {
__APPROX_FCAST(sqrt);
}
__device__ inline __half Exp(const __half x) { return hexp(x); }
__device__ inline __half Log(const __half x) { return hlog(x); }
__device__ inline __half Sqrt(const __half x) { return hsqrt(x); }
#undef __HALF_TO_US
#undef __HALF_TO_US
#undef __HALF_TO_CUS
#undef __HALF_TO_CUS
...
@@ -81,7 +289,6 @@ extern "C" __global__ void $func_name($parameters) {
...
@@ -81,7 +289,6 @@ extern "C" __global__ void $func_name($parameters) {
}
}
}
}
)"
;
)"
;
}
// namespace fusion_group
}
// namespace fusion_group
}
// namespace ir
}
// namespace ir
}
// namespace framework
}
// namespace framework
...
...
paddle/fluid/framework/ir/fusion_group/operation.cc
浏览文件 @
c1187cd6
...
@@ -91,17 +91,18 @@ void OperationMap::InsertUnaryElementwiseOperations() {
...
@@ -91,17 +91,18 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// relu:
// relu:
// out = f(x) = x > 0 ? x : 0
// out = f(x) = x > 0 ? x : 0
// dx = dout * (out > 0 ? 1 : 0)
// dx = dout * (out > 0 ? 1 : 0)
insert_handler
(
"relu"
,
"${0} > 0 ? ${0} : 0"
,
{
"${1} > 0 ? ${2} : 0"
});
insert_handler
(
"relu"
,
"${0} > %{0} ? ${0} : %{0.0}"
,
{
"${1} > %{0.0} ? ${2} : %{0.0}"
});
// sigmoid:
// sigmoid:
// out = f(x) = 1.0 / (1.0 + exp(-x))
// out = f(x) = 1.0 / (1.0 + exp(-x))
// dx = dout * out * (1 - out)
// dx = dout * out * (1 - out)
insert_handler
(
"sigmoid"
,
"
1.0 / (1.0
+ Exp(- ${0}))"
,
insert_handler
(
"sigmoid"
,
"
%{1.0} / (%{1.0}
+ Exp(- ${0}))"
,
{
"${2} * ${1} * (
1.0
- ${1})"
});
{
"${2} * ${1} * (
%{1.0}
- ${1})"
});
// tanh:
// tanh:
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
// dx = dout * (1 - out * out)
// dx = dout * (1 - out * out)
insert_handler
(
"tanh"
,
"
2.0 / (1.0 + Exp(-2.0 * ${0})) - 1.0
"
,
insert_handler
(
"tanh"
,
"
%{2.0} / (%{1.0} + Exp(-%{2.0} * ${0})) - %{1.0}
"
,
{
"${2} * (
1.0
- ${1} * ${1})"
});
{
"${2} * (
%{1.0}
- ${1} * ${1})"
});
// cast:
// cast:
// out = static_cast<T>(x)
// out = static_cast<T>(x)
...
@@ -112,22 +113,23 @@ void OperationMap::InsertUnaryElementwiseOperations() {
...
@@ -112,22 +113,23 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// sqrt:
// sqrt:
// out = x^(1/2)
// out = x^(1/2)
// dx = dout * 0.5 / out
// dx = dout * 0.5 / out
insert_handler
(
"sqrt"
,
"Sqrt(${0})"
,
{
"${2} *
0.5
/ ${1}"
});
insert_handler
(
"sqrt"
,
"Sqrt(${0})"
,
{
"${2} *
%{0.5}
/ ${1}"
});
// square:
// square:
// out = x^2
// out = x^2
// dx = dout * 2.0 * x
// dx = dout * 2.0 * x
insert_handler
(
"square"
,
"${0} * ${0}"
,
{
"${2} *
2.0
* ${0}"
});
insert_handler
(
"square"
,
"${0} * ${0}"
,
{
"${2} *
%{2.0}
* ${0}"
});
// scale
// scale
// out = (bias_after_scale) ? scale * X + bias : scale(X + bias)
// out = (bias_after_scale) ? scale * X + bias : scale(X + bias)
// here we use '=' operator to seperate th default value
// here we use '=' operator to seperate th default value
// TODO(wangchaochaohu): Later we need to support Tensor input for scale and
// TODO(wangchaochaohu): Later we need to support Tensor input for scale and
// bias.
// bias.
insert_handler
(
"scale"
,
insert_handler
(
"${bias_after_scale=true} ? (${scale=1.0} * ${0} + "
"scale"
,
"${bias=0.0}) : (${scale=1.0} * (${0} + ${bias=0.0}))"
,
"${bias_after_scale=true} ? (${scale=%{1.0}} * ${0} + "
{});
"${bias=%{0.0}}) : (${scale=%{1.0}} * (${0} + ${bias=%{0.0}}))"
,
{});
}
}
void
OperationMap
::
InsertBinaryElementwiseOperations
()
{
void
OperationMap
::
InsertBinaryElementwiseOperations
()
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录