From ae542dc78c25e34c68ef6e7fca68847e781a250d Mon Sep 17 00:00:00 2001 From: haosicheng <47998305+HarperCy@users.noreply.github.com> Date: Thu, 1 Sep 2022 11:31:43 +0800 Subject: [PATCH] add support of cast kernel from fp32 to uint8 *test=kunlun (#45557) --- paddle/phi/kernels/xpu/cast_kernel.cc | 17 ++++++++++++----- .../tests/unittests/xpu/test_cast_op_xpu.py | 3 ++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/paddle/phi/kernels/xpu/cast_kernel.cc b/paddle/phi/kernels/xpu/cast_kernel.cc index 9aa503d5873..502b8324522 100644 --- a/paddle/phi/kernels/xpu/cast_kernel.cc +++ b/paddle/phi/kernels/xpu/cast_kernel.cc @@ -41,7 +41,7 @@ void CastKernel(const Context& dev_ctx, r = xpu::cast_v2( dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(dev_ctx.GetPlace()), + dev_ctx.template Alloc(out), numel); break; case phi::DataType::FLOAT16: @@ -49,28 +49,35 @@ void CastKernel(const Context& dev_ctx, dev_ctx.x_context(), reinterpret_cast(in_data), reinterpret_cast( - out->mutable_data(dev_ctx.GetPlace())), + dev_ctx.template Alloc(out)), numel); break; case phi::DataType::INT64: r = xpu::cast_v2( dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(dev_ctx.GetPlace()), + dev_ctx.template Alloc(out), numel); break; case phi::DataType::INT32: r = xpu::cast_v2( dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(dev_ctx.GetPlace()), + dev_ctx.template Alloc(out), numel); break; case phi::DataType::BOOL: r = xpu::cast_v2( dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(dev_ctx.GetPlace()), + dev_ctx.template Alloc(out), + numel); + break; + case phi::DataType::UINT8: + r = xpu::cast_v2( + dev_ctx.x_context(), + reinterpret_cast(in_data), + dev_ctx.template Alloc(out), numel); break; default: diff --git a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py index cd7062f66d9..abeb83d20c7 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_cast_op_xpu.py @@ -33,6 +33,7 @@ typeid_dict = { 'float32': int(core.VarDesc.VarType.FP32), 'float16': int(core.VarDesc.VarType.FP16), 'bool': int(core.VarDesc.VarType.BOOL), + 'uint8': int(core.VarDesc.VarType.UINT8), } @@ -45,7 +46,7 @@ class XPUTestCastOp(XPUOpTestWrapper): def dynamic_create_class(self): base_class = self.TestCastOp classes = [] - for out_type in {'float16', 'float32', 'int32', 'int64'}: + for out_type in {'float16', 'float32', 'int32', 'int64', 'uint8'}: class_name = 'XPUTestCastOp_outtype_' + out_type attr_dict = {'out_typename': out_type} classes.append([class_name, attr_dict]) -- GitLab