From 0f24de8320181c0b83f382813f9da4d9ef5e94fe Mon Sep 17 00:00:00 2001 From: zyfncg Date: Wed, 24 Nov 2021 16:01:26 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90PTen=E3=80=91Add=20Scalar=20and=20Scal?= =?UTF-8?q?arArray=20in=20pten=20(#37409)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add scalar and scalar_array * remove DenseTensor include from Scalar and ScalarArray * remove inner header from scalar_array * refactor the method of fill_constant and add some comment --- paddle/fluid/operators/cumsum_op.cu | 2 +- paddle/fluid/operators/math/tree2col.cc | 3 +- paddle/fluid/operators/math/tree2col.h | 4 +- paddle/pten/api/all.h | 1 + paddle/pten/api/include/creation.h | 3 +- paddle/pten/api/lib/creation.cc | 9 +- paddle/pten/api/lib/tensor.cc | 3 + paddle/pten/common/scalar.h | 212 ++++++++++++++++++++---- paddle/pten/common/scalar_array.h | 148 +++++++++++++++++ paddle/pten/core/kernel_utils.h | 3 + paddle/pten/infermeta/nary.cc | 7 + paddle/pten/infermeta/nary.h | 6 +- paddle/pten/kernels/cpu/creation.cc | 25 +++ paddle/pten/kernels/cpu/creation.h | 7 + paddle/pten/kernels/cuda/creation.cu | 24 +++ paddle/pten/kernels/cuda/creation.h | 7 + paddle/pten/tests/api/test_fill_api.cc | 86 +++++++++- 17 files changed, 503 insertions(+), 47 deletions(-) create mode 100644 paddle/pten/common/scalar_array.h diff --git a/paddle/fluid/operators/cumsum_op.cu b/paddle/fluid/operators/cumsum_op.cu index 854be76f24e..d9e19eb7f61 100644 --- a/paddle/fluid/operators/cumsum_op.cu +++ b/paddle/fluid/operators/cumsum_op.cu @@ -254,7 +254,7 @@ class CumCUDAKernel : public framework::OpKernel { dim3 transpose_grids((width + tile_size - 1) / tile_size, (height + tile_size - 1) / tile_size); auto& dev_ctx = context.template device_context(); - Tensor tmp; + framework::Tensor tmp; tmp.Resize(out_dims); auto* tmp_data = tmp.mutable_data(context.GetPlace()); T* next_in_data = out_data; diff --git a/paddle/fluid/operators/math/tree2col.cc b/paddle/fluid/operators/math/tree2col.cc index 0344226ea66..97ab2c5f52a 100644 --- a/paddle/fluid/operators/math/tree2col.cc +++ b/paddle/fluid/operators/math/tree2col.cc @@ -19,7 +19,6 @@ namespace paddle { namespace operators { namespace math { -using Tensor = framework::Tensor; std::vector Tree2ColUtil::construct_patch( size_t root, int max_depth, const std::vector> &tr) { std::stack> stack; @@ -51,7 +50,7 @@ std::vector Tree2ColUtil::construct_patch( return patch; } -void Tree2ColUtil::construct_tree(const paddle::Tensor &EdgeSet, +void Tree2ColUtil::construct_tree(const framework::Tensor &EdgeSet, std::vector> *tr, size_t *node_count) { auto edge_set_dims = EdgeSet.dims(); diff --git a/paddle/fluid/operators/math/tree2col.h b/paddle/fluid/operators/math/tree2col.h index 478ba78e259..632777c9cd9 100644 --- a/paddle/fluid/operators/math/tree2col.h +++ b/paddle/fluid/operators/math/tree2col.h @@ -21,8 +21,6 @@ #include "paddle/fluid/operators/math/math_function.h" namespace paddle { -using Tensor = framework::Tensor; -using DDim = framework::DDim; namespace operators { namespace math { class TreeNode { @@ -64,7 +62,7 @@ class Tree2ColUtil { static std::vector construct_patch( size_t root, int max_depth, const std::vector> &tr); - static void construct_tree(const Tensor &EdgeSet, + static void construct_tree(const framework::Tensor &EdgeSet, std::vector> *tr, size_t *node_count); }; diff --git a/paddle/pten/api/all.h b/paddle/pten/api/all.h index 22cbe5fa1fd..2c647786379 100644 --- a/paddle/pten/api/all.h +++ b/paddle/pten/api/all.h @@ -37,6 +37,7 @@ limitations under the License. */ #include "paddle/pten/common/data_type.h" #include "paddle/pten/common/layout.h" #include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" // original custom op headers #include "paddle/pten/api/ext/dispatch.h" diff --git a/paddle/pten/api/include/creation.h b/paddle/pten/api/include/creation.h index bcd3d4355cf..b4e4bd0fd05 100644 --- a/paddle/pten/api/include/creation.h +++ b/paddle/pten/api/include/creation.h @@ -18,11 +18,12 @@ #include "paddle/pten/common/backend.h" #include "paddle/pten/common/data_type.h" #include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" namespace paddle { namespace experimental { -PD_DLL_DECL Tensor full(const std::vector& shape, +PD_DLL_DECL Tensor full(const ScalarArray& shape, const Scalar& value, DataType dtype = DataType::FLOAT32, Backend backend = Backend::CPU, diff --git a/paddle/pten/api/lib/creation.cc b/paddle/pten/api/lib/creation.cc index 523c2f6bd10..088ff919596 100644 --- a/paddle/pten/api/lib/creation.cc +++ b/paddle/pten/api/lib/creation.cc @@ -34,7 +34,7 @@ PT_DECLARE_MODULE(CreationCUDA); namespace paddle { namespace experimental { -PD_DLL_DECL Tensor full(const std::vector& shape, +PD_DLL_DECL Tensor full(const ScalarArray& shape, const Scalar& value, DataType dtype, Backend backend, @@ -42,14 +42,15 @@ PD_DLL_DECL Tensor full(const std::vector& shape, // 1. Get kernel signature and kernel pten::KernelKey kernel_key{backend, layout, dtype}; auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( - "fill_constant.scalar", kernel_key); + "fill_constant", kernel_key); // 2. Get Device Context auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend()); auto kernel_context = pten::KernelContext(dev_ctx); // 3. Auto data transform - kernel_context.EmplaceBackAttr(value); + kernel_context.EmplaceBackAttr(pten::ScalarArray(shape)); + kernel_context.EmplaceBackAttr(pten::Scalar(value)); // 4. InferShape auto out_meta = pten::FullInferShape(shape, dtype, layout); @@ -94,7 +95,7 @@ PD_DLL_DECL Tensor full_like(const Tensor& x, // 3. Auto data transform auto dense_x = std::dynamic_pointer_cast(x.impl()); - kernel_context.EmplaceBackAttr(value); + kernel_context.EmplaceBackAttr(pten::Scalar(value)); // 4. InferShape auto out_meta = FullLikeInferShape(dense_x->meta(), dtype, layout); diff --git a/paddle/pten/api/lib/tensor.cc b/paddle/pten/api/lib/tensor.cc index 8a9ce5e859e..3f0966d369d 100644 --- a/paddle/pten/api/lib/tensor.cc +++ b/paddle/pten/api/lib/tensor.cc @@ -219,6 +219,7 @@ template PD_DLL_DECL const int32_t *Tensor::data() const; template PD_DLL_DECL const uint8_t *Tensor::data() const; template PD_DLL_DECL const int8_t *Tensor::data() const; template PD_DLL_DECL const int16_t *Tensor::data() const; +template PD_DLL_DECL const uint16_t *Tensor::data() const; template PD_DLL_DECL const bool *Tensor::data() const; template PD_DLL_DECL const paddle::platform::complex *Tensor::data>() const; @@ -226,6 +227,8 @@ template PD_DLL_DECL const paddle::platform::complex *Tensor::data>() const; template PD_DLL_DECL const paddle::platform::float16 * Tensor::data() const; +template PD_DLL_DECL const paddle::platform::bfloat16 * +Tensor::data() const; template T *Tensor::data() { diff --git a/paddle/pten/common/scalar.h b/paddle/pten/common/scalar.h index bc2488024f1..36205a0e4c2 100644 --- a/paddle/pten/common/scalar.h +++ b/paddle/pten/common/scalar.h @@ -18,69 +18,217 @@ limitations under the License. */ #include #include "paddle/pten/api/ext/exception.h" - +#include "paddle/pten/api/include/tensor.h" namespace paddle { namespace experimental { -class Scalar { +template +class ScalarBase { public: // Constructor support implicit - Scalar(float val) : tag(Tag::HAS_F) { data_.f = val; } // NOLINT + ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT + data_.f64 = val; + } + + ScalarBase(float val) : dtype_(DataType::FLOAT32) { // NOLINT + data_.f32 = val; + } + + ScalarBase(float16 val) : dtype_(DataType::FLOAT16) { // NOLINT + data_.f16 = val; + } - Scalar(double val) : tag(Tag::HAS_D) { data_.d = val; } // NOLINT + ScalarBase(bfloat16 val) : dtype_(DataType::BFLOAT16) { // NOLINT + data_.bf16 = val; + } - Scalar(int32_t val) : tag(Tag::HAS_I32) { data_.i32 = val; } // NOLINT + ScalarBase(int64_t val) : dtype_(DataType::INT64) { // NOLINT + data_.i64 = val; + } - Scalar(int64_t val) : tag(Tag::HAS_I64) { data_.i64 = val; } // NOLINT + ScalarBase(int32_t val) : dtype_(DataType::INT32) { // NOLINT + data_.i32 = val; + } - Scalar(bool val) : tag(Tag::HAS_B) { data_.b = val; } // NOLINT + ScalarBase(int16_t val) : dtype_(DataType::INT16) { // NOLINT + data_.i16 = val; + } - Scalar(const std::string& str_value) : tag(Tag::HAS_D) { // NOLINT + ScalarBase(int8_t val) : dtype_(DataType::INT8) { // NOLINT + data_.i8 = val; + } + + ScalarBase(uint64_t val) : dtype_(DataType::UINT64) { // NOLINT + data_.ui64 = val; + } + + ScalarBase(uint32_t val) : dtype_(DataType::UINT32) { // NOLINT + data_.ui32 = val; + } + + ScalarBase(uint16_t val) : dtype_(DataType::UINT16) { // NOLINT + data_.ui16 = val; + } + + ScalarBase(uint8_t val) : dtype_(DataType::UINT8) { // NOLINT + data_.ui8 = val; + } + + ScalarBase(bool val) : dtype_(DataType::BOOL) { // NOLINT + data_.b = val; + } + + ScalarBase(complex64 val) : dtype_(DataType::COMPLEX64) { // NOLINT + data_.c64 = val; + } + + ScalarBase(complex128 val) : dtype_(DataType::COMPLEX128) { // NOLINT + data_.c128 = val; + } + + // The compatible method for fliud operators, + // and it will be removed in the future. + explicit ScalarBase(const std::string& str_value) + : dtype_(DataType::FLOAT64) { if (str_value == "inf") { - data_.d = std::numeric_limits::infinity(); + data_.f64 = std::numeric_limits::infinity(); } else if (str_value == "-inf") { - data_.d = -std::numeric_limits::infinity(); + data_.f64 = -std::numeric_limits::infinity(); } else if (str_value == "nan") { - data_.d = std::numeric_limits::quiet_NaN(); + data_.f64 = std::numeric_limits::quiet_NaN(); } else { - data_.d = std::stod(str_value); + data_.f64 = std::stod(str_value); + } + } + + // The Tensor must have one dim + ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT + PD_CHECK( + tensor.numel() == 1, + "The Scalar only supports Tensor with 1 element, but now Tensor has `", + tensor.numel(), + "` element."); + switch (dtype_) { + case DataType::FLOAT32: + data_.f32 = tensor.template data()[0]; + break; + case DataType::FLOAT64: + data_.f64 = tensor.template data()[0]; + break; + case DataType::FLOAT16: + data_.f16 = tensor.template data()[0]; + break; + case DataType::BFLOAT16: + data_.bf16 = tensor.template data()[0]; + break; + case DataType::INT32: + data_.i32 = tensor.template data()[0]; + break; + case DataType::INT64: + data_.i64 = tensor.template data()[0]; + break; + case DataType::INT16: + data_.i16 = tensor.template data()[0]; + break; + case DataType::INT8: + data_.i8 = tensor.template data()[0]; + break; + case DataType::UINT16: + data_.ui16 = tensor.template data()[0]; + break; + case DataType::UINT8: + data_.ui8 = tensor.template data()[0]; + break; + case DataType::BOOL: + data_.b = tensor.template data()[0]; + break; + case DataType::COMPLEX64: + data_.c64 = tensor.template data()[0]; + break; + case DataType::COMPLEX128: + data_.c128 = tensor.template data()[0]; + break; + default: + PD_THROW("Invalid tensor data type `", dtype_, "`."); } } - template - inline T to() const { - switch (tag) { - case Tag::HAS_F: - return static_cast(data_.f); - case Tag::HAS_D: - return static_cast(data_.d); - case Tag::HAS_I32: - return static_cast(data_.i32); - case Tag::HAS_I64: - return static_cast(data_.i64); - case Tag::HAS_B: - return static_cast(data_.b); + template + ScalarBase(const ScalarBase& other) { + CopyScalar(other, this); + } + + template + inline RT to() const { + switch (dtype_) { + case DataType::FLOAT32: + return static_cast(data_.f32); + case DataType::FLOAT64: + return static_cast(data_.f64); + case DataType::FLOAT16: + return static_cast(data_.f16); + case DataType::BFLOAT16: + return static_cast(data_.bf16); + case DataType::INT32: + return static_cast(data_.i32); + case DataType::INT64: + return static_cast(data_.i64); + case DataType::INT16: + return static_cast(data_.i16); + case DataType::INT8: + return static_cast(data_.i8); + case DataType::UINT16: + return static_cast(data_.ui16); + case DataType::UINT8: + return static_cast(data_.ui8); + case DataType::BOOL: + return static_cast(data_.b); + case DataType::COMPLEX64: + return static_cast(data_.c64); + case DataType::COMPLEX128: + return static_cast(data_.c128); default: - PD_THROW("Invalid enum scalar type tag `", static_cast(tag), "`."); + PD_THROW("Invalid enum scalar data type `", dtype_, "`."); } } private: - enum class Tag { HAS_F, HAS_D, HAS_I32, HAS_I64, HAS_B }; - Tag tag; + template + friend void CopyScalar(const ScalarBase& src, ScalarBase* dst); + private: + DataType dtype_; union data { - float f; - double d; + bool b; + int8_t i8; + int16_t i16; int32_t i32; int64_t i64; - bool b; + uint8_t ui8; + uint16_t ui16; + uint32_t ui32; + uint64_t ui64; + bfloat16 bf16; + float16 f16; + float f32; + double f64; + complex64 c64; + complex128 c128; } data_; }; +template +void CopyScalar(const ScalarBase& src, ScalarBase* dst) { + dst->dtype_ = src.dtype_; + dst->data_.c128 = src.data_.c128; +} + +using Scalar = paddle::experimental::ScalarBase; + } // namespace experimental } // namespace paddle namespace pten { -using Scalar = paddle::experimental::Scalar; +class DenseTensor; +using Scalar = paddle::experimental::ScalarBase; } // namespace pten diff --git a/paddle/pten/common/scalar_array.h b/paddle/pten/common/scalar_array.h new file mode 100644 index 00000000000..701f777d4a0 --- /dev/null +++ b/paddle/pten/common/scalar_array.h @@ -0,0 +1,148 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/api/ext/exception.h" +#include "paddle/pten/api/include/tensor.h" + +namespace paddle { +namespace experimental { + +template +class ScalarArrayBase { + public: + // Constructor support implicit + ScalarArrayBase() = default; + + ScalarArrayBase(const std::vector& vec) : array_(vec) {} // NOLINT + + ScalarArrayBase(std::initializer_list array_list) + : array_(array_list) {} + + ScalarArrayBase(const int64_t* date_value, int64_t n) { + AssignData(date_value, n); + } + + ScalarArrayBase(const int32_t* date_value, int64_t n) { + AssignData(date_value, n); + } + + // The Tensor must have one dim + ScalarArrayBase(const T& tensor) { // NOLINT + size_t n = tensor.numel(); + array_.reserve(n); + switch (tensor.type()) { + case DataType::INT32: + AssignData(tensor.template data(), n); + break; + case DataType::INT64: + AssignData(tensor.template data(), n); + break; + default: + PD_THROW( + "Data type error. Currently, The data type of ScalarArrayBase " + "only supports Tensor with int32 and int64, " + "but now received `", + tensor.type(), + "`."); + } + } + + // The Tensor in vec must have only one element + ScalarArrayBase(const std::vector& tensor_list) { // NOLINT + auto n = tensor_list.size(); + array_.reserve(n); + if (!tensor_list.empty()) { + DataType data_type = tensor_list[0].dtype(); + switch (data_type) { + case DataType::INT32: { + for (size_t i = 0; i < n; ++i) { + PD_CHECK(tensor_list[i].dtype() == data_type, + "The data_type of tensors in the list isn't consistent." + "the first tensor is`", + data_type, + "` but `", + i, + "`th tensor is`", + tensor_list[i].dtype(), + "`."); + array_.push_back(*tensor_list[i].template data()); + } + break; + } + case DataType::INT64: { + for (size_t i = 0; i < n; ++i) { + PD_CHECK(tensor_list[i].dtype() == data_type, + "The data_type of tensors in the list isn't consistent." + "the first tensor is`", + data_type, + "` but `", + i, + "`th tensor is`", + tensor_list[i].dtype(), + "`."); + array_.push_back(*tensor_list[i].template data()); + } + break; + } + default: + PD_THROW( + "Data type error. Currently, The data type of ScalarArrayBase " + "only supports Tensor with int32 and int64, " + "but now received `", + data_type, + "`."); + } + } + } + + template + ScalarArrayBase(const ScalarArrayBase& other) + : array_(other.GetData()) {} + + const std::vector& GetData() const { return array_; } + + private: + /// \brief Assign the data_ from const data pointer value of type T. + template + void AssignData(const TYPE* value_data, int64_t n) { + if (value_data) { + array_.reserve(n); + for (auto i = 0; i < n; ++i) { + array_.push_back(static_cast(value_data[i])); + } + } else { + PD_THROW("The input data pointer is null."); + } + } + + private: + // TODO(zhangyunfei) Replace std::vector with a more efficient container + // structure. + std::vector array_; +}; + +using ScalarArray = + paddle::experimental::ScalarArrayBase; + +} // namespace experimental +} // namespace paddle + +namespace pten { + +class DenseTensor; +using ScalarArray = paddle::experimental::ScalarArrayBase; + +} // namespace pten diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h index 794857dba73..7e6be1c3914 100644 --- a/paddle/pten/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_context.h" #include "paddle/pten/core/kernel_def.h" @@ -209,6 +210,8 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); /* Output Helpers */ diff --git a/paddle/pten/infermeta/nary.cc b/paddle/pten/infermeta/nary.cc index 0ae078b13c0..d79945a384a 100644 --- a/paddle/pten/infermeta/nary.cc +++ b/paddle/pten/infermeta/nary.cc @@ -24,4 +24,11 @@ DenseTensorMeta FullInferShape(const std::vector& shape, return {dtype, out_dims, layout}; } +DenseTensorMeta FullInferShape(const ScalarArray& shape, + DataType dtype, + DataLayout layout) { + const auto& out_dims = paddle::framework::make_ddim(shape.GetData()); + return {dtype, out_dims, layout}; +} + } // namespace pten diff --git a/paddle/pten/infermeta/nary.h b/paddle/pten/infermeta/nary.h index 8900e0ed71c..c526583d7ba 100644 --- a/paddle/pten/infermeta/nary.h +++ b/paddle/pten/infermeta/nary.h @@ -14,7 +14,7 @@ limitations under the License. */ #pragma once -// See Note [ Why still include the fluid headers? ] +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/tensor_meta.h" namespace pten { @@ -31,4 +31,8 @@ DenseTensorMeta FullInferShape(const std::vector& shape, DataType dtype, DataLayout layout); +DenseTensorMeta FullInferShape(const ScalarArray& shape, + DataType dtype, + DataLayout layout); + } // namespace pten diff --git a/paddle/pten/kernels/cpu/creation.cc b/paddle/pten/kernels/cpu/creation.cc index 0de359a6d6b..84db03a78ec 100644 --- a/paddle/pten/kernels/cpu/creation.cc +++ b/paddle/pten/kernels/cpu/creation.cc @@ -57,6 +57,15 @@ void FillConstant(const CPUContext& dev_ctx, eigen::fill(dev_ctx, out, val.to()); } +template +void FillConstantDynamicShape(const CPUContext& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DenseTensor* out) { + out->Resize(paddle::framework::make_ddim(shape.GetData())); + eigen::fill(dev_ctx, out, val.to()); +} + } // namespace pten PT_REGISTER_MODULE(CreationCPU); @@ -87,3 +96,19 @@ PT_REGISTER_KERNEL("fill_constant.scalar", paddle::platform::bfloat16, paddle::platform::complex, paddle::platform::complex) {} + +PT_REGISTER_KERNEL("fill_constant", + CPU, + ANY, + pten::FillConstantDynamicShape, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::bfloat16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cpu/creation.h b/paddle/pten/kernels/cpu/creation.h index 6ace3118a14..668e242be9c 100644 --- a/paddle/pten/kernels/cpu/creation.h +++ b/paddle/pten/kernels/cpu/creation.h @@ -15,6 +15,7 @@ #pragma once #include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -33,4 +34,10 @@ void FillConstant(const CPUContext& dev_ctx, const Scalar& val, DenseTensor* out); +template +void FillConstantDynamicShape(const CPUContext& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DenseTensor* out); + } // namespace pten diff --git a/paddle/pten/kernels/cuda/creation.cu b/paddle/pten/kernels/cuda/creation.cu index 61562d2312d..203562a6205 100644 --- a/paddle/pten/kernels/cuda/creation.cu +++ b/paddle/pten/kernels/cuda/creation.cu @@ -58,6 +58,15 @@ void FillConstant(const CUDAContext& dev_ctx, eigen::fill(dev_ctx, out, val.to()); } +template +void FillConstantDynamicShape(const CUDAContext& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DenseTensor* out) { + out->Resize(paddle::framework::make_ddim(shape.GetData())); + eigen::fill(dev_ctx, out, val.to()); +} + } // namespace pten PT_REGISTER_MODULE(CreationCUDA); @@ -87,3 +96,18 @@ PT_REGISTER_KERNEL("fill_constant.scalar", paddle::platform::float16, paddle::platform::complex, paddle::platform::complex) {} + +PT_REGISTER_KERNEL("fill_constant", + CUDA, + ANY, + pten::FillConstantDynamicShape, + float, + double, + uint8_t, + int16_t, + int, + int64_t, + bool, + paddle::platform::float16, + paddle::platform::complex, + paddle::platform::complex) {} diff --git a/paddle/pten/kernels/cuda/creation.h b/paddle/pten/kernels/cuda/creation.h index 09c0d505fc0..45ea5348e21 100644 --- a/paddle/pten/kernels/cuda/creation.h +++ b/paddle/pten/kernels/cuda/creation.h @@ -18,6 +18,7 @@ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/fluid/platform/device_context.h" @@ -36,6 +37,12 @@ void FillConstant(const CUDAContext& dev_ctx, const Scalar& val, DenseTensor* out); +template +void FillConstantDynamicShape(const CUDAContext& dev_ctx, + const ScalarArray& shape, + const Scalar& val, + DenseTensor* out); + } // namespace pten #endif diff --git a/paddle/pten/tests/api/test_fill_api.cc b/paddle/pten/tests/api/test_fill_api.cc index 552c5e0ef96..1ebfc8e6746 100644 --- a/paddle/pten/tests/api/test_fill_api.cc +++ b/paddle/pten/tests/api/test_fill_api.cc @@ -129,19 +129,40 @@ TEST(API, ones_like) { } } -TEST(API, full) { +TEST(API, full1) { // 1. create tensor const auto alloc = std::make_shared( paddle::platform::CPUPlace()); + auto dense_shape = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::INT64, + framework::make_ddim({2}), + pten::DataLayout::NCHW)); + auto* shape_data = dense_shape->mutable_data(); + shape_data[0] = 2; + shape_data[1] = 3; + + auto dense_scalar = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::FLOAT32, + framework::make_ddim({1}), + pten::DataLayout::NCHW)); + dense_scalar->mutable_data()[0] = 1.0; + + paddle::experimental::Tensor value(dense_scalar); + + paddle::experimental::Tensor tensor_shape(dense_shape); + float val = 1.0; // 2. test API - auto out = paddle::experimental::full({3, 2}, val, pten::DataType::FLOAT32); + auto out = + paddle::experimental::full(tensor_shape, value, pten::DataType::FLOAT32); // 3. check result ASSERT_EQ(out.shape().size(), 2UL); - ASSERT_EQ(out.shape()[0], 3); + ASSERT_EQ(out.shape()[0], 2); ASSERT_EQ(out.numel(), 6); ASSERT_EQ(out.is_cpu(), true); ASSERT_EQ(out.type(), pten::DataType::FLOAT32); @@ -155,5 +176,64 @@ TEST(API, full) { } } +TEST(API, full2) { + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + + auto dense_scalar = std::make_shared( + alloc, + pten::DenseTensorMeta(pten::DataType::INT32, + framework::make_ddim({1}), + pten::DataLayout::NCHW)); + dense_scalar->mutable_data()[0] = 2; + + paddle::experimental::Tensor shape_scalar1(dense_scalar); + paddle::experimental::Tensor shape_scalar2(dense_scalar); + std::vector list_shape{shape_scalar1, + shape_scalar2}; + + float val = 1.0; + + auto out = + paddle::experimental::full(list_shape, val, pten::DataType::FLOAT32); + + ASSERT_EQ(out.shape().size(), 2UL); + ASSERT_EQ(out.shape()[0], 2); + ASSERT_EQ(out.numel(), 4); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::FLOAT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + + auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto* actual_result = dense_out->data(); + for (auto i = 0; i < 4; i++) { + ASSERT_NEAR(actual_result[i], val, 1e-6f); + } +} + +TEST(API, full3) { + std::vector vector_shape{2, 3}; + + float val = 1.0; + + auto out = + paddle::experimental::full(vector_shape, val, pten::DataType::INT32); + + ASSERT_EQ(out.shape().size(), 2UL); + ASSERT_EQ(out.shape()[0], 2); + ASSERT_EQ(out.numel(), 6); + ASSERT_EQ(out.is_cpu(), true); + ASSERT_EQ(out.type(), pten::DataType::INT32); + ASSERT_EQ(out.layout(), pten::DataLayout::NCHW); + ASSERT_EQ(out.initialized(), true); + + auto dense_out = std::dynamic_pointer_cast(out.impl()); + auto* actual_result = dense_out->data(); + for (auto i = 0; i < 6; i++) { + ASSERT_EQ(actual_result[i], 1); + } +} + } // namespace tests } // namespace paddle -- GitLab