未验证 提交 ae542dc7 编写于 作者: H haosicheng 提交者: GitHub

add support of cast kernel from fp32 to uint8 *test=kunlun (#45557)

上级 eb5b83e7
......@@ -41,7 +41,7 @@ void CastKernel(const Context& dev_ctx,
r = xpu::cast_v2<XPUInTDType, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<float>(dev_ctx.GetPlace()),
dev_ctx.template Alloc<float>(out),
numel);
break;
case phi::DataType::FLOAT16:
......@@ -49,28 +49,35 @@ void CastKernel(const Context& dev_ctx,
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
reinterpret_cast<float16*>(
out->mutable_data<phi::dtype::float16>(dev_ctx.GetPlace())),
dev_ctx.template Alloc<phi::dtype::float16>(out)),
numel);
break;
case phi::DataType::INT64:
r = xpu::cast_v2<XPUInTDType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<int64_t>(dev_ctx.GetPlace()),
dev_ctx.template Alloc<int64_t>(out),
numel);
break;
case phi::DataType::INT32:
r = xpu::cast_v2<XPUInTDType, int32_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<int>(dev_ctx.GetPlace()),
dev_ctx.template Alloc<int>(out),
numel);
break;
case phi::DataType::BOOL:
r = xpu::cast_v2<XPUInTDType, bool>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<bool>(dev_ctx.GetPlace()),
dev_ctx.template Alloc<bool>(out),
numel);
break;
case phi::DataType::UINT8:
r = xpu::cast_v2<XPUInTDType, uint8_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
dev_ctx.template Alloc<uint8_t>(out),
numel);
break;
default:
......
......@@ -33,6 +33,7 @@ typeid_dict = {
'float32': int(core.VarDesc.VarType.FP32),
'float16': int(core.VarDesc.VarType.FP16),
'bool': int(core.VarDesc.VarType.BOOL),
'uint8': int(core.VarDesc.VarType.UINT8),
}
......@@ -45,7 +46,7 @@ class XPUTestCastOp(XPUOpTestWrapper):
def dynamic_create_class(self):
base_class = self.TestCastOp
classes = []
for out_type in {'float16', 'float32', 'int32', 'int64'}:
for out_type in {'float16', 'float32', 'int32', 'int64', 'uint8'}:
class_name = 'XPUTestCastOp_outtype_' + out_type
attr_dict = {'out_typename': out_type}
classes.append([class_name, attr_dict])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册