diff --git a/paddle/fluid/operators/cast_op_xpu.cc b/paddle/fluid/operators/cast_op_xpu.cc index c3e23831e65286d3d65e8ae80933e0725ad85f27..0ec5ca707a6b5978ea84b5b15670cba0a303cbbf 100644 --- a/paddle/fluid/operators/cast_op_xpu.cc +++ b/paddle/fluid/operators/cast_op_xpu.cc @@ -20,6 +20,8 @@ limitations under the License. */ #include "paddle/fluid/platform/float16.h" #include "xpu/refactor/math.h" +#include "paddle/pten/kernels/cast_kernel.h" + namespace paddle { namespace operators { @@ -35,49 +37,21 @@ class CastXPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out = context.Output("Out"); - auto in_type = static_cast(context.Attr("in_dtype")); - auto out_type = static_cast(context.Attr("out_dtype")); - auto* in_data = in->data(); + auto out_dtype = + static_cast(context.Attr("out_dtype")); - auto numel = in->numel(); auto& dev_ctx = context.template device_context(); - int r = -1; - switch (out_type) { - case var_type::FP32: - r = xpu::cast_v2( - dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(context.GetPlace()), numel); - break; - case var_type::FP16: - r = xpu::cast_v2( - dev_ctx.x_context(), reinterpret_cast(in_data), - reinterpret_cast( - out->mutable_data(context.GetPlace())), - numel); - break; - case var_type::INT64: - r = xpu::cast_v2( - dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(context.GetPlace()), numel); - break; - case var_type::INT32: - r = xpu::cast_v2( - dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(context.GetPlace()), numel); - break; - case var_type::BOOL: - r = xpu::cast_v2( - dev_ctx.x_context(), reinterpret_cast(in_data), - out->mutable_data(context.GetPlace()), numel); - break; - default: - PADDLE_THROW(platform::errors::Unavailable( - "Not supported cast %d -> %d", in_type, out_type)); - } - PADDLE_ENFORCE_EQ( - r, XPU_SUCCESS, - platform::errors::External("XPU CAST API return wrong value[%d %s].", r, - XPUAPIErrorMsg[r])); + + out->mutable_data(dev_ctx.GetPlace(), + static_cast(out_dtype)); + + auto pt_out_dtype = pten::TransToPtenDataType( + static_cast(out_dtype)); + // call pten kernel + pten::CastKernel( + static_cast::TYPE&>(dev_ctx), + *in, pt_out_dtype, out); } }; diff --git a/paddle/pten/kernels/xpu/cast_kernel.cc b/paddle/pten/kernels/xpu/cast_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc1ba021e22bc0847d25e8c50e5f87dcac2ae41e --- /dev/null +++ b/paddle/pten/kernels/xpu/cast_kernel.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/cast_kernel.h" + +#include "paddle/fluid/platform/device/xpu/xpu_header.h" +#include "paddle/pten/backends/xpu/xpu_context.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/pten/core/enforce.h" + +namespace pten { + +template +void CastKernel(const Context& dev_ctx, + const DenseTensor& x, + DataType out_dtype, + DenseTensor* out) { + using XPUInTDType = typename XPUTypeTrait::Type; + using float16 = typename XPUTypeTrait::Type; + + auto* in_data = x.data(); + auto numel = x.numel(); + + int r = -1; + switch (out_dtype) { + case pten::DataType::FLOAT32: + r = xpu::cast_v2( + dev_ctx.x_context(), + reinterpret_cast(in_data), + out->mutable_data(dev_ctx.GetPlace()), + numel); + break; + case pten::DataType::FLOAT16: + r = xpu::cast_v2( + dev_ctx.x_context(), + reinterpret_cast(in_data), + reinterpret_cast( + out->mutable_data(dev_ctx.GetPlace())), + numel); + break; + case pten::DataType::INT64: + r = xpu::cast_v2( + dev_ctx.x_context(), + reinterpret_cast(in_data), + out->mutable_data(dev_ctx.GetPlace()), + numel); + break; + case pten::DataType::INT32: + r = xpu::cast_v2( + dev_ctx.x_context(), + reinterpret_cast(in_data), + out->mutable_data(dev_ctx.GetPlace()), + numel); + break; + case pten::DataType::BOOL: + r = xpu::cast_v2( + dev_ctx.x_context(), + reinterpret_cast(in_data), + out->mutable_data(dev_ctx.GetPlace()), + numel); + break; + default: + PADDLE_THROW(platform::errors::Unavailable( + "Not supported cast %d -> %d", x.dtype(), out_dtype)); + } + + PADDLE_ENFORCE_EQ( + r, + XPU_SUCCESS, + pten::errors::External( + "XPU CAST API return wrong value[%d %s].", r, XPUAPIErrorMsg[r])); +} +} // namespace pten + +PT_REGISTER_KERNEL(cast, + XPU, + ALL_LAYOUT, + pten::CastKernel, + int32_t, + float, + pten::platform::float16, + int64_t, + bool) { + kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); +}