From 938397173ab7fe02761b8fc7c210f8b68195ede0 Mon Sep 17 00:00:00 2001 From: chentianyu03 Date: Thu, 27 Jan 2022 14:44:39 +0800 Subject: [PATCH] [pten] add full xpu kernel (#39172) * add full_kernel xpu * fix full xpu register device type error * fix full kernel bug * add fulllike kernel impl and replace with raw kernel * fix dev_ctx convert template args error * modify namespace and header file * add isinf check * fix input type args in TensorSetConstantXPU error --- .../fluid/operators/fill_any_like_op_xpu.cc | 15 +- paddle/pten/kernels/xpu/full_kernel.cc | 133 ++++++++++++++++++ 2 files changed, 141 insertions(+), 7 deletions(-) create mode 100644 paddle/pten/kernels/xpu/full_kernel.cc diff --git a/paddle/fluid/operators/fill_any_like_op_xpu.cc b/paddle/fluid/operators/fill_any_like_op_xpu.cc index 76cf339fbf..b4788d0445 100644 --- a/paddle/fluid/operators/fill_any_like_op_xpu.cc +++ b/paddle/fluid/operators/fill_any_like_op_xpu.cc @@ -16,6 +16,8 @@ limitations under the License. */ #include "paddle/fluid/operators/fill_any_like_op.h" +#include "paddle/pten/kernels/full_kernel.h" + namespace paddle { namespace operators { @@ -56,13 +58,12 @@ class FillAnyLikeXPUKernel : public framework::OpKernel { auto& dev_ctx = context.template device_context(); - auto out_data = reinterpret_cast(out->data()); - int ret = xpu::constant(dev_ctx.x_context(), out_data, out->numel(), - static_cast(value)); - PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, - platform::errors::External( - "XPU CONSTANT API return wrong value[%d %s].", ret, - XPUAPIErrorMsg[ret])); + + // call pten kernel + pten::FullLikeKernel( + static_cast::TYPE&>(dev_ctx), + value, out); } }; diff --git a/paddle/pten/kernels/xpu/full_kernel.cc b/paddle/pten/kernels/xpu/full_kernel.cc new file mode 100644 index 0000000000..71d2b8e3ad --- /dev/null +++ b/paddle/pten/kernels/xpu/full_kernel.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/full_kernel.h" + +#include "paddle/pten/api/ext/dispatch.h" +#include "paddle/pten/backends/xpu/xpu_context.h" +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace pten { + +template +void TensorSetConstantXPU(pten::DenseTensor* tensor, + InType value, + pten::Place place) { + auto* begin = tensor->mutable_data(place); + int64_t numel = tensor->numel(); + std::unique_ptr data_cpu(new OutType[numel]); + std::fill( + data_cpu.get(), data_cpu.get() + numel, static_cast(value)); + paddle::memory::Copy(place, + begin, + pten::CPUPlace(), + static_cast(data_cpu.get()), + numel * sizeof(OutType)); +} + +template +void FullValueXPU(const Context& dev_ctx, DenseTensor* tensor, VType val) { + tensor->mutable_data(dev_ctx.GetPlace()); + + PD_VISIT_ALL_TYPES(tensor->dtype(), "FullValueXPU", ([&] { + TensorSetConstantXPU( + tensor, val, dev_ctx.GetPlace()); + })); +} + +template +void FullKernel(const Context& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DenseTensor* out) { + out->ResizeAndAllocate(pten::framework::make_ddim(shape.GetData())); + FullValueXPU(dev_ctx, out, val.to()); +} + +template +void FullLikeKernel(const Context& dev_ctx, + const Scalar& val, + DenseTensor* out) { + auto value = val.to(); + using XPUInTDType = typename XPUTypeTrait::Type; + using CommonType = typename std::common_type< + float, + typename std::conditional::value, + float, + T>::type>::type; + + auto common_type_value = static_cast(value); + + PADDLE_ENFORCE_EQ( + (common_type_value >= + static_cast(std::numeric_limits::lowest())) && + (common_type_value <= + static_cast(std::numeric_limits::max())), + true, + pten::errors::InvalidArgument( + "The filled value is out of range for target type, " + "current kernel type is %s, the range should between %f " + "and %f, but now value is %f.", + typeid(T).name(), + static_cast(std::numeric_limits::lowest()), + static_cast(std::numeric_limits::max()), + static_cast(value))); + + PADDLE_ENFORCE_EQ(std::isnan(value), + false, + pten::errors::InvalidArgument("The filled value is NaN.")); + PADDLE_ENFORCE_EQ(std::isinf(value), + false, + pten::errors::InvalidArgument("The filled value is Inf.")); + + auto out_data = reinterpret_cast(out->data()); + int ret = xpu::constant(dev_ctx.x_context(), + out_data, + out->numel(), + static_cast(value)); + PADDLE_ENFORCE_EQ( + ret, + XPU_SUCCESS, + pten::errors::External("XPU CONSTANT API return wrong value[%d %s].", + ret, + XPUAPIErrorMsg[ret])); +} + +} // namespace pten + +PT_REGISTER_KERNEL(full, + XPU, + ALL_LAYOUT, + pten::FullKernel, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + pten::platform::float16, + pten::platform::bfloat16, + pten::platform::complex, + pten::platform::complex) {} + +PT_REGISTER_KERNEL(full_like, + XPU, + ALL_LAYOUT, + pten::FullLikeKernel, + float, + int, + int64_t, + pten::platform::float16) {} -- GitLab