未验证 提交 93d2f0a6 编写于 作者: C chentianyu03 提交者: GitHub

[pten] Cast xpu kernel (#39179)

* cast xpu kernel init

* cast xpu kernel

* replace with raw cast xpu kernel

* fix cast kernel bug

* add the missing break

* modify namespace and header file
上级 2c0160e5
......@@ -20,6 +20,8 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h"
#include "xpu/refactor/math.h"
#include "paddle/pten/kernels/cast_kernel.h"
namespace paddle {
namespace operators {
......@@ -35,49 +37,21 @@ class CastXPUKernel : public framework::OpKernel<InT> {
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto in_type = static_cast<var_type::Type>(context.Attr<int>("in_dtype"));
auto out_type = static_cast<var_type::Type>(context.Attr<int>("out_dtype"));
auto* in_data = in->data<InT>();
auto out_dtype =
static_cast<var_type::Type>(context.Attr<int>("out_dtype"));
auto numel = in->numel();
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = -1;
switch (out_type) {
case var_type::FP32:
r = xpu::cast_v2<XPUInTDType, float>(
dev_ctx.x_context(), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<float>(context.GetPlace()), numel);
break;
case var_type::FP16:
r = xpu::cast_v2<XPUInTDType, float16>(
dev_ctx.x_context(), reinterpret_cast<const XPUInTDType*>(in_data),
reinterpret_cast<float16*>(
out->mutable_data<plat::float16>(context.GetPlace())),
numel);
break;
case var_type::INT64:
r = xpu::cast_v2<XPUInTDType, int64_t>(
dev_ctx.x_context(), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<int64_t>(context.GetPlace()), numel);
break;
case var_type::INT32:
r = xpu::cast_v2<XPUInTDType, int32_t>(
dev_ctx.x_context(), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<int>(context.GetPlace()), numel);
break;
case var_type::BOOL:
r = xpu::cast_v2<XPUInTDType, bool>(
dev_ctx.x_context(), reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<bool>(context.GetPlace()), numel);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"Not supported cast %d -> %d", in_type, out_type));
}
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU CAST API return wrong value[%d %s].", r,
XPUAPIErrorMsg[r]));
out->mutable_data(dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype));
auto pt_out_dtype = pten::TransToPtenDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
// call pten kernel
pten::CastKernel<InT>(
static_cast<const typename paddle::framework::ConvertToPtenContext<
DeviceContext>::TYPE&>(dev_ctx),
*in, pt_out_dtype, out);
}
};
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/pten/backends/xpu/xpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/enforce.h"
namespace pten {
template <typename T, typename Context>
void CastKernel(const Context& dev_ctx,
const DenseTensor& x,
DataType out_dtype,
DenseTensor* out) {
using XPUInTDType = typename XPUTypeTrait<T>::Type;
using float16 = typename XPUTypeTrait<pten::platform::float16>::Type;
auto* in_data = x.data<T>();
auto numel = x.numel();
int r = -1;
switch (out_dtype) {
case pten::DataType::FLOAT32:
r = xpu::cast_v2<XPUInTDType, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<float>(dev_ctx.GetPlace()),
numel);
break;
case pten::DataType::FLOAT16:
r = xpu::cast_v2<XPUInTDType, float16>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
reinterpret_cast<float16*>(
out->mutable_data<pten::platform::float16>(dev_ctx.GetPlace())),
numel);
break;
case pten::DataType::INT64:
r = xpu::cast_v2<XPUInTDType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<int64_t>(dev_ctx.GetPlace()),
numel);
break;
case pten::DataType::INT32:
r = xpu::cast_v2<XPUInTDType, int32_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<int>(dev_ctx.GetPlace()),
numel);
break;
case pten::DataType::BOOL:
r = xpu::cast_v2<XPUInTDType, bool>(
dev_ctx.x_context(),
reinterpret_cast<const XPUInTDType*>(in_data),
out->mutable_data<bool>(dev_ctx.GetPlace()),
numel);
break;
default:
PADDLE_THROW(platform::errors::Unavailable(
"Not supported cast %d -> %d", x.dtype(), out_dtype));
}
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
pten::errors::External(
"XPU CAST API return wrong value[%d %s].", r, XPUAPIErrorMsg[r]));
}
} // namespace pten
PT_REGISTER_KERNEL(cast,
XPU,
ALL_LAYOUT,
pten::CastKernel,
int32_t,
float,
pten::platform::float16,
int64_t,
bool) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册