diff --git a/paddle/phi/kernels/xpu/cast_kernel.cc b/paddle/phi/kernels/xpu/cast_kernel.cc index 9aa503d58736defa477414df43cd812d75cfca36..502b8324522e666ab7c062309cd73e21f8956cd4 100644 --- a/paddle/phi/kernels/xpu/cast_kernel.cc +++ b/paddle/phi/kernels/xpu/cast_kernel.cc @@ -41,7 +41,7 @@ void CastKernel(const Context& dev_ctx, r = xpu::cast_v2( dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(dev_ctx.GetPlace()), + dev_ctx.template Alloc(out), numel); break; case phi::DataType::FLOAT16: @@ -49,28 +49,35 @@ void CastKernel(const Context& dev_ctx, dev_ctx.x_context(), reinterpret_cast(in_data), reinterpret_cast( - out->mutable_data(dev_ctx.GetPlace())), + dev_ctx.template Alloc(out)), numel); break; case phi::DataType::INT64: r = xpu::cast_v2( dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(dev_ctx.GetPlace()), + dev_ctx.template Alloc(out), numel); break; case phi::DataType::INT32: r = xpu::cast_v2( dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(dev_ctx.GetPlace()), + dev_ctx.template Alloc(out), numel); break; case phi::DataType::BOOL: r = xpu::cast_v2( dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(dev_ctx.GetPlace()), + dev_ctx.template Alloc(out), + numel); + break; + case phi::DataType::UINT8: + r = xpu::cast_v2( + dev_ctx.x_context(), + reinterpret_cast(in_data), + dev_ctx.template Alloc(out), numel); break; default: diff --git a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py index cd7062f66d995f6a3acc89a7478069e8f836425b..abeb83d20c740ba6d1085759312e55054e633fa9 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py @@ -33,6 +33,7 @@ typeid_dict = { 'float32': int(core.VarDesc.VarType.FP32), 'float16': int(core.VarDesc.VarType.FP16), 'bool': int(core.VarDesc.VarType.BOOL), + 'uint8': int(core.VarDesc.VarType.UINT8), } @@ -45,7 +46,7 @@ class XPUTestCastOp(XPUOpTestWrapper): def dynamic_create_class(self): base_class = self.TestCastOp classes = [] - for out_type in {'float16', 'float32', 'int32', 'int64'}: + for out_type in {'float16', 'float32', 'int32', 'int64', 'uint8'}: class_name = 'XPUTestCastOp_outtype_' + out_type attr_dict = {'out_typename': out_type} classes.append([class_name, attr_dict])