未验证 提交 0f24de83 编写于 作者: Z zyfncg 提交者: GitHub

【PTen】Add Scalar and ScalarArray in pten (#37409)

* add scalar and scalar_array

* remove DenseTensor include from Scalar and ScalarArray

* remove inner header from scalar_array

* refactor the method of fill_constant and add some comment
上级 16590799
......@@ -254,7 +254,7 @@ class CumCUDAKernel : public framework::OpKernel<T> {
dim3 transpose_grids((width + tile_size - 1) / tile_size,
(height + tile_size - 1) / tile_size);
auto& dev_ctx = context.template device_context<DeviceContext>();
Tensor tmp;
framework::Tensor tmp;
tmp.Resize(out_dims);
auto* tmp_data = tmp.mutable_data<T>(context.GetPlace());
T* next_in_data = out_data;
......
......@@ -19,7 +19,6 @@
namespace paddle {
namespace operators {
namespace math {
using Tensor = framework::Tensor;
std::vector<TreeNode> Tree2ColUtil::construct_patch(
size_t root, int max_depth, const std::vector<std::vector<int>> &tr) {
std::stack<TreeNode, std::deque<TreeNode>> stack;
......@@ -51,7 +50,7 @@ std::vector<TreeNode> Tree2ColUtil::construct_patch(
return patch;
}
void Tree2ColUtil::construct_tree(const paddle::Tensor &EdgeSet,
void Tree2ColUtil::construct_tree(const framework::Tensor &EdgeSet,
std::vector<std::vector<int>> *tr,
size_t *node_count) {
auto edge_set_dims = EdgeSet.dims();
......
......@@ -21,8 +21,6 @@
#include "paddle/fluid/operators/math/math_function.h"
namespace paddle {
using Tensor = framework::Tensor;
using DDim = framework::DDim;
namespace operators {
namespace math {
class TreeNode {
......@@ -64,7 +62,7 @@ class Tree2ColUtil {
static std::vector<TreeNode> construct_patch(
size_t root, int max_depth, const std::vector<std::vector<int>> &tr);
static void construct_tree(const Tensor &EdgeSet,
static void construct_tree(const framework::Tensor &EdgeSet,
std::vector<std::vector<int>> *tr,
size_t *node_count);
};
......
......@@ -37,6 +37,7 @@ limitations under the License. */
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/layout.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
// original custom op headers
#include "paddle/pten/api/ext/dispatch.h"
......
......@@ -18,11 +18,12 @@
#include "paddle/pten/common/backend.h"
#include "paddle/pten/common/data_type.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
namespace paddle {
namespace experimental {
PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape,
PD_DLL_DECL Tensor full(const ScalarArray& shape,
const Scalar& value,
DataType dtype = DataType::FLOAT32,
Backend backend = Backend::CPU,
......
......@@ -34,7 +34,7 @@ PT_DECLARE_MODULE(CreationCUDA);
namespace paddle {
namespace experimental {
PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape,
PD_DLL_DECL Tensor full(const ScalarArray& shape,
const Scalar& value,
DataType dtype,
Backend backend,
......@@ -42,14 +42,15 @@ PD_DLL_DECL Tensor full(const std::vector<int64_t>& shape,
// 1. Get kernel signature and kernel
pten::KernelKey kernel_key{backend, layout, dtype};
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"fill_constant.scalar", kernel_key);
"fill_constant", kernel_key);
// 2. Get Device Context
auto* dev_ctx = GetDeviceContextByBackend(kernel_key.backend());
auto kernel_context = pten::KernelContext(dev_ctx);
// 3. Auto data transform
kernel_context.EmplaceBackAttr(value);
kernel_context.EmplaceBackAttr(pten::ScalarArray(shape));
kernel_context.EmplaceBackAttr(pten::Scalar(value));
// 4. InferShape
auto out_meta = pten::FullInferShape(shape, dtype, layout);
......@@ -94,7 +95,7 @@ PD_DLL_DECL Tensor full_like(const Tensor& x,
// 3. Auto data transform
auto dense_x = std::dynamic_pointer_cast<pten::DenseTensor>(x.impl());
kernel_context.EmplaceBackAttr(value);
kernel_context.EmplaceBackAttr(pten::Scalar(value));
// 4. InferShape
auto out_meta = FullLikeInferShape(dense_x->meta(), dtype, layout);
......
......@@ -219,6 +219,7 @@ template PD_DLL_DECL const int32_t *Tensor::data<int32_t>() const;
template PD_DLL_DECL const uint8_t *Tensor::data<uint8_t>() const;
template PD_DLL_DECL const int8_t *Tensor::data<int8_t>() const;
template PD_DLL_DECL const int16_t *Tensor::data<int16_t>() const;
template PD_DLL_DECL const uint16_t *Tensor::data<uint16_t>() const;
template PD_DLL_DECL const bool *Tensor::data<bool>() const;
template PD_DLL_DECL const paddle::platform::complex<float>
*Tensor::data<paddle::platform::complex<float>>() const;
......@@ -226,6 +227,8 @@ template PD_DLL_DECL const paddle::platform::complex<double>
*Tensor::data<paddle::platform::complex<double>>() const;
template PD_DLL_DECL const paddle::platform::float16 *
Tensor::data<paddle::platform::float16>() const;
template PD_DLL_DECL const paddle::platform::bfloat16 *
Tensor::data<paddle::platform::bfloat16>() const;
template <typename T>
T *Tensor::data() {
......
......@@ -18,69 +18,217 @@ limitations under the License. */
#include <limits>
#include "paddle/pten/api/ext/exception.h"
#include "paddle/pten/api/include/tensor.h"
namespace paddle {
namespace experimental {
class Scalar {
template <typename T>
class ScalarBase {
public:
// Constructor support implicit
Scalar(float val) : tag(Tag::HAS_F) { data_.f = val; } // NOLINT
ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT
data_.f64 = val;
}
ScalarBase(float val) : dtype_(DataType::FLOAT32) { // NOLINT
data_.f32 = val;
}
ScalarBase(float16 val) : dtype_(DataType::FLOAT16) { // NOLINT
data_.f16 = val;
}
Scalar(double val) : tag(Tag::HAS_D) { data_.d = val; } // NOLINT
ScalarBase(bfloat16 val) : dtype_(DataType::BFLOAT16) { // NOLINT
data_.bf16 = val;
}
Scalar(int32_t val) : tag(Tag::HAS_I32) { data_.i32 = val; } // NOLINT
ScalarBase(int64_t val) : dtype_(DataType::INT64) { // NOLINT
data_.i64 = val;
}
Scalar(int64_t val) : tag(Tag::HAS_I64) { data_.i64 = val; } // NOLINT
ScalarBase(int32_t val) : dtype_(DataType::INT32) { // NOLINT
data_.i32 = val;
}
Scalar(bool val) : tag(Tag::HAS_B) { data_.b = val; } // NOLINT
ScalarBase(int16_t val) : dtype_(DataType::INT16) { // NOLINT
data_.i16 = val;
}
Scalar(const std::string& str_value) : tag(Tag::HAS_D) { // NOLINT
ScalarBase(int8_t val) : dtype_(DataType::INT8) { // NOLINT
data_.i8 = val;
}
ScalarBase(uint64_t val) : dtype_(DataType::UINT64) { // NOLINT
data_.ui64 = val;
}
ScalarBase(uint32_t val) : dtype_(DataType::UINT32) { // NOLINT
data_.ui32 = val;
}
ScalarBase(uint16_t val) : dtype_(DataType::UINT16) { // NOLINT
data_.ui16 = val;
}
ScalarBase(uint8_t val) : dtype_(DataType::UINT8) { // NOLINT
data_.ui8 = val;
}
ScalarBase(bool val) : dtype_(DataType::BOOL) { // NOLINT
data_.b = val;
}
ScalarBase(complex64 val) : dtype_(DataType::COMPLEX64) { // NOLINT
data_.c64 = val;
}
ScalarBase(complex128 val) : dtype_(DataType::COMPLEX128) { // NOLINT
data_.c128 = val;
}
// The compatible method for fliud operators,
// and it will be removed in the future.
explicit ScalarBase(const std::string& str_value)
: dtype_(DataType::FLOAT64) {
if (str_value == "inf") {
data_.d = std::numeric_limits<double>::infinity();
data_.f64 = std::numeric_limits<double>::infinity();
} else if (str_value == "-inf") {
data_.d = -std::numeric_limits<double>::infinity();
data_.f64 = -std::numeric_limits<double>::infinity();
} else if (str_value == "nan") {
data_.d = std::numeric_limits<double>::quiet_NaN();
data_.f64 = std::numeric_limits<double>::quiet_NaN();
} else {
data_.d = std::stod(str_value);
data_.f64 = std::stod(str_value);
}
}
// The Tensor must have one dim
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT
PD_CHECK(
tensor.numel() == 1,
"The Scalar only supports Tensor with 1 element, but now Tensor has `",
tensor.numel(),
"` element.");
switch (dtype_) {
case DataType::FLOAT32:
data_.f32 = tensor.template data<float>()[0];
break;
case DataType::FLOAT64:
data_.f64 = tensor.template data<double>()[0];
break;
case DataType::FLOAT16:
data_.f16 = tensor.template data<float16>()[0];
break;
case DataType::BFLOAT16:
data_.bf16 = tensor.template data<bfloat16>()[0];
break;
case DataType::INT32:
data_.i32 = tensor.template data<int32_t>()[0];
break;
case DataType::INT64:
data_.i64 = tensor.template data<int64_t>()[0];
break;
case DataType::INT16:
data_.i16 = tensor.template data<int16_t>()[0];
break;
case DataType::INT8:
data_.i8 = tensor.template data<int8_t>()[0];
break;
case DataType::UINT16:
data_.ui16 = tensor.template data<uint16_t>()[0];
break;
case DataType::UINT8:
data_.ui8 = tensor.template data<uint8_t>()[0];
break;
case DataType::BOOL:
data_.b = tensor.template data<bool>()[0];
break;
case DataType::COMPLEX64:
data_.c64 = tensor.template data<complex64>()[0];
break;
case DataType::COMPLEX128:
data_.c128 = tensor.template data<complex128>()[0];
break;
default:
PD_THROW("Invalid tensor data type `", dtype_, "`.");
}
}
template <typename T>
inline T to() const {
switch (tag) {
case Tag::HAS_F:
return static_cast<T>(data_.f);
case Tag::HAS_D:
return static_cast<T>(data_.d);
case Tag::HAS_I32:
return static_cast<T>(data_.i32);
case Tag::HAS_I64:
return static_cast<T>(data_.i64);
case Tag::HAS_B:
return static_cast<T>(data_.b);
template <typename OtherT>
ScalarBase(const ScalarBase<OtherT>& other) {
CopyScalar(other, this);
}
template <typename RT>
inline RT to() const {
switch (dtype_) {
case DataType::FLOAT32:
return static_cast<RT>(data_.f32);
case DataType::FLOAT64:
return static_cast<RT>(data_.f64);
case DataType::FLOAT16:
return static_cast<RT>(data_.f16);
case DataType::BFLOAT16:
return static_cast<RT>(data_.bf16);
case DataType::INT32:
return static_cast<RT>(data_.i32);
case DataType::INT64:
return static_cast<RT>(data_.i64);
case DataType::INT16:
return static_cast<RT>(data_.i16);
case DataType::INT8:
return static_cast<RT>(data_.i8);
case DataType::UINT16:
return static_cast<RT>(data_.ui16);
case DataType::UINT8:
return static_cast<RT>(data_.ui8);
case DataType::BOOL:
return static_cast<RT>(data_.b);
case DataType::COMPLEX64:
return static_cast<RT>(data_.c64);
case DataType::COMPLEX128:
return static_cast<RT>(data_.c128);
default:
PD_THROW("Invalid enum scalar type tag `", static_cast<int>(tag), "`.");
PD_THROW("Invalid enum scalar data type `", dtype_, "`.");
}
}
private:
enum class Tag { HAS_F, HAS_D, HAS_I32, HAS_I64, HAS_B };
Tag tag;
template <typename T1, typename T2>
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);
private:
DataType dtype_;
union data {
float f;
double d;
bool b;
int8_t i8;
int16_t i16;
int32_t i32;
int64_t i64;
bool b;
uint8_t ui8;
uint16_t ui16;
uint32_t ui32;
uint64_t ui64;
bfloat16 bf16;
float16 f16;
float f32;
double f64;
complex64 c64;
complex128 c128;
} data_;
};
template <typename T1, typename T2>
void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst) {
dst->dtype_ = src.dtype_;
dst->data_.c128 = src.data_.c128;
}
using Scalar = paddle::experimental::ScalarBase<paddle::experimental::Tensor>;
} // namespace experimental
} // namespace paddle
namespace pten {
using Scalar = paddle::experimental::Scalar;
class DenseTensor;
using Scalar = paddle::experimental::ScalarBase<DenseTensor>;
} // namespace pten
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/pten/api/ext/exception.h"
#include "paddle/pten/api/include/tensor.h"
namespace paddle {
namespace experimental {
template <typename T>
class ScalarArrayBase {
public:
// Constructor support implicit
ScalarArrayBase() = default;
ScalarArrayBase(const std::vector<int64_t>& vec) : array_(vec) {} // NOLINT
ScalarArrayBase(std::initializer_list<int64_t> array_list)
: array_(array_list) {}
ScalarArrayBase(const int64_t* date_value, int64_t n) {
AssignData(date_value, n);
}
ScalarArrayBase(const int32_t* date_value, int64_t n) {
AssignData(date_value, n);
}
// The Tensor must have one dim
ScalarArrayBase(const T& tensor) { // NOLINT
size_t n = tensor.numel();
array_.reserve(n);
switch (tensor.type()) {
case DataType::INT32:
AssignData(tensor.template data<int32_t>(), n);
break;
case DataType::INT64:
AssignData(tensor.template data<int64_t>(), n);
break;
default:
PD_THROW(
"Data type error. Currently, The data type of ScalarArrayBase "
"only supports Tensor with int32 and int64, "
"but now received `",
tensor.type(),
"`.");
}
}
// The Tensor in vec must have only one element
ScalarArrayBase(const std::vector<T>& tensor_list) { // NOLINT
auto n = tensor_list.size();
array_.reserve(n);
if (!tensor_list.empty()) {
DataType data_type = tensor_list[0].dtype();
switch (data_type) {
case DataType::INT32: {
for (size_t i = 0; i < n; ++i) {
PD_CHECK(tensor_list[i].dtype() == data_type,
"The data_type of tensors in the list isn't consistent."
"the first tensor is`",
data_type,
"` but `",
i,
"`th tensor is`",
tensor_list[i].dtype(),
"`.");
array_.push_back(*tensor_list[i].template data<int32_t>());
}
break;
}
case DataType::INT64: {
for (size_t i = 0; i < n; ++i) {
PD_CHECK(tensor_list[i].dtype() == data_type,
"The data_type of tensors in the list isn't consistent."
"the first tensor is`",
data_type,
"` but `",
i,
"`th tensor is`",
tensor_list[i].dtype(),
"`.");
array_.push_back(*tensor_list[i].template data<int64_t>());
}
break;
}
default:
PD_THROW(
"Data type error. Currently, The data type of ScalarArrayBase "
"only supports Tensor with int32 and int64, "
"but now received `",
data_type,
"`.");
}
}
}
template <typename OtherT>
ScalarArrayBase(const ScalarArrayBase<OtherT>& other)
: array_(other.GetData()) {}
const std::vector<int64_t>& GetData() const { return array_; }
private:
/// \brief Assign the data_ from const data pointer value of type T.
template <typename TYPE>
void AssignData(const TYPE* value_data, int64_t n) {
if (value_data) {
array_.reserve(n);
for (auto i = 0; i < n; ++i) {
array_.push_back(static_cast<int64_t>(value_data[i]));
}
} else {
PD_THROW("The input data pointer is null.");
}
}
private:
// TODO(zhangyunfei) Replace std::vector with a more efficient container
// structure.
std::vector<int64_t> array_;
};
using ScalarArray =
paddle::experimental::ScalarArrayBase<paddle::experimental::Tensor>;
} // namespace experimental
} // namespace paddle
namespace pten {
class DenseTensor;
using ScalarArray = paddle::experimental::ScalarArrayBase<DenseTensor>;
} // namespace pten
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_def.h"
......@@ -209,6 +210,8 @@ struct KernelImpl<Return (*)(Args...), kernel_fn> {
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(DataType);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int64_t>&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const ScalarArray&);
PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector<int>&);
/* Output Helpers */
......
......@@ -24,4 +24,11 @@ DenseTensorMeta FullInferShape(const std::vector<int64_t>& shape,
return {dtype, out_dims, layout};
}
DenseTensorMeta FullInferShape(const ScalarArray& shape,
DataType dtype,
DataLayout layout) {
const auto& out_dims = paddle::framework::make_ddim(shape.GetData());
return {dtype, out_dims, layout};
}
} // namespace pten
......@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/tensor_meta.h"
namespace pten {
......@@ -31,4 +31,8 @@ DenseTensorMeta FullInferShape(const std::vector<int64_t>& shape,
DataType dtype,
DataLayout layout);
DenseTensorMeta FullInferShape(const ScalarArray& shape,
DataType dtype,
DataLayout layout);
} // namespace pten
......@@ -57,6 +57,15 @@ void FillConstant(const CPUContext& dev_ctx,
eigen::fill<CPUContext, T>(dev_ctx, out, val.to<T>());
}
template <typename T>
void FillConstantDynamicShape(const CPUContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData()));
eigen::fill<CPUContext, T>(dev_ctx, out, val.to<T>());
}
} // namespace pten
PT_REGISTER_MODULE(CreationCPU);
......@@ -87,3 +96,19 @@ PT_REGISTER_KERNEL("fill_constant.scalar",
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL("fill_constant",
CPU,
ANY,
pten::FillConstantDynamicShape,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -15,6 +15,7 @@
#pragma once
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -33,4 +34,10 @@ void FillConstant(const CPUContext& dev_ctx,
const Scalar& val,
DenseTensor* out);
template <typename T>
void FillConstantDynamicShape(const CPUContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out);
} // namespace pten
......@@ -58,6 +58,15 @@ void FillConstant(const CUDAContext& dev_ctx,
eigen::fill<CUDAContext, T>(dev_ctx, out, val.to<T>());
}
template <typename T>
void FillConstantDynamicShape(const CUDAContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData()));
eigen::fill<CUDAContext, T>(dev_ctx, out, val.to<T>());
}
} // namespace pten
PT_REGISTER_MODULE(CreationCUDA);
......@@ -87,3 +96,18 @@ PT_REGISTER_KERNEL("fill_constant.scalar",
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
PT_REGISTER_KERNEL("fill_constant",
CUDA,
ANY,
pten::FillConstantDynamicShape,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
paddle::platform::float16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {}
......@@ -18,6 +18,7 @@
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -36,6 +37,12 @@ void FillConstant(const CUDAContext& dev_ctx,
const Scalar& val,
DenseTensor* out);
template <typename T>
void FillConstantDynamicShape(const CUDAContext& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out);
} // namespace pten
#endif
......@@ -129,19 +129,40 @@ TEST(API, ones_like) {
}
}
TEST(API, full) {
TEST(API, full1) {
// 1. create tensor
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_shape = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::INT64,
framework::make_ddim({2}),
pten::DataLayout::NCHW));
auto* shape_data = dense_shape->mutable_data<int64_t>();
shape_data[0] = 2;
shape_data[1] = 3;
auto dense_scalar = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::FLOAT32,
framework::make_ddim({1}),
pten::DataLayout::NCHW));
dense_scalar->mutable_data<float>()[0] = 1.0;
paddle::experimental::Tensor value(dense_scalar);
paddle::experimental::Tensor tensor_shape(dense_shape);
float val = 1.0;
// 2. test API
auto out = paddle::experimental::full({3, 2}, val, pten::DataType::FLOAT32);
auto out =
paddle::experimental::full(tensor_shape, value, pten::DataType::FLOAT32);
// 3. check result
ASSERT_EQ(out.shape().size(), 2UL);
ASSERT_EQ(out.shape()[0], 3);
ASSERT_EQ(out.shape()[0], 2);
ASSERT_EQ(out.numel(), 6);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
......@@ -155,5 +176,64 @@ TEST(API, full) {
}
}
TEST(API, full2) {
const auto alloc = std::make_shared<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_scalar = std::make_shared<pten::DenseTensor>(
alloc,
pten::DenseTensorMeta(pten::DataType::INT32,
framework::make_ddim({1}),
pten::DataLayout::NCHW));
dense_scalar->mutable_data<int32_t>()[0] = 2;
paddle::experimental::Tensor shape_scalar1(dense_scalar);
paddle::experimental::Tensor shape_scalar2(dense_scalar);
std::vector<paddle::experimental::Tensor> list_shape{shape_scalar1,
shape_scalar2};
float val = 1.0;
auto out =
paddle::experimental::full(list_shape, val, pten::DataType::FLOAT32);
ASSERT_EQ(out.shape().size(), 2UL);
ASSERT_EQ(out.shape()[0], 2);
ASSERT_EQ(out.numel(), 4);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::FLOAT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
auto dense_out = std::dynamic_pointer_cast<pten::DenseTensor>(out.impl());
auto* actual_result = dense_out->data<float>();
for (auto i = 0; i < 4; i++) {
ASSERT_NEAR(actual_result[i], val, 1e-6f);
}
}
TEST(API, full3) {
std::vector<int64_t> vector_shape{2, 3};
float val = 1.0;
auto out =
paddle::experimental::full(vector_shape, val, pten::DataType::INT32);
ASSERT_EQ(out.shape().size(), 2UL);
ASSERT_EQ(out.shape()[0], 2);
ASSERT_EQ(out.numel(), 6);
ASSERT_EQ(out.is_cpu(), true);
ASSERT_EQ(out.type(), pten::DataType::INT32);
ASSERT_EQ(out.layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out.initialized(), true);
auto dense_out = std::dynamic_pointer_cast<pten::DenseTensor>(out.impl());
auto* actual_result = dense_out->data<int>();
for (auto i = 0; i < 6; i++) {
ASSERT_EQ(actual_result[i], 1);
}
}
} // namespace tests
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册