未验证 提交 38886829 编写于 作者: L lanxianghit 提交者: GitHub

add cinn bf16 support (#53637)

添加CINN与Paddle框架的BFloat16类型映射
上级 82c73884
......@@ -88,6 +88,7 @@ namespace cpp = ::cinn::frontend::paddle::cpp;
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(BF16);
SET_DATA_TYPE_CASE_ITEM(FP16);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
......@@ -139,6 +140,7 @@ std::string VarDataTypeToString(
SET_DATA_TYPE_CASE_ITEM(INT16);
SET_DATA_TYPE_CASE_ITEM(INT32);
SET_DATA_TYPE_CASE_ITEM(INT64);
SET_DATA_TYPE_CASE_ITEM(BF16);
SET_DATA_TYPE_CASE_ITEM(FP16);
SET_DATA_TYPE_CASE_ITEM(FP32);
SET_DATA_TYPE_CASE_ITEM(FP64);
......
......@@ -40,6 +40,7 @@ namespace paddle::framework::paddle2cinn {
SET_TYPE_CASE_ITEM(UI32, UINT32)
SET_TYPE_CASE_ITEM(UI64, UINT64)
SET_TYPE_CASE_ITEM(BF16, BFLOAT16)
SET_TYPE_CASE_ITEM(F16, FLOAT16)
SET_TYPE_CASE_ITEM(F32, FLOAT32)
SET_TYPE_CASE_ITEM(F64, FLOAT64)
......@@ -70,6 +71,9 @@ namespace paddle::framework::paddle2cinn {
SET_TYPE_CASE_ITEM(cinn_float32_t, FLOAT32)
SET_TYPE_CASE_ITEM(cinn_float64_t, FLOAT64)
#ifdef CINN_COMMON_BFLOAT16_H
SET_TYPE_CASE_ITEM(cinn_bfloat16_t, BFLOAT16)
#endif // CINN_COMMON_BFLOAT16_H
#ifdef CINN_COMMON_FLOAT16_H
SET_TYPE_CASE_ITEM(cinn_float16_t, FLOAT16)
#endif // CINN_COMMON_FLOAT16_H
......
......@@ -39,6 +39,8 @@ TEST(TransToPaddleDataType, common_type) {
TransToPaddleDataType(::cinn::common::UI32()));
ASSERT_EQ(::phi::DataType::UINT64,
TransToPaddleDataType(::cinn::common::UI64()));
ASSERT_EQ(::phi::DataType::BFLOAT16,
TransToPaddleDataType(::cinn::common::BF16()));
ASSERT_EQ(::phi::DataType::FLOAT16,
TransToPaddleDataType(::cinn::common::F16()));
ASSERT_EQ(::phi::DataType::FLOAT32,
......
......@@ -635,6 +635,11 @@ if '${WITH_CINN}' == 'ON':
shutil.copy(cinn_fp16_file, libs_path)
package_data['paddle.libs']+=['float16.h']
cinn_bf16_file = '${CINN_INCLUDE_DIR}/cinn/runtime/cuda/bfloat16.h'
if os.path.exists(cinn_bf16_file):
shutil.copy(cinn_bf16_file, libs_path)
package_data['paddle.libs']+=['bfloat16.h']
if '${CMAKE_BUILD_TYPE}' == 'Release' and os.name != 'nt':
command = "patchelf --set-rpath '$ORIGIN/' %s/${CINN_LIB_NAME}" % libs_path
if os.system(command) != 0:
......
......@@ -1070,6 +1070,12 @@ def get_package_data_and_package_dir():
if os.path.exists(cinn_fp16_file):
shutil.copy(cinn_fp16_file, libs_path)
package_data['paddle.libs'] += ['float16.h']
cinn_bf16_file = (
env_dict.get("CINN_INCLUDE_DIR") + '/cinn/runtime/cuda/bfloat16.h'
)
if os.path.exists(cinn_bf16_file):
shutil.copy(cinn_bf16_file, libs_path)
package_data['paddle.libs'] += ['bfloat16.h']
if env_dict.get("CMAKE_BUILD_TYPE") == 'Release' and os.name != 'nt':
command = (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册