diff --git a/paddle/fluid/operators/fill_any_like_op.h b/paddle/fluid/operators/fill_any_like_op.h index 3ad56827f8344417d6919301a2495bb2d41bbaa4..287bbbfa3b34350b1e6f5da381a6a938a4570ecc 100644 --- a/paddle/fluid/operators/fill_any_like_op.h +++ b/paddle/fluid/operators/fill_any_like_op.h @@ -20,7 +20,7 @@ limitations under the License. */ #include "paddle/fluid/framework/pten_utils.h" #include "paddle/pten/include/core.h" -#include "paddle/pten/include/creation.h" +#include "paddle/pten/kernels/full_kernel.h" namespace paddle { namespace operators { @@ -65,7 +65,7 @@ class FillAnyLikeKernel : public framework::OpKernel { const auto& dev_ctx = context.template device_context(); // call new kernel - pten::FullLike(dev_ctx, value, pt_out.get()); + pten::FullLikeKernel(dev_ctx, value, pt_out.get()); } }; diff --git a/paddle/pten/all.h b/paddle/pten/all.h index b7ef1c1ec2611d490fe3104ac829367c6310a674..844114c341d67004d5a02a12f91ac7670e3ba856 100644 --- a/paddle/pten/all.h +++ b/paddle/pten/all.h @@ -16,7 +16,6 @@ limitations under the License. */ // developer apis #include "paddle/pten/include/core.h" -#include "paddle/pten/include/creation.h" #include "paddle/pten/include/infermeta.h" #include "paddle/pten/include/linalg.h" #include "paddle/pten/include/manipulation.h" diff --git a/paddle/pten/include/creation.h b/paddle/pten/include/creation.h deleted file mode 100644 index fa5bd49ca3026f26afdcbb67ac6f50036eded6cc..0000000000000000000000000000000000000000 --- a/paddle/pten/include/creation.h +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include "paddle/pten/api/lib/utils/storage.h" -#include "paddle/pten/include/infermeta.h" -#include "paddle/pten/kernels/empty_kernel.h" -#include "paddle/pten/kernels/full_kernel.h" - -namespace pten { - -// TODO(YuanRisheng) This function name should be same as User API name. -// TODO(zyfncg) Automatic code generation -template -DenseTensor Full(const ContextT& dev_ctx, - const ScalarArray& shape, - const Scalar& val, - DataType dtype = DataType::FLOAT32, - Backend backend = Backend::CPU, // Is backend needed here? - DataLayout layout = DataLayout::NCHW) { - auto out_meta = CreateInferMeta(shape, dtype, layout); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - Full(dev_ctx, shape, val, &dense_out); - return dense_out; -} - -template -DenseTensor FullLike( - const ContextT& dev_ctx, - const DenseTensor& x, - const Scalar& val, - DataType dtype = DataType::UNDEFINED, - Backend backend = Backend::UNDEFINED, // Is backend needed here? - DataLayout layout = DataLayout::UNDEFINED) { - auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout); - pten::DenseTensor dense_out( - pten::make_intrusive( - dev_ctx.GetPlace()), - std::move(out_meta)); - FullLike(dev_ctx, val, &dense_out); - return dense_out; -} - -} // namespace pten diff --git a/paddle/pten/kernels/cpu/full_kernel.cc b/paddle/pten/kernels/cpu/full_kernel.cc index 4912656bb2aefe6076296a70122cec540b8689ad..1ae8001d79dc7140d8155f900d6993c40ac163c1 100644 --- a/paddle/pten/kernels/cpu/full_kernel.cc +++ b/paddle/pten/kernels/cpu/full_kernel.cc @@ -21,7 +21,7 @@ limitations under the License. */ PT_REGISTER_CTX_KERNEL(full, CPU, ALL_LAYOUT, - pten::Full, + pten::FullKernel, float, double, uint8_t, @@ -37,7 +37,7 @@ PT_REGISTER_CTX_KERNEL(full, PT_REGISTER_CTX_KERNEL(full_like, CPU, ALL_LAYOUT, - pten::FullLike, + pten::FullLikeKernel, float, double, int, diff --git a/paddle/pten/kernels/full_kernel.h b/paddle/pten/kernels/full_kernel.h index 5bf6e37c36e576594096b8d5cadec63f4e514621..bc484fb4edffa95abf25c66172b7e6d3c603a500 100644 --- a/paddle/pten/kernels/full_kernel.h +++ b/paddle/pten/kernels/full_kernel.h @@ -18,15 +18,47 @@ #include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/infermeta/nullary.h" +#include "paddle/pten/kernels/empty_kernel.h" + namespace pten { template -void Full(const Context& dev_ctx, - const ScalarArray& shape, - const Scalar& val, - DenseTensor* out); +void FullKernel(const Context& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DenseTensor* out); + +template +void FullLikeKernel(const Context& dev_ctx, + const Scalar& val, + DenseTensor* out); + +template +DenseTensor Full(const Context& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DataType dtype = DataType::FLOAT32, + Backend backend = Backend::CPU, // Is backend needed here? + DataLayout layout = DataLayout::NCHW) { + auto out_meta = CreateInferMeta(shape, dtype, layout); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + FullKernel(dev_ctx, shape, val, &dense_out); + return dense_out; +} template -void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out); +DenseTensor FullLike( + const Context& dev_ctx, + const DenseTensor& x, + const Scalar& val, + DataType dtype = DataType::UNDEFINED, + Backend backend = Backend::UNDEFINED, // Is backend needed here? + DataLayout layout = DataLayout::UNDEFINED) { + auto out_meta = CreateLikeInferMeta(x.meta(), dtype, layout); + auto dense_out = Empty(dev_ctx, std::move(out_meta)); + FullLikeKernel(dev_ctx, val, &dense_out); + return dense_out; +} } // namespace pten diff --git a/paddle/pten/kernels/gpu/full_kernel.cu b/paddle/pten/kernels/gpu/full_kernel.cu index 16389d7749bf1d8edf5e224f6c2411c72cc9adb7..ae1f8529db3de8a42a0c8b43781fa35c7c9f7ef1 100644 --- a/paddle/pten/kernels/gpu/full_kernel.cu +++ b/paddle/pten/kernels/gpu/full_kernel.cu @@ -21,7 +21,7 @@ limitations under the License. */ PT_REGISTER_CTX_KERNEL(full, GPU, ALL_LAYOUT, - pten::Full, + pten::FullKernel, float, double, uint8_t, @@ -36,7 +36,7 @@ PT_REGISTER_CTX_KERNEL(full, PT_REGISTER_CTX_KERNEL(full_like, GPU, ALL_LAYOUT, - pten::FullLike, + pten::FullLikeKernel, float, double, int, diff --git a/paddle/pten/kernels/impl/full_kernel_impl.h b/paddle/pten/kernels/impl/full_kernel_impl.h index c77b7a7077824fbe44d5730835136fba6b7f929f..9be40e22a0360ee23050f59b8a7598333def38ad 100644 --- a/paddle/pten/kernels/impl/full_kernel_impl.h +++ b/paddle/pten/kernels/impl/full_kernel_impl.h @@ -24,7 +24,7 @@ limitations under the License. */ namespace pten { -template +template void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) { tensor->mutable_data(); auto t = pten::EigenVector::Flatten(*tensor); @@ -32,16 +32,18 @@ void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) { } template -void Full(const Context& dev_ctx, - const ScalarArray& shape, - const Scalar& val, - DenseTensor* out) { +void FullKernel(const Context& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DenseTensor* out) { out->Resize(paddle::framework::make_ddim(shape.GetData())); - FullValue(dev_ctx, out, val.to()); + FullValue(dev_ctx, out, val.to()); } template -void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out) { +void FullLikeKernel(const Context& dev_ctx, + const Scalar& val, + DenseTensor* out) { auto value = val.to(); using CommonType = typename std::common_type< float, @@ -66,7 +68,7 @@ void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out) { static_cast(std::numeric_limits::lowest()), static_cast(std::numeric_limits::max()), static_cast(value))); - FullValue(dev_ctx, out, value); + FullValue(dev_ctx, out, value); } } // namespace pten diff --git a/paddle/pten/tests/kernels/test_creation_dev_api.cc b/paddle/pten/tests/kernels/test_creation_dev_api.cc index 8469b94b797c8757b51aeaa8814e151aacfb9778..4d753f7d09b8e0cf9c4e485426911cead82c2cd9 100644 --- a/paddle/pten/tests/kernels/test_creation_dev_api.cc +++ b/paddle/pten/tests/kernels/test_creation_dev_api.cc @@ -15,7 +15,8 @@ limitations under the License. */ #include #include -#include "paddle/pten/include/creation.h" +#include "paddle/pten/kernels/empty_kernel.h" +#include "paddle/pten/kernels/full_kernel.h" #include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/core/dense_tensor.h"