未验证 提交 38886829 编写于 作者: L lanxianghit 提交者: GitHub

add cinn bf16 support (#53637)

添加CINN与Paddle框架的BFloat16类型映射
上级 82c73884
...@@ -88,6 +88,7 @@ namespace cpp = ::cinn::frontend::paddle::cpp; ...@@ -88,6 +88,7 @@ namespace cpp = ::cinn::frontend::paddle::cpp;
SET_DATA_TYPE_CASE_ITEM(INT16); SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32); SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64); SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(BF16);
SET_DATA_TYPE_CASE_ITEM(FP16); SET_DATA_TYPE_CASE_ITEM(FP16);
SET_DATA_TYPE_CASE_ITEM(FP32); SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64); SET_DATA_TYPE_CASE_ITEM(FP64);
...@@ -139,6 +140,7 @@ std::string VarDataTypeToString( ...@@ -139,6 +140,7 @@ std::string VarDataTypeToString(
SET_DATA_TYPE_CASE_ITEM(INT16); SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32); SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64); SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(BF16);
SET_DATA_TYPE_CASE_ITEM(FP16); SET_DATA_TYPE_CASE_ITEM(FP16);
SET_DATA_TYPE_CASE_ITEM(FP32); SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64); SET_DATA_TYPE_CASE_ITEM(FP64);
......
...@@ -40,6 +40,7 @@ namespace paddle::framework::paddle2cinn { ...@@ -40,6 +40,7 @@ namespace paddle::framework::paddle2cinn {
SET_TYPE_CASE_ITEM(UI32, UINT32) SET_TYPE_CASE_ITEM(UI32, UINT32)
SET_TYPE_CASE_ITEM(UI64, UINT64) SET_TYPE_CASE_ITEM(UI64, UINT64)
SET_TYPE_CASE_ITEM(BF16, BFLOAT16)
SET_TYPE_CASE_ITEM(F16, FLOAT16) SET_TYPE_CASE_ITEM(F16, FLOAT16)
SET_TYPE_CASE_ITEM(F32, FLOAT32) SET_TYPE_CASE_ITEM(F32, FLOAT32)
SET_TYPE_CASE_ITEM(F64, FLOAT64) SET_TYPE_CASE_ITEM(F64, FLOAT64)
...@@ -70,6 +71,9 @@ namespace paddle::framework::paddle2cinn { ...@@ -70,6 +71,9 @@ namespace paddle::framework::paddle2cinn {
SET_TYPE_CASE_ITEM(cinn_float32_t, FLOAT32) SET_TYPE_CASE_ITEM(cinn_float32_t, FLOAT32)
SET_TYPE_CASE_ITEM(cinn_float64_t, FLOAT64) SET_TYPE_CASE_ITEM(cinn_float64_t, FLOAT64)
#ifdef CINN_COMMON_BFLOAT16_H
SET_TYPE_CASE_ITEM(cinn_bfloat16_t, BFLOAT16)
#endif // CINN_COMMON_BFLOAT16_H
#ifdef CINN_COMMON_FLOAT16_H #ifdef CINN_COMMON_FLOAT16_H
SET_TYPE_CASE_ITEM(cinn_float16_t, FLOAT16) SET_TYPE_CASE_ITEM(cinn_float16_t, FLOAT16)
#endif // CINN_COMMON_FLOAT16_H #endif // CINN_COMMON_FLOAT16_H
......
...@@ -39,6 +39,8 @@ TEST(TransToPaddleDataType, common_type) { ...@@ -39,6 +39,8 @@ TEST(TransToPaddleDataType, common_type) {
TransToPaddleDataType(::cinn::common::UI32())); TransToPaddleDataType(::cinn::common::UI32()));
ASSERT_EQ(::phi::DataType::UINT64, ASSERT_EQ(::phi::DataType::UINT64,
TransToPaddleDataType(::cinn::common::UI64())); TransToPaddleDataType(::cinn::common::UI64()));
ASSERT_EQ(::phi::DataType::BFLOAT16,
TransToPaddleDataType(::cinn::common::BF16()));
ASSERT_EQ(::phi::DataType::FLOAT16, ASSERT_EQ(::phi::DataType::FLOAT16,
TransToPaddleDataType(::cinn::common::F16())); TransToPaddleDataType(::cinn::common::F16()));
ASSERT_EQ(::phi::DataType::FLOAT32, ASSERT_EQ(::phi::DataType::FLOAT32,
......
...@@ -635,6 +635,11 @@ if '${WITH_CINN}' == 'ON': ...@@ -635,6 +635,11 @@ if '${WITH_CINN}' == 'ON':
shutil.copy(cinn_fp16_file, libs_path) shutil.copy(cinn_fp16_file, libs_path)
package_data['paddle.libs']+=['float16.h'] package_data['paddle.libs']+=['float16.h']
cinn_bf16_file = '${CINN_INCLUDE_DIR}/cinn/runtime/cuda/bfloat16.h'
if os.path.exists(cinn_bf16_file):
shutil.copy(cinn_bf16_file, libs_path)
package_data['paddle.libs']+=['bfloat16.h']
if '${CMAKE_BUILD_TYPE}' == 'Release' and os.name != 'nt': if '${CMAKE_BUILD_TYPE}' == 'Release' and os.name != 'nt':
command = "patchelf --set-rpath '$ORIGIN/' %s/${CINN_LIB_NAME}" % libs_path command = "patchelf --set-rpath '$ORIGIN/' %s/${CINN_LIB_NAME}" % libs_path
if os.system(command) != 0: if os.system(command) != 0:
......
...@@ -1070,6 +1070,12 @@ def get_package_data_and_package_dir(): ...@@ -1070,6 +1070,12 @@ def get_package_data_and_package_dir():
if os.path.exists(cinn_fp16_file): if os.path.exists(cinn_fp16_file):
shutil.copy(cinn_fp16_file, libs_path) shutil.copy(cinn_fp16_file, libs_path)
package_data['paddle.libs'] += ['float16.h'] package_data['paddle.libs'] += ['float16.h']
cinn_bf16_file = (
env_dict.get("CINN_INCLUDE_DIR") + '/cinn/runtime/cuda/bfloat16.h'
)
if os.path.exists(cinn_bf16_file):
shutil.copy(cinn_bf16_file, libs_path)
package_data['paddle.libs'] += ['bfloat16.h']
if env_dict.get("CMAKE_BUILD_TYPE") == 'Release' and os.name != 'nt': if env_dict.get("CMAKE_BUILD_TYPE") == 'Release' and os.name != 'nt':
command = ( command = (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册