未验证 提交 c1187cd6 编写于 作者: W wangchaochaohu 提交者: GitHub

Fp16 refine for fusion group (#23472)

上级 ce08fdcf
...@@ -36,7 +36,7 @@ std::string ExtractDataType(const std::vector<Node*>& nodes) { ...@@ -36,7 +36,7 @@ std::string ExtractDataType(const std::vector<Node*>& nodes) {
} else if (dtype == proto::VarType::FP64) { } else if (dtype == proto::VarType::FP64) {
dtype_str = "double"; dtype_str = "double";
} else if (dtype == proto::VarType::FP16) { } else if (dtype == proto::VarType::FP16) {
dtype_str = "float16"; dtype_str = "__half";
} }
break; break;
} }
...@@ -147,13 +147,13 @@ std::string CodeGenerator::Generate( ...@@ -147,13 +147,13 @@ std::string CodeGenerator::Generate(
} }
std::string predefined_cuda_functions = ""; std::string predefined_cuda_functions = "";
if (all_dtype.find("float") != all_dtype.end() && if (all_dtype.find("float") != all_dtype.end() &&
all_dtype.find("float16") == all_dtype.end()) { all_dtype.find("__half") == all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp32; predefined_cuda_functions += predefined_cuda_functions_fp32;
} }
if (all_dtype.find("double") != all_dtype.end()) { if (all_dtype.find("double") != all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp64; predefined_cuda_functions += predefined_cuda_functions_fp64;
} }
if (all_dtype.find("float16") != all_dtype.end()) { if (all_dtype.find("__half") != all_dtype.end()) {
predefined_cuda_functions += predefined_cuda_functions_fp16; predefined_cuda_functions += predefined_cuda_functions_fp16;
} }
return predefined_cuda_functions + code_templates_[0].Format(template_var); return predefined_cuda_functions + code_templates_[0].Format(template_var);
......
...@@ -112,22 +112,7 @@ static std::string RefineTemplateWithAttr(const std::string& op_type, ...@@ -112,22 +112,7 @@ static std::string RefineTemplateWithAttr(const std::string& op_type,
return ret; return ret;
} }
// In order to avoid multiple __half2float function calls, we do this
// optimization
static std::string OptimzeFP16RHS(std::unordered_set<int>* used,
const int index,
const std::vector<int>& input_ids) {
std::stringstream ret;
if (used->find(input_ids[index]) == used->end()) {
ret << "float half2fp32_" + TmpName(input_ids[index]) + " = __half2float(" +
TmpName(input_ids[index]) + ");";
}
return ret.str();
}
std::string OperationExpression::GetRHS(std::unordered_set<int>* used, std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
std::string* half2fp32_statement,
size_t exprs_index) const { size_t exprs_index) const {
auto rhs = OperationMap::Instance().Get(op_type_).exprs[exprs_index]; auto rhs = OperationMap::Instance().Get(op_type_).exprs[exprs_index];
auto num_operands = OperationMap::Instance().Get(op_type_).num_operands; auto num_operands = OperationMap::Instance().Get(op_type_).num_operands;
...@@ -136,16 +121,22 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used, ...@@ -136,16 +121,22 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
size_t input_size = input_ids_.size(); size_t input_size = input_ids_.size();
rhs = ExpandMultivariateTemplate(rhs, input_size); rhs = ExpandMultivariateTemplate(rhs, input_size);
} }
for (size_t i = 0; i < rhs.size(); i++) {
size_t pos = i; size_t pos = 0;
while (pos < rhs.size()) {
if (rhs[pos] == '$' && rhs[pos + 1] == '{') { if (rhs[pos] == '$' && rhs[pos + 1] == '{') {
int length = 0; size_t length = 0;
while (rhs[pos + 2 + length] != '}') { int bracket_number = 1;
length++; for (length = 0; (pos + 2 + length) < rhs.size(); length++) {
char ch = rhs[pos + 2 + length];
if (ch == '}') bracket_number--;
if (ch == '{') bracket_number++;
if (bracket_number == 0) break;
} }
std::string index_str = rhs.substr(pos + 2, length); std::string index_str = rhs.substr(pos + 2, length);
std::string refine_str = std::string refine_str =
RefineTemplateWithAttr(op_type_, index_str, attr_); RefineTemplateWithAttr(op_type_, index_str, attr_);
std::string var_name;
if (index_str == refine_str) { if (index_str == refine_str) {
int index = StringTo<int>(index_str); int index = StringTo<int>(index_str);
PADDLE_ENFORCE_LT(index, input_ids_.size(), PADDLE_ENFORCE_LT(index, input_ids_.size(),
...@@ -160,20 +151,31 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used, ...@@ -160,20 +151,31 @@ std::string OperationExpression::GetRHS(std::unordered_set<int>* used,
index, op_type_, input_ids_[index])); index, op_type_, input_ids_[index]));
// TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we // TODO(wangchaochaohu): Here fp16 convert to float to do comupte, we
// need to add general fp16 compute later. // need to add general fp16 compute later.
std::string var_name; var_name = TmpName(input_ids_[index]);
if (rhs_type_ == "float16") {
half2fp32_statement->append(OptimzeFP16RHS(used, index, input_ids_));
var_name = "half2fp32_" + TmpName(input_ids_[index]);
} else {
var_name = TmpName(input_ids_[index]);
}
rhs.replace(pos, length + 3, var_name); rhs.replace(pos, length + 3, var_name);
used->insert(input_ids_[index]); used->insert(input_ids_[index]);
} else { } else {
std::string var_name = refine_str; var_name = refine_str;
rhs.replace(pos, length + 3, var_name); rhs.replace(pos, length + 3, var_name);
} }
pos = pos + var_name.length();
}
pos++;
}
pos = 0;
while (pos < rhs.size()) {
if (rhs[pos] == '%' && rhs[pos + 1] == '{') {
int length = 0;
while (rhs[pos + 2 + length] != '}') {
length++;
}
std::string number_str = rhs.substr(pos + 2, length);
if (rhs_type_ == "__half")
number_str = "__float2half(" + number_str + ")";
rhs.replace(pos, length + 3, number_str);
pos = pos + number_str.length();
} }
pos++;
} }
return rhs; return rhs;
} }
...@@ -192,28 +194,24 @@ bool OperationExpression::IsSupport() const { ...@@ -192,28 +194,24 @@ bool OperationExpression::IsSupport() const {
// unique for the node which belong the group // unique for the node which belong the group
std::string OperationExpression::GetExpression( std::string OperationExpression::GetExpression(
std::unordered_set<int>* used) const { std::unordered_set<int>* used) const {
std::string half2fp32_statement;
std::stringstream ret; std::stringstream ret;
if (IsSupport()) { if (IsSupport()) {
for (size_t i = 0; i < output_ids_.size(); ++i) { for (size_t i = 0; i < output_ids_.size(); ++i) {
std::string cast_str = ""; std::string cast_str = "";
if ((lhs_type_ == rhs_type_ && rhs_type_ != "float16") || if (lhs_type_ == rhs_type_) {
(lhs_type_ != rhs_type_ && rhs_type_ == "float16")) { ret << GetLHS(i) << " = " << GetRHS(used, i) << ";";
ret << GetLHS(i) << " = " << GetRHS(used, &half2fp32_statement, i)
<< ";";
} else { } else {
if ((lhs_type_ == rhs_type_ && rhs_type_ == "float16") || if (lhs_type_ == "__half")
lhs_type_ == "float16") {
cast_str = "__float2half"; cast_str = "__float2half";
} else { else if (rhs_type_ == "__half")
cast_str = "__half2float";
else
cast_str = "static_cast<" + lhs_type_ + ">"; cast_str = "static_cast<" + lhs_type_ + ">";
} ret << GetLHS(i) << " = " << cast_str << "(" << GetRHS(used, i) << ");";
ret << GetLHS(i) << " = " << cast_str << "("
<< GetRHS(used, &half2fp32_statement, i) << ");";
} }
} }
} }
return half2fp32_statement + ret.str(); return ret.str();
} }
} // namespace fusion_group } // namespace fusion_group
......
...@@ -68,7 +68,6 @@ class OperationExpression { ...@@ -68,7 +68,6 @@ class OperationExpression {
private: private:
// TODO(wangchao): make offset more flexible we add stride and basic offset // TODO(wangchao): make offset more flexible we add stride and basic offset
std::string GetRHS(std::unordered_set<int>* used, std::string GetRHS(std::unordered_set<int>* used,
std::string* half2fp32_statement,
size_t exprs_index = 0) const; size_t exprs_index = 0) const;
std::string GetLHS(size_t i = 0) const; std::string GetLHS(size_t i = 0) const;
......
...@@ -36,11 +36,6 @@ __device__ inline double Sqrt(double x) { return sqrt(x); } ...@@ -36,11 +36,6 @@ __device__ inline double Sqrt(double x) { return sqrt(x); }
)"; )";
static constexpr char predefined_cuda_functions_fp16[] = R"( static constexpr char predefined_cuda_functions_fp16[] = R"(
__device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var))) #define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var))) #define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
...@@ -65,6 +60,219 @@ __device__ float __half2float(const __half h) { ...@@ -65,6 +60,219 @@ __device__ float __half2float(const __half h) {
return val; return val;
} }
#define __CUDA_FP16_DECL__ __host__ __device__
/******************************************************************************
* __half comparison *
******************************************************************************/
#define __COMPARISON_OP_HALF_MACRO(name) do {\
unsigned short val; \
asm( "{ .reg .pred __$temp3;\n" \
" setp."#name".f16 __$temp3, %1, %2;\n" \
" selp.u16 %0, 1, 0, __$temp3;}" \
: "=h"(val) : "h"(__HALF_TO_CUS(a)), "h"(__HALF_TO_CUS(b))); \
return val ? true : false; \
} while(0);
__CUDA_FP16_DECL__ bool __heq(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(eq);
}
__CUDA_FP16_DECL__ bool __hne(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(ne);
}
__CUDA_FP16_DECL__ bool __hle(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(le);
}
__CUDA_FP16_DECL__ bool __hge(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(ge);
}
__CUDA_FP16_DECL__ bool __hlt(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(lt);
}
__CUDA_FP16_DECL__ bool __hgt(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(gt);
}
__CUDA_FP16_DECL__ bool __hequ(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(equ);
}
__CUDA_FP16_DECL__ bool __hneu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(neu);
}
__CUDA_FP16_DECL__ bool __hleu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(leu);
}
__CUDA_FP16_DECL__ bool __hgeu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(geu);
}
__CUDA_FP16_DECL__ bool __hltu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(ltu);
}
__CUDA_FP16_DECL__ bool __hgtu(const __half a, const __half b)
{
__COMPARISON_OP_HALF_MACRO(gtu);
}
#undef __COMPARISON_OP_HALF_MACRO
/******************************************************************************
* __half arithmetic *
******************************************************************************/
#define __BINARY_OP_HALF_MACRO(name) do {\
__half val; \
asm( "{"#name".f16 %0,%1,%2;\n}" \
:"=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)),"h"(__HALF_TO_CUS(b))); \
return val; \
} while(0);
__CUDA_FP16_DECL__ __half __hadd(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(add);
}
__CUDA_FP16_DECL__ __half __hsub(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(sub);
}
__CUDA_FP16_DECL__ __half __hmul(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(mul);
}
__CUDA_FP16_DECL__ __half __hadd_sat(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(add.sat);
}
__CUDA_FP16_DECL__ __half __hsub_sat(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(sub.sat);
}
__CUDA_FP16_DECL__ __half __hmul_sat(const __half a, const __half b)
{
__BINARY_OP_HALF_MACRO(mul.sat);
}
#undef __BINARY_OP_HALF_MACRO
#define __TERNARY_OP_HALF_MACRO(name) do {\
__half val; \
asm( "{"#name".f16 %0,%1,%2,%3;\n}" \
:"=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)),"h"(__HALF_TO_CUS(b)),"h"(__HALF_TO_CUS(c))); \
return val; \
} while(0);
__CUDA_FP16_DECL__ __half __hfma(const __half a, const __half b, const __half c)
{
__TERNARY_OP_HALF_MACRO(fma.rn);
}
__CUDA_FP16_DECL__ __half __hfma_sat(const __half a, const __half b, const __half c)
{
__TERNARY_OP_HALF_MACRO(fma.rn.sat);
}
#undef __TERNARY_OP_HALF2_MACRO
__CUDA_FP16_DECL__ __half __hdiv(__half a, __half b) {
__half v, abs, den;
__HALF_TO_US(den) = 0x008F;
float fa, fb, fv, rcp;
fa = __half2float(a);
fb = __half2float(b);
asm("{rcp.approx.f32 %0, %1;\n}" :"=f"(rcp) : "f"(fb));
fv = rcp * fa;
v = __float2half(fv);
__HALF_TO_US(abs) = (unsigned short)(((unsigned int)__HALF_TO_CUS(v)) & 0x00007FFF);
if (__hlt(abs, den) && (!(__HALF_TO_CUS(abs) == 0x0000))) {
float err = __fmaf_rn(-fb, fv, fa);
fv = __fmaf_rn(rcp, err, fv);
v = __float2half(fv);
}
return v;
}
/* Some basic arithmetic operations expected of a builtin */
__device__ __forceinline__ __half operator+(const __half &lh, const __half &rh) { return __hadd(lh, rh); }
__device__ __forceinline__ __half operator-(const __half &lh, const __half &rh) { return __hsub(lh, rh); }
__device__ __forceinline__ __half operator*(const __half &lh, const __half &rh) { return __hmul(lh, rh); }
__device__ __forceinline__ __half operator/(const __half &lh, const __half &rh) { return __hdiv(lh, rh); }
/* Some basic comparison operations to make it look like a builtin */
__device__ __forceinline__ bool operator==(const __half &lh, const __half &rh) { return __heq(lh, rh); }
__device__ __forceinline__ bool operator!=(const __half &lh, const __half &rh) { return __hne(lh, rh); }
__device__ __forceinline__ bool operator> (const __half &lh, const __half &rh) { return __hgt(lh, rh); }
__device__ __forceinline__ bool operator< (const __half &lh, const __half &rh) { return __hlt(lh, rh); }
__device__ __forceinline__ bool operator>=(const __half &lh, const __half &rh) { return __hge(lh, rh); }
__device__ __forceinline__ bool operator<=(const __half &lh, const __half &rh) { return __hle(lh, rh); }
#define __SPEC_CASE(i,r, spc, ulp) \
"{.reg.b16 spc, ulp, p;\n"\
" mov.b16 spc,"#spc";\n"\
" mov.b16 ulp,"#ulp";\n"\
" set.eq.f16.f16 p,"#i", spc;\n"\
" fma.rn.f16 "#r",p,ulp,"#r";\n}\n"
__CUDA_FP16_DECL__ __half hexp(const __half a) {
__half val;
asm("{.reg.b32 f, C; \n"
" .reg.b16 h,r; \n"
" mov.b16 h,%1; \n"
" cvt.f32.f16 f,h; \n"
" mov.b32 C, 0x3fb8aa3b; \n"
" mul.f32 f,f,C; \n"
" ex2.approx.f32 f,f; \n"
" cvt.rn.f16.f32 r,f; \n"
__SPEC_CASE(h, r, 0X1F79, 0x9400)
__SPEC_CASE(h, r, 0X25CF, 0x9400)
__SPEC_CASE(h, r, 0XC13B, 0x0400)
__SPEC_CASE(h, r, 0XC1EF, 0x0200)
" mov.b16 %0,r; \n"
"}": "=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)));
return val;
}
__CUDA_FP16_DECL__ __half hlog(const __half a) {
__half val;
asm("{.reg.b32 f, C; \n"
" .reg.b16 r,h; \n"
" mov.b16 h,%1; \n"
" cvt.f32.f16 f,h; \n"
" lg2.approx.f32 f,f; \n"
" mov.b32 C, 0x3f317218; \n"
" mul.f32 f,f,C; \n"
" cvt.rn.f16.f32 r,f; \n"
__SPEC_CASE(h, r, 0X160D, 0x9C00)
__SPEC_CASE(h, r, 0X3BFE, 0x8010)
__SPEC_CASE(h, r, 0X3C0B, 0x8080)
__SPEC_CASE(h, r, 0X6051, 0x1C00)
" mov.b16 %0,r; \n"
"}": "=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)));
return val;
}
#define __APPROX_FCAST(fun) do {\
__half val;\
asm("{.reg.b32 f; \n"\
" .reg.b16 r; \n"\
" mov.b16 r,%1; \n"\
" cvt.f32.f16 f,r; \n"\
" "#fun".approx.f32 f,f; \n"\
" cvt.rn.f16.f32 r,f; \n"\
" mov.b16 %0,r; \n"\
"}": "=h"(__HALF_TO_US(val)) : "h"(__HALF_TO_CUS(a)));\
return val;\
} while(0);
__CUDA_FP16_DECL__ __half hsqrt(const __half a) {
__APPROX_FCAST(sqrt);
}
__device__ inline __half Exp(const __half x) { return hexp(x); }
__device__ inline __half Log(const __half x) { return hlog(x); }
__device__ inline __half Sqrt(const __half x) { return hsqrt(x); }
#undef __HALF_TO_US #undef __HALF_TO_US
#undef __HALF_TO_CUS #undef __HALF_TO_CUS
...@@ -81,7 +289,6 @@ extern "C" __global__ void $func_name($parameters) { ...@@ -81,7 +289,6 @@ extern "C" __global__ void $func_name($parameters) {
} }
} }
)"; )";
} // namespace fusion_group } // namespace fusion_group
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
......
...@@ -91,17 +91,18 @@ void OperationMap::InsertUnaryElementwiseOperations() { ...@@ -91,17 +91,18 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// relu: // relu:
// out = f(x) = x > 0 ? x : 0 // out = f(x) = x > 0 ? x : 0
// dx = dout * (out > 0 ? 1 : 0) // dx = dout * (out > 0 ? 1 : 0)
insert_handler("relu", "${0} > 0 ? ${0} : 0", {"${1} > 0 ? ${2} : 0"}); insert_handler("relu", "${0} > %{0} ? ${0} : %{0.0}",
{"${1} > %{0.0} ? ${2} : %{0.0}"});
// sigmoid: // sigmoid:
// out = f(x) = 1.0 / (1.0 + exp(-x)) // out = f(x) = 1.0 / (1.0 + exp(-x))
// dx = dout * out * (1 - out) // dx = dout * out * (1 - out)
insert_handler("sigmoid", "1.0 / (1.0 + Exp(- ${0}))", insert_handler("sigmoid", "%{1.0} / (%{1.0} + Exp(- ${0}))",
{"${2} * ${1} * (1.0 - ${1})"}); {"${2} * ${1} * (%{1.0} - ${1})"});
// tanh: // tanh:
// out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0; // out = f(x) = 2.0 / (1.0 + exp(-2.0 * x)) - 1.0;
// dx = dout * (1 - out * out) // dx = dout * (1 - out * out)
insert_handler("tanh", "2.0 / (1.0 + Exp(-2.0 * ${0})) - 1.0", insert_handler("tanh", "%{2.0} / (%{1.0} + Exp(-%{2.0} * ${0})) - %{1.0}",
{"${2} * (1.0 - ${1} * ${1})"}); {"${2} * (%{1.0} - ${1} * ${1})"});
// cast: // cast:
// out = static_cast<T>(x) // out = static_cast<T>(x)
...@@ -112,22 +113,23 @@ void OperationMap::InsertUnaryElementwiseOperations() { ...@@ -112,22 +113,23 @@ void OperationMap::InsertUnaryElementwiseOperations() {
// sqrt: // sqrt:
// out = x^(1/2) // out = x^(1/2)
// dx = dout * 0.5 / out // dx = dout * 0.5 / out
insert_handler("sqrt", "Sqrt(${0})", {"${2} * 0.5 / ${1}"}); insert_handler("sqrt", "Sqrt(${0})", {"${2} * %{0.5} / ${1}"});
// square: // square:
// out = x^2 // out = x^2
// dx = dout * 2.0 * x // dx = dout * 2.0 * x
insert_handler("square", "${0} * ${0}", {"${2} * 2.0 * ${0}"}); insert_handler("square", "${0} * ${0}", {"${2} * %{2.0} * ${0}"});
// scale // scale
// out = (bias_after_scale) ? scale * X + bias : scale(X + bias) // out = (bias_after_scale) ? scale * X + bias : scale(X + bias)
// here we use '=' operator to seperate th default value // here we use '=' operator to seperate th default value
// TODO(wangchaochaohu): Later we need to support Tensor input for scale and // TODO(wangchaochaohu): Later we need to support Tensor input for scale and
// bias. // bias.
insert_handler("scale", insert_handler(
"${bias_after_scale=true} ? (${scale=1.0} * ${0} + " "scale",
"${bias=0.0}) : (${scale=1.0} * (${0} + ${bias=0.0}))", "${bias_after_scale=true} ? (${scale=%{1.0}} * ${0} + "
{}); "${bias=%{0.0}}) : (${scale=%{1.0}} * (${0} + ${bias=%{0.0}}))",
{});
} }
void OperationMap::InsertBinaryElementwiseOperations() { void OperationMap::InsertBinaryElementwiseOperations() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册