diff --git a/paddle/fluid/operators/clip_op_xpu.cc b/paddle/fluid/operators/clip_op_xpu.cc deleted file mode 100644 index 03afed9adc1e328950ca8ea4d41e74fa86feb47e..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/clip_op_xpu.cc +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifdef PADDLE_WITH_XPU - -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class ClipXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - auto max = static_cast(ctx.Attr("max")); - if (ctx.HasInput("Max")) { - Tensor max_cpu; - auto* max_t = ctx.Input("Max"); - auto* max_data = max_t->data(); - if (platform::is_xpu_place(max_t->place())) { - paddle::framework::TensorCopySync( - *max_t, platform::CPUPlace(), &max_cpu); - max_data = max_cpu.data(); - } - max = max_data[0]; - } - - auto min = ctx.Attr("min"); - if (ctx.HasInput("Min")) { - Tensor min_cpu; - auto* min_t = ctx.Input("Min"); - auto* min_data = min_t->data(); - if (platform::is_xpu_place(min_t->place())) { - paddle::framework::TensorCopySync( - *min_t, platform::CPUPlace(), &min_cpu); - min_data = min_cpu.data(); - } - min = min_data[0]; - } - - using XPUDataType = typename XPUTypeTrait::Type; - auto& dev_ctx = ctx.template device_context(); - auto x_data = reinterpret_cast(x->data()); - auto out_data = reinterpret_cast(out->data()); - int r = xpu::clip_v2( - dev_ctx.x_context(), x_data, out_data, x->numel(), min, max); - PADDLE_ENFORCE_EQ( - r, - XPU_SUCCESS, - platform::errors::External("XPU API(clip_v2) return wrong " - "value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; - -REGISTER_OP_XPU_KERNEL(clip, ops::ClipXPUKernel); - -#endif diff --git a/paddle/phi/kernels/xpu/clip_kernel.cc b/paddle/phi/kernels/xpu/clip_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..80e7e5f0c493dde1b0cbc37df14e91655699214a --- /dev/null +++ b/paddle/phi/kernels/xpu/clip_kernel.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/clip_kernel.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/backends/xpu/xpu_header.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void ClipKernel(const Context& dev_ctx, + const DenseTensor& x, + const Scalar& min, + const Scalar& max, + DenseTensor* out) { + dev_ctx.template Alloc(out); + using XPUDataType = typename XPUTypeTrait::Type; + auto x_data = reinterpret_cast(x.data()); + auto out_data = reinterpret_cast(out->data()); + int r = xpu::clip_v2(dev_ctx.x_context(), + x_data, + out_data, + x.numel(), + min.to(), + max.to()); + + PADDLE_ENFORCE_EQ(r, + XPU_SUCCESS, + phi::errors::External("XPU API(clip_v2) return wrong " + "value[%d %s]", + r, + XPUAPIErrorMsg[r])); +} + +} // namespace phi + +PD_REGISTER_KERNEL(clip, XPU, ALL_LAYOUT, phi::ClipKernel, float) {}