From 9dad4f79c972a1d99e149cc956833fd8cf092731 Mon Sep 17 00:00:00 2001 From: risemeup1 <62429225+risemeup1@users.noreply.github.com> Date: Tue, 30 Aug 2022 11:09:48 +0800 Subject: [PATCH] move cast XPU kernel to PHI,test=kunlun (#45534) --- paddle/fluid/operators/cast_op_xpu.cc | 72 --------------------------- 1 file changed, 72 deletions(-) delete mode 100644 paddle/fluid/operators/cast_op_xpu.cc diff --git a/paddle/fluid/operators/cast_op_xpu.cc b/paddle/fluid/operators/cast_op_xpu.cc deleted file mode 100644 index 4956581cc8c..00000000000 --- a/paddle/fluid/operators/cast_op_xpu.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include - -#include "paddle/fluid/framework/convert_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/cast_op.h" -#include "paddle/fluid/platform/float16.h" -#include "paddle/phi/kernels/cast_kernel.h" -#include "xpu/refactor/math.h" - -namespace paddle { -namespace operators { - -using var_type = framework::proto::VarType; -namespace plat = paddle::platform; - -template -class CastXPUKernel : public framework::OpKernel { - using XPUInTDType = typename XPUTypeTrait::Type; - using float16 = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* in = context.Input("X"); - auto* out = context.Output("Out"); - auto out_dtype = - static_cast(context.Attr("out_dtype")); - - auto& dev_ctx = context.template device_context(); - - out->mutable_data(dev_ctx.GetPlace(), - static_cast(out_dtype)); - - auto pt_out_dtype = framework::TransToPhiDataType( - static_cast(out_dtype)); - // call phi kernel - phi::CastKernel( - static_cast::TYPE&>(dev_ctx), - *in, - pt_out_dtype, - out); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - cast, - ops::CastXPUKernel, - ops::CastXPUKernel, - ops::CastXPUKernel, - ops::CastXPUKernel, - ops::CastXPUKernel); -#endif -- GitLab