未验证 提交 c1187cd6 编写于 作者: W wangchaochaohu 提交者: GitHub

Fp16 refine for fusion group (#23472)

上级 ce08fdcf
......@@ -36,7 +36,7 @@ std::string ExtractDataType(const std::vector<Node*>& nodes) {
} else if (dtype == proto::VarType::FP64) {
dtype_str = "double";
} else if (dtype == proto::VarType::FP16) {
dtype_str = "float16";
dtype_str = "__half";
}
break;
}
......@@ -147,13 +147,13 @@ std::string CodeGenerator::Generate(
}
std::string predefined_cuda_functions = "";
if (all_dtype.find("float") != all_dtype.end() &&
all_dtype.find("float16") == all_dtype.end()) {
all_dtype.find("__half") == all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp32;
}
if (all_dtype.find("double") != all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp64;
}
if (all_dtype.find("float16") != all_dtype.end()) {
if (all_dtype.find("__half") != all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp16;
}
return predefined_cuda_functions + code_templates_[0].Format(template_var);
......
......@@ -112,22 +112,7 @@ static std::string RefineTemplateWithAttr(const std::string& op_type,
return ret;
}
// In order to avoid multiple __half2float function calls, we do this
// optimization
static std::string OptimzeFP16RHS(std::unordered_set<int>* used,
const int index,
const std::vector<int>& input_ids) {
std::stringstream ret;
if (used->find(input_ids[index]) == used->end()) {
ret << "float half2fp32_" + TmpName(input_ids[index]) + " = __half2float(" +
TmpName(input_ids[index]) + ");";
}
return ret.str();
}
std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
std::string* half2fp32_statement,
size_t exprs_index) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[exprs_index];
auto num_operands = OperationMap::Instance().Get(op_type_).num_operands;
......@@ -136,16 +121,22 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
size_t input_size = input_ids_.size();
rhs = ExpandMultivariateTemplate(rhs, input_size);
}
for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i;
size_t pos = 0;
while (pos < rhs.size()) {
if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
int length = 0;
while (rhs[pos + 2 + length] != '}') {
length++;
size_t length = 0;
int bracket_number = 1;
for (length = 0; (pos + 2 + length) < rhs.size(); length++) {
char ch = rhs[pos + 2 + length];
if (ch == '}') bracket_number--;
if (ch == '{') bracket_number++;
if (bracket_number == 0) break;
}
std::string index_str = rhs.substr(pos + 2, length);
std::string refine_str =
RefineTemplateWithAttr(op_type_, index_str, attr_);
std::string var_name;
if (index_str == refine_str) {
int index = StringTo<int>(index_str);
PADDLE_ENFORCE_LT(index, input_ids_.size(),
......@@ -160,20 +151,31 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
index, op_type_, input_ids_[index]));
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we
// need to add general fp16 compute later.
std::string var_name;
if (rhs_type_ == "float16") {
half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_));
var_name = "half2fp32_" + TmpName(input_ids_[index]);
} else {
var_name = TmpName(input_ids_[index]);
}
rhs.replace(pos, length + 3, var_name);
used->insert(input_ids_[index]);
} else {
std::string var_name = refine_str;
var_name = refine_str;
rhs.replace(pos, length + 3, var_name);
}
pos = pos + var_name.length();
}
pos++;
}
pos = 0;
while (pos < rhs.size()) {
if (rhs[pos] == '%' && rhs[pos + 1] == '{') {
int length = 0;
while (rhs[pos + 2 + length] != '}') {
length++;
}
std::string number_str = rhs.substr(pos + 2, length);
if (rhs_type_ == "__half")
number_str = "__float2half(" + number_str + ")";
rhs.replace(pos, length + 3, number_str);
pos = pos + number_str.length();
}
pos++;
}
return rhs;
}
......@@ -192,28 +194,24 @@ bool OperationExpression::IsSupport() const {
// unique for the node which belong the group
std::string OperationExpression::GetExpression(
std::unordered_set<int>* used) const {
std::string half2fp32_statement;
std::stringstream ret;
if (IsSupport()) {
for (size_t i = 0; i < output_ids_.size(); ++i) {
std::string cast_str = "";
if ((lhs_type_ == rhs_type_ && rhs_type_ != "float16") ||
(lhs_type_ != rhs_type_ && rhs_type_ == "float16")) {
ret << GetLHS(i) << " = " << GetRHS(used, &half2fp32_statement, i)
<< ";";
if (lhs_type_ == rhs_type_) {
ret << GetLHS(i) << " = " << GetRHS(used, i) << ";";
} else {
if ((lhs_type_ == rhs_type_ && rhs_type_ == "float16") ||
lhs_type_ == "float16") {
if (lhs_type_ == "__half")
cast_str = "__float2half";
} else {
else if (rhs_type_ == "__half")
cast_str = "__half2float";
else
cast_str = "static_cast<" + lhs_type_ + ">";
ret << GetLHS(i) << " = " << cast_str << "(" << GetRHS(used, i) << ");";
}
ret << GetLHS(i) << " = " << cast_str << "("
<< GetRHS(used, &half2fp32_statement, i) << ");";
}
}
}
return half2fp32_statement + ret.str();
return ret.str();
}
} // namespace fusion_group
......
......@@ -68,7 +68,6 @@ class OperationExpression {
private:
// TODO(wangchao): make offset more flexible we add stride and basic offset
std::string GetRHS(std::unordered_set<int>* used,
std::string* half2fp32_statement,
size_t exprs_index = 0) const;
std::string GetLHS(size_t i = 0) const;
......
......@@ -36,11 +36,6 @@ __device__ inline double Sqrt(double x) { return sqrt(x); }
)";
static constexpr char predefined_cuda_functions_fp16[] = R"(
__device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
......@@ -65,6 +60,219 @@ __device__ float __half2float(const __half h) {
return val;
}
#define __CUDA_FP16_DECL__ __host__ __device__
/******************************************************************************
* __half comparison *
******************************************************************************/
#define __COMPARISON_OP_HALF_MACRO(name) do {\
unsigned short val; \
asm( "{ .reg .pred __$temp3;\n" \
" setp."#name".f16 __$temp3, %1, %2;\n" \
" selp.u16 %0, 1, 0, __$temp3;}" \
: "=h"(val) : "h"(__HALF_TO_CUS(a)), "h"(__HALF_TO_CUS(b))); \
return val ? true : false; \
} while(0);
__CUDA_FP16_DECL__ bool __heq(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(eq);
}
__CUDA_FP16_DECL__ bool __hne(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(ne);
}
__CUDA_FP16_DECL__ bool __hle(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(le);
}
__CUDA_FP16_DECL__ bool __hge(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(ge);
}
__CUDA_FP16_DECL__ bool __hlt(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(lt);
}
__CUDA_FP16_DECL__ bool __hgt(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(gt);
}
__CUDA_FP16_DECL__ bool __hequ(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(equ);
}
__CUDA_FP16_DECL__ bool __hneu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(neu);
}
__CUDA_FP16_DECL__ bool __hleu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(leu);
}
__CUDA_FP16_DECL__ bool __hgeu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(geu);
}
__CUDA_FP16_DECL__ bool __hltu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(ltu);
}
__CUDA_FP16_DECL__ bool __hgtu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(gtu);
}
#undef __COMPARISON_OP_HALF_MACRO
/******************************************************************************
* __half arithmetic *
******************************************************************************/
#define __BINARY_OP_HALF_MACRO(name) do {\
__half val; \
asm( "{"#name".f16 %0,%1,%2;\n}" \
:"=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)),"h"(__HALF_TO_CUS(b))); \
return val; \
} while(0);
__CUDA_FP16_DECL__ __half __hadd(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(add);
}
__CUDA_FP16_DECL__ __half __hsub(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(sub);
}
__CUDA_FP16_DECL__ __half __hmul(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(mul);
}
__CUDA_FP16_DECL__ __half __hadd_sat(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(add.sat);
}
__CUDA_FP16_DECL__ __half __hsub_sat(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(sub.sat);
}
__CUDA_FP16_DECL__ __half __hmul_sat(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(mul.sat);
}
#undef __BINARY_OP_HALF_MACRO
#define __TERNARY_OP_HALF_MACRO(name) do {\
__half val; \
asm( "{"#name".f16 %0,%1,%2,%3;\n}" \
:"=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)),"h"(__HALF_TO_CUS(b)),"h"(__HALF_TO_CUS(c))); \
return val; \
} while(0);
__CUDA_FP16_DECL__ __half __hfma(const __half a, const __half b, const __half c)
{
__TERNARY_OP_HALF_MACRO(fma.rn);
}
__CUDA_FP16_DECL__ __half __hfma_sat(const __half a, const __half b, const __half c)
{
__TERNARY_OP_HALF_MACRO(fma.rn.sat);
}
#undef __TERNARY_OP_HALF2_MACRO
__CUDA_FP16_DECL__ __half __hdiv(__half a, __half b) {
__half v, abs, den;
__HALF_TO_US(den) = 0x008F;
float fa, fb, fv, rcp;
fa = __half2float(a);
fb = __half2float(b);
asm("{rcp.approx.f32 %0, %1;\n}" :"=f"(rcp) : "f"(fb));
fv = rcp * fa;
v = __float2half(fv);
__HALF_TO_US(abs) = (unsigned short)(((unsigned int)__HALF_TO_CUS(v)) & 0x00007FFF);
if (__hlt(abs, den) && (!(__HALF_TO_CUS(abs) == 0x0000))) {
float err = __fmaf_rn(-fb, fv, fa);
fv = __fmaf_rn(rcp, err, fv);
v = __float2half(fv);
}
return v;
}
/* Some basic arithmetic operations expected of a builtin */
__device__ __forceinline__ __half operator+(const __half &lh, const __half &rh) { return __hadd(lh, rh); }
__device__ __forceinline__ __half operator-(const __half &lh, const __half &rh) { return __hsub(lh, rh); }
__device__ __forceinline__ __half operator*(const __half &lh, const __half &rh) { return __hmul(lh, rh); }
__device__ __forceinline__ __half operator/(const __half &lh, const __half &rh) { return __hdiv(lh, rh); }
/* Some basic comparison operations to make it look like a builtin */
__device__ __forceinline__ bool operator==(const __half &lh, const __half &rh) { return __heq(lh, rh); }
__device__ __forceinline__ bool operator!=(const __half &lh, const __half &rh) { return __hne(lh, rh); }
__device__ __forceinline__ bool operator> (const __half &lh, const __half &rh) { return __hgt(lh, rh); }
__device__ __forceinline__ bool operator< (const __half &lh, const __half &rh) { return __hlt(lh, rh); }
__device__ __forceinline__ bool operator>=(const __half &lh, const __half &rh) { return __hge(lh, rh); }
__device__ __forceinline__ bool operator<=(const __half &lh, const __half &rh) { return __hle(lh, rh); }
#define __SPEC_CASE(i,r, spc, ulp) \
"{.reg.b16 spc, ulp, p;\n"\
" mov.b16 spc,"#spc";\n"\
" mov.b16 ulp,"#ulp";\n"\
" set.eq.f16.f16 p,"#i", spc;\n"\
" fma.rn.f16 "#r",p,ulp,"#r";\n}\n"
__CUDA_FP16_DECL__ __half hexp(const __half a) {
__half val;
asm("{.reg.b32 f, C; \n"
" .reg.b16 h,r; \n"
" mov.b16 h,%1; \n"
" cvt.f32.f16 f,h; \n"
" mov.b32 C, 0x3fb8aa3b; \n"
" mul.f32 f,f,C; \n"
" ex2.approx.f32 f,f; \n"
" cvt.rn.f16.f32 r,f; \n"
__SPEC_CASE(h, r, 0X1F79, 0x9400)
__SPEC_CASE(h, r, 0X25CF, 0x9400)
__SPEC_CASE(h, r, 0XC13B, 0x0400)
__SPEC_CASE(h, r, 0XC1EF, 0x0200)
" mov.b16 %0,r; \n"
"}": "=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)));
return val;
}
__CUDA_FP16_DECL__ __half hlog(const __half a) {
__half val;
asm("{.reg.b32 f, C; \n"
" .reg.b16 r,h; \n"
" mov.b16 h,%1; \n"
" cvt.f32.f16 f,h; \n"
" lg2.approx.f32 f,f; \n"
" mov.b32 C, 0x3f317218; \n"
" mul.f32 f,f,C; \n"
" cvt.rn.f16.f32 r,f; \n"
__SPEC_CASE(h, r, 0X160D, 0x9C00)
__SPEC_CASE(h, r, 0X3BFE, 0x8010)
__SPEC_CASE(h, r, 0X3C0B, 0x8080)
__SPEC_CASE(h, r, 0X6051, 0x1C00)
" mov.b16 %0,r; \n"
"}": "=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)));
return val;
}
#define __APPROX_FCAST(fun) do {\
__half val;\
asm("{.reg.b32 f; \n"\
" .reg.b16 r; \n"\
" mov.b16 r,%1; \n"\
" cvt.f32.f16 f,r; \n"\
" "#fun".approx.f32 f,f; \n"\
" cvt.rn.f16.f32 r,f; \n"\
" mov.b16 %0,r; \n"\
"}": "=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)));\
return val;\
} while(0);
__CUDA_FP16_DECL__ __half hsqrt(const __half a) {
__APPROX_FCAST(sqrt);
}
__device__ inline __half Exp(const __half x) { return hexp(x); }
__device__ inline __half Log(const __half x) { return hlog(x); }
__device__ inline __half Sqrt(const __half x) { return hsqrt(x); }
#undef __HALF_TO_US
#undef __HALF_TO_CUS
......@@ -81,7 +289,6 @@ extern "C" __global__ void $func_name($parameters) {
}
}
)";
} // namespace fusion_group
} // namespace ir
} // namespace framework
......
......@@ -91,17 +91,18 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// relu:
// out = f(x) = x > 0 ? x : 0
// dx = dout * (out > 0 ? 1 : 0)
insert_handler("relu", "${0} > 0 ? ${0} : 0", {"${1} > 0 ? ${2} : 0"});
insert_handler("relu", "${0} > %{0} ? ${0} : %{0.0}",
{"${1} > %{0.0} ? ${2} : %{0.0}"});
// sigmoid:
// out = f(x) = 1.0 / (1.0 + exp(-x))
// dx = dout * out * (1 - out)
insert_handler("sigmoid", "1.0 / (1.0 + Exp(- ${0}))",
{"${2} * ${1} * (1.0 - ${1})"});
insert_handler("sigmoid", "%{1.0} / (%{1.0} + Exp(- ${0}))",
{"${2} * ${1} * (%{1.0} - ${1})"});
// tanh:
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
// dx = dout * (1 - out * out)
insert_handler("tanh", "2.0 / (1.0 + Exp(-2.0 * ${0})) - 1.0",
{"${2} * (1.0 - ${1} * ${1})"});
insert_handler("tanh", "%{2.0} / (%{1.0} + Exp(-%{2.0} * ${0})) - %{1.0}",
{"${2} * (%{1.0} - ${1} * ${1})"});
// cast:
// out = static_cast<T>(x)
......@@ -112,21 +113,22 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// sqrt:
// out = x^(1/2)
// dx = dout * 0.5 / out
insert_handler("sqrt", "Sqrt(${0})", {"${2} * 0.5 / ${1}"});
insert_handler("sqrt", "Sqrt(${0})", {"${2} * %{0.5} / ${1}"});
// square:
// out = x^2
// dx = dout * 2.0 * x
insert_handler("square", "${0} * ${0}", {"${2} * 2.0 * ${0}"});
insert_handler("square", "${0} * ${0}", {"${2} * %{2.0} * ${0}"});
// scale
// out = (bias_after_scale) ? scale * X + bias : scale(X + bias)
// here we use '=' operator to seperate th default value
// TODO(wangchaochaohu): Later we need to support Tensor input for scale and
// bias.
insert_handler("scale",
"${bias_after_scale=true} ? (${scale=1.0} * ${0} + "
"${bias=0.0}) : (${scale=1.0} * (${0} + ${bias=0.0}))",
insert_handler(
"scale",
"${bias_after_scale=true} ? (${scale=%{1.0}} * ${0} + "
"${bias=%{0.0}}) : (${scale=%{1.0}} * (${0} + ${bias=%{0.0}}))",
{});
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册