未验证 提交 ae542dc7 编写于 作者: H haosicheng 提交者: GitHub

add support of cast kernel from fp32 to uint8 *test=kunlun (#45557)

上级 eb5b83e7
...@@ -41,7 +41,7 @@ void CastKernel(const Context& dev_ctx, ...@@ -41,7 +41,7 @@ void CastKernel(const Context& dev_ctx,
r = xpu::cast_v2<XPUInTDType, float>( r = xpu::cast_v2<XPUInTDType, float>(
dev_ctx.x_context(), dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<float>(dev_ctx.GetPlace()), dev_ctx.template Alloc<float>(out),
numel); numel);
break; break;
case phi::DataType::FLOAT16: case phi::DataType::FLOAT16:
...@@ -49,28 +49,35 @@ void CastKernel(const Context& dev_ctx, ...@@ -49,28 +49,35 @@ void CastKernel(const Context& dev_ctx,
dev_ctx.x_context(), dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data), reinterpret_cast<const XPUInTDType*>(in_data),
reinterpret_cast<float16*>( reinterpret_cast<float16*>(
out->mutable_data<phi::dtype::float16>(dev_ctx.GetPlace())), dev_ctx.template Alloc<phi::dtype::float16>(out)),
numel); numel);
break; break;
case phi::DataType::INT64: case phi::DataType::INT64:
r = xpu::cast_v2<XPUInTDType, int64_t>( r = xpu::cast_v2<XPUInTDType, int64_t>(
dev_ctx.x_context(), dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<int64_t>(dev_ctx.GetPlace()), dev_ctx.template Alloc<int64_t>(out),
numel); numel);
break; break;
case phi::DataType::INT32: case phi::DataType::INT32:
r = xpu::cast_v2<XPUInTDType, int32_t>( r = xpu::cast_v2<XPUInTDType, int32_t>(
dev_ctx.x_context(), dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<int>(dev_ctx.GetPlace()), dev_ctx.template Alloc<int>(out),
numel); numel);
break; break;
case phi::DataType::BOOL: case phi::DataType::BOOL:
r = xpu::cast_v2<XPUInTDType, bool>( r = xpu::cast_v2<XPUInTDType, bool>(
dev_ctx.x_context(), dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<bool>(dev_ctx.GetPlace()), dev_ctx.template Alloc<bool>(out),
numel);
break;
case phi::DataType::UINT8:
r = xpu::cast_v2<XPUInTDType, uint8_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<uint8_t>(out),
numel); numel);
break; break;
default: default:
......
...@@ -33,6 +33,7 @@ typeid_dict = { ...@@ -33,6 +33,7 @@ typeid_dict = {
'float32': int(core.VarDesc.VarType.FP32), 'float32': int(core.VarDesc.VarType.FP32),
'float16': int(core.VarDesc.VarType.FP16), 'float16': int(core.VarDesc.VarType.FP16),
'bool': int(core.VarDesc.VarType.BOOL), 'bool': int(core.VarDesc.VarType.BOOL),
'uint8': int(core.VarDesc.VarType.UINT8),
} }
...@@ -45,7 +46,7 @@ class XPUTestCastOp(XPUOpTestWrapper): ...@@ -45,7 +46,7 @@ class XPUTestCastOp(XPUOpTestWrapper):
def dynamic_create_class(self): def dynamic_create_class(self):
base_class = self.TestCastOp base_class = self.TestCastOp
classes = [] classes = []
for out_type in {'float16', 'float32', 'int32', 'int64'}: for out_type in {'float16', 'float32', 'int32', 'int64', 'uint8'}:
class_name = 'XPUTestCastOp_outtype_' + out_type class_name = 'XPUTestCastOp_outtype_' + out_type
attr_dict = {'out_typename': out_type} attr_dict = {'out_typename': out_type}
classes.append([class_name, attr_dict]) classes.append([class_name, attr_dict])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册