From 0c02d2ed7c840360ed42023902dc5da96552b3fd Mon Sep 17 00:00:00 2001 From: zyfncg Date: Thu, 6 Jan 2022 09:55:18 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PTen=E3=80=91Adjust=20the=20format=20o?= =?UTF-8?q?f=20full=20kernel=20(#38596)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * adjust the full kernel * remove creation.h * use Empty to create tensor in full --- paddle/fluid/operators/fill_any_like_op.h | 4 +- paddle/pten/all.h | 1 - paddle/pten/include/creation.h | 59 ------------------- paddle/pten/kernels/cpu/full_kernel.cc | 4 +- paddle/pten/kernels/full_kernel.h | 42 +++++++++++-- paddle/pten/kernels/gpu/full_kernel.cu | 4 +- paddle/pten/kernels/impl/full_kernel_impl.h | 18 +++--- .../tests/kernels/test_creation_dev_api.cc | 3 +- 8 files changed, 55 insertions(+), 80 deletions(-) delete mode 100644 paddle/pten/include/creation.h diff --git a/paddle/fluid/operators/fill_any_like_op.h b/paddle/fluid/operators/fill_any_like_op.h index 3ad56827f83..287bbbfa3b3 100644 --- a/paddle/fluid/operators/fill_any_like_op.h +++ b/paddle/fluid/operators/fill_any_like_op.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/pten_utils.h" #include "paddle/pten/include/core.h" -#include "paddle/pten/include/creation.h" +#include "paddle/pten/kernels/full_kernel.h" namespace paddle { namespace operators { @@ -65,7 +65,7 @@ class FillAnyLikeKernel : public framework::OpKernel { const auto& dev_ctx = context.template device_context(); // call new kernel - pten::FullLike(dev_ctx, value, pt_out.get()); + pten::FullLikeKernel(dev_ctx, value, pt_out.get()); } }; diff --git a/paddle/pten/all.h b/paddle/pten/all.h index b7ef1c1ec26..844114c341d 100644 --- a/paddle/pten/all.h +++ b/paddle/pten/all.h @@ -16,7 +16,6 @@ limitations under the License. */ // developer apis #include "paddle/pten/include/core.h" -#include "paddle/pten/include/creation.h" #include "paddle/pten/include/infermeta.h" #include "paddle/pten/include/linalg.h" #include "paddle/pten/include/manipulation.h" diff --git a/paddle/pten/include/creation.h b/paddle/pten/include/creation.h deleted file mode 100644 index fa5bd49ca30..00000000000 --- a/paddle/pten/include/creation.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pten/api/lib/utils/storage.h" -#include "paddle/pten/include/infermeta.h" -#include "paddle/pten/kernels/empty_kernel.h" -#include "paddle/pten/kernels/full_kernel.h" - -namespace pten { - -// TODO(YuanRisheng) This function name should be same as User API name. -// TODO(zyfncg) Automatic code generation -template -DenseTensor Full(const ContextT& dev_ctx, - const ScalarArray& shape, - const Scalar& val, - DataType dtype = DataType::FLOAT32, - Backend backend = Backend::CPU, // Is backend needed here? - DataLayout layout = DataLayout::NCHW) { - auto out_meta = CreateInferMeta(shape, dtype, layout); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - Full(dev_ctx, shape, val, &dense_out); - return dense_out; -} - -template -DenseTensor FullLike( - const ContextT& dev_ctx, - const DenseTensor& x, - const Scalar& val, - DataType dtype = DataType::UNDEFINED, - Backend backend = Backend::UNDEFINED, // Is backend needed here? - DataLayout layout = DataLayout::UNDEFINED) { - auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - FullLike(dev_ctx, val, &dense_out); - return dense_out; -} - -} // namespace pten diff --git a/paddle/pten/kernels/cpu/full_kernel.cc b/paddle/pten/kernels/cpu/full_kernel.cc index 4912656bb2a..1ae8001d79d 100644 --- a/paddle/pten/kernels/cpu/full_kernel.cc +++ b/paddle/pten/kernels/cpu/full_kernel.cc @@ -21,7 +21,7 @@ limitations under the License. */ PT_REGISTER_CTX_KERNEL(full, CPU, ALL_LAYOUT, - pten::Full, + pten::FullKernel, float, double, uint8_t, @@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(full, PT_REGISTER_CTX_KERNEL(full_like, CPU, ALL_LAYOUT, - pten::FullLike, + pten::FullLikeKernel, float, double, int, diff --git a/paddle/pten/kernels/full_kernel.h b/paddle/pten/kernels/full_kernel.h index 5bf6e37c36e..bc484fb4edf 100644 --- a/paddle/pten/kernels/full_kernel.h +++ b/paddle/pten/kernels/full_kernel.h @@ -18,15 +18,47 @@ #include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/infermeta/nullary.h" +#include "paddle/pten/kernels/empty_kernel.h" + namespace pten { template -void Full(const Context& dev_ctx, - const ScalarArray& shape, - const Scalar& val, - DenseTensor* out); +void FullKernel(const Context& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DenseTensor* out); + +template +void FullLikeKernel(const Context& dev_ctx, + const Scalar& val, + DenseTensor* out); + +template +DenseTensor Full(const Context& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DataType dtype = DataType::FLOAT32, + Backend backend = Backend::CPU, // Is backend needed here? + DataLayout layout = DataLayout::NCHW) { + auto out_meta = CreateInferMeta(shape, dtype, layout); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + FullKernel(dev_ctx, shape, val, &dense_out); + return dense_out; +} template -void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out); +DenseTensor FullLike( + const Context& dev_ctx, + const DenseTensor& x, + const Scalar& val, + DataType dtype = DataType::UNDEFINED, + Backend backend = Backend::UNDEFINED, // Is backend needed here? + DataLayout layout = DataLayout::UNDEFINED) { + auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + FullLikeKernel(dev_ctx, val, &dense_out); + return dense_out; +} } // namespace pten diff --git a/paddle/pten/kernels/gpu/full_kernel.cu b/paddle/pten/kernels/gpu/full_kernel.cu index 16389d7749b..ae1f8529db3 100644 --- a/paddle/pten/kernels/gpu/full_kernel.cu +++ b/paddle/pten/kernels/gpu/full_kernel.cu @@ -21,7 +21,7 @@ limitations under the License. */ PT_REGISTER_CTX_KERNEL(full, GPU, ALL_LAYOUT, - pten::Full, + pten::FullKernel, float, double, uint8_t, @@ -36,7 +36,7 @@ PT_REGISTER_CTX_KERNEL(full, PT_REGISTER_CTX_KERNEL(full_like, GPU, ALL_LAYOUT, - pten::FullLike, + pten::FullLikeKernel, float, double, int, diff --git a/paddle/pten/kernels/impl/full_kernel_impl.h b/paddle/pten/kernels/impl/full_kernel_impl.h index c77b7a70778..9be40e22a03 100644 --- a/paddle/pten/kernels/impl/full_kernel_impl.h +++ b/paddle/pten/kernels/impl/full_kernel_impl.h @@ -24,7 +24,7 @@ limitations under the License. */ namespace pten { -template +template void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) { tensor->mutable_data(); auto t = pten::EigenVector::Flatten(*tensor); @@ -32,16 +32,18 @@ void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) { } template -void Full(const Context& dev_ctx, - const ScalarArray& shape, - const Scalar& val, - DenseTensor* out) { +void FullKernel(const Context& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DenseTensor* out) { out->Resize(paddle::framework::make_ddim(shape.GetData())); - FullValue(dev_ctx, out, val.to()); + FullValue(dev_ctx, out, val.to()); } template -void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out) { +void FullLikeKernel(const Context& dev_ctx, + const Scalar& val, + DenseTensor* out) { auto value = val.to(); using CommonType = typename std::common_type< float, @@ -66,7 +68,7 @@ void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out) { static_cast(std::numeric_limits::lowest()), static_cast(std::numeric_limits::max()), static_cast(value))); - FullValue(dev_ctx, out, value); + FullValue(dev_ctx, out, value); } } // namespace pten diff --git a/paddle/pten/tests/kernels/test_creation_dev_api.cc b/paddle/pten/tests/kernels/test_creation_dev_api.cc index 8469b94b797..4d753f7d09b 100644 --- a/paddle/pten/tests/kernels/test_creation_dev_api.cc +++ b/paddle/pten/tests/kernels/test_creation_dev_api.cc @@ -15,7 +15,8 @@ limitations under the License. */ #include #include -#include "paddle/pten/include/creation.h" +#include "paddle/pten/kernels/empty_kernel.h" +#include "paddle/pten/kernels/full_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/core/dense_tensor.h" -- GitLab