未验证 提交 d0df5632 编写于 作者: C chentianyu03 提交者: GitHub

[pten] add split kernel (#39060)

* add split kernel

* add split kernel signature

* fix split bug

* modify MakePtenScalarArrayFromVarList

* modify MakePtenScalarArrayFromVarList

* fix split windows register error

* add test case for split kernel

* replace raw split kernel with pten kernel

* fix makeScalar/ScalarArray bug

* remove debug log

* remove int64_t type in buildPtcontext

* update by code review

* fix split dev test failed

* change DenseTensorMeta to MetaTensor

* change split api code from auto gen to manual

* split cuda kernel support bfloat16 type

* fix conflict

* rm raw split kernel

* merge develop branch

* change to pten::errors
上级 d12c3636
......@@ -22,9 +22,12 @@ limitations under the License. */
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/extension.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_kernel_info_helper.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_context.h"
#include "paddle/pten/core/kernel_factory.h"
......@@ -183,14 +186,14 @@ TEST(CustomKernel, custom_kernel_dot) {
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8,
paddle::framework::make_ddim({2, 3}),
pten::framework::make_ddim({2, 3}),
pten::DataLayout::NCHW));
auto* dense_x_data =
dense_x->mutable_data<uint8_t>(paddle::platform::CPUPlace());
auto dense_y = std::make_shared<pten::DenseTensor>(
alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8,
paddle::framework::make_ddim({2, 3}),
pten::framework::make_ddim({2, 3}),
pten::DataLayout::NCHW));
auto* dense_y_data =
dense_y->mutable_data<uint8_t>(paddle::platform::CPUPlace());
......@@ -231,8 +234,7 @@ TEST(CustomKernel, custom_kernel_dot) {
pten::DataType fake_attr_dtype = pten::DataType::UINT32;
paddle::framework::LoDTensor tmp_tensor;
tmp_tensor.mutable_data<uint8_t>({1}, pten::TransToPtenPlace(backend));
pten::Scalar fake_attr_scalar =
paddle::experimental::MakePtenScalar(tmp_tensor);
pten::Scalar fake_attr_scalar{tmp_tensor};
pten::ScalarArray fake_attr_scalar_array;
std::vector<int64_t> fake_attr_int64_vec;
std::vector<int> fake_attr_int_vec;
......
......@@ -2099,6 +2099,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
std::type_index(typeid(std::vector<int32_t>))) {
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(int32_t))) {
pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray(
&BOOST_GET_CONST(int32_t, attr_iter->second), 1)));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
......
......@@ -346,6 +346,14 @@ void BuildDygraphPtenKernelContext(
std::type_index(typeid(std::vector<int32_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int64_t))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::ScalarArray(&BOOST_GET_CONST(int64_t, attr), 1)));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int32_t))) {
kernel_ctx->EmplaceBackAttr(
std::move(pten::ScalarArray(&BOOST_GET_CONST(int32_t, attr), 1)));
} else if (attr_defs[i].type_index ==
std::type_index(typeid(std::vector<int32_t>))) {
const auto& vector_int_attr = BOOST_GET_CONST(std::vector<int>, attr);
......
......@@ -217,7 +217,7 @@ TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) {
} // namespace imperative
} // namespace paddle
USE_OP(split);
USE_OP_ITSELF(split);
USE_OP(relu);
#ifdef PADDLE_WITH_MKLDNN
USE_OP_DEVICE_KERNEL(relu, MKLDNN);
......
......@@ -172,11 +172,3 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker,
ops::SplitGradMaker<paddle::framework::OpDesc>,
ops::SplitGradMaker<paddle::imperative::OpBase>);
namespace plat = paddle::platform;
REGISTER_OP_CPU_KERNEL(
split, ops::SplitOpKernel<plat::CPUDeviceContext, double>,
ops::SplitOpKernel<plat::CPUDeviceContext, float>,
ops::SplitOpKernel<plat::CPUDeviceContext, int64_t>,
ops::SplitOpKernel<plat::CPUDeviceContext, int>,
ops::SplitOpKernel<plat::CPUDeviceContext, bool>,
ops::SplitOpKernel<plat::CPUDeviceContext, plat::float16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/split_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
split, ops::SplitOpKernel<plat::CUDADeviceContext, double>,
ops::SplitOpKernel<plat::CUDADeviceContext, float>,
ops::SplitOpKernel<plat::CUDADeviceContext, int64_t>,
ops::SplitOpKernel<plat::CUDADeviceContext, int>,
ops::SplitOpKernel<plat::CUDADeviceContext, bool>,
ops::SplitOpKernel<plat::CUDADeviceContext, plat::float16>,
ops::SplitOpKernel<plat::CUDADeviceContext, plat::bfloat16>);
......@@ -19,10 +19,8 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/pten/kernels/split_kernel.h"
namespace paddle {
namespace operators {
static inline std::vector<framework::DDim> UpdateOutsDims(
......@@ -108,56 +106,6 @@ static inline std::vector<framework::DDim> UpdateOutsDims(
}
return outs_dims;
}
template <typename DeviceContext, typename T>
class SplitOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
int num = ctx.Attr<int>("num");
std::vector<int> sections = ctx.Attr<std::vector<int>>("sections");
int axis = ctx.Attr<int>("axis");
auto in_dims = in->dims();
auto outs_number = outs.size();
bool need_resize_outs_dims = false;
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<framework::Tensor>("AxisTensor");
axis = GetDataFromTensor(axis_tensor)[0];
need_resize_outs_dims = true;
}
auto sections_tensor_list =
ctx.MultiInput<framework::Tensor>("SectionsTensorList");
if (sections_tensor_list.size() > 0) {
sections = GetDataFromTensorList(sections_tensor_list);
need_resize_outs_dims = true;
}
if (need_resize_outs_dims) {
std::vector<framework::DDim> outs_dims =
UpdateOutsDims(true, true, in_dims, num, sections, axis, outs_number);
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->Resize(outs_dims[j]);
}
}
std::vector<const framework::Tensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
shape_refer.emplace_back(outs[j]);
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
StridedMemcpyWithAxis0<T>(dev_ctx, *in, shape_refer, &outs);
} else {
math::SplitFunctor<DeviceContext, T> functor;
functor(dev_ctx, *in, shape_refer, axis, &outs);
}
}
};
template <typename T>
class SplitGradMaker : public framework::SingleGradOpMaker<T> {
......
......@@ -16,6 +16,8 @@ limitations under the License. */
#include "paddle/pten/api/include/tensor.h"
#include "paddle/pten/common/backend.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
/**
* This file stores some special APIs that are implemented manually
......@@ -28,5 +30,11 @@ namespace experimental {
// TODO(chenweihang): Replace backend by place when place is ready
PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking);
// TODO(chentianyu03): Split API has extra logic to calculate the outputs size,
// api_gen do not support
PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis);
} // namespace experimental
} // namespace paddle
......@@ -19,9 +19,12 @@ limitations under the License. */
#include "glog/logging.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/data_transform.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/core/meta_tensor.h"
#include "paddle/pten/infermeta/unary.h"
PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);
......@@ -75,6 +78,71 @@ PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) {
return out;
}
PADDLE_API std::vector<Tensor> split(const Tensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError(
"split", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "split API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "split API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto dense_x = PrepareData(x, kernel.InputAt(0), {});
// Calculate the number of out tensors
size_t out_number;
if (num_or_sections.GetData().size() == 1) {
out_number = num_or_sections.GetData()[0];
} else {
out_number = num_or_sections.GetData().size();
}
std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);
std::vector<pten::MetaTensor> meta_outs;
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(dense_outs[i]);
}
pten::SplitInferMeta(
MakeMetaTensor(*dense_x), num_or_sections, axis, &meta_outs);
using kernel_signature = void (*)(const platform::DeviceContext&,
const pten::DenseTensor&,
const pten::ScalarArray&,
const pten::Scalar&,
std::vector<pten::DenseTensor*>&);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx,
*dense_x,
pten::ScalarArray(num_or_sections),
pten::Scalar(axis),
dense_outs);
return out;
}
} // namespace experimental
} // namespace paddle
......
......@@ -36,45 +36,6 @@ std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor(
return std::make_unique<pten::DenseTensor>(src);
}
pten::Scalar MakePtenScalar(const paddle::framework::Tensor& src) {
PADDLE_ENFORCE_EQ(src.numel(),
1,
paddle::platform::errors::InvalidArgument(
"The Scalar only supports Tensor with 1 element, "
"but now Tensor has %d element.",
src.numel()));
switch (src.type()) {
case paddle::framework::proto::VarType::FP32:
return {src.template data<float>()[0]};
case paddle::framework::proto::VarType::FP64:
return {src.template data<double>()[0]};
case paddle::framework::proto::VarType::FP16:
return {src.template data<float16>()[0]};
case paddle::framework::proto::VarType::BF16:
return {src.template data<bfloat16>()[0]};
case paddle::framework::proto::VarType::INT32:
return {src.template data<int32_t>()[0]};
case paddle::framework::proto::VarType::INT64:
return {src.template data<int64_t>()[0]};
case paddle::framework::proto::VarType::INT16:
return {src.template data<int16_t>()[0]};
case paddle::framework::proto::VarType::INT8:
return {src.template data<int8_t>()[0]};
case paddle::framework::proto::VarType::UINT8:
return {src.template data<uint8_t>()[0]};
case paddle::framework::proto::VarType::BOOL:
return {src.template data<bool>()[0]};
case paddle::framework::proto::VarType::COMPLEX64:
return {src.template data<complex64>()[0]};
case paddle::framework::proto::VarType::COMPLEX128:
return {src.template data<complex128>()[0]};
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Data type error. Don't support casting a %d LoDTensor to Scalar.",
src.type()));
}
}
pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) {
auto expected_place = pten::TransToPtenPlace(pten::Backend::CPU);
if (variable.IsType<framework::LoDTensor>()) {
......@@ -82,9 +43,9 @@ pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) {
if (!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
return MakePtenScalar(tmp_tensor);
return {tmp_tensor};
} else {
return MakePtenScalar(tensor);
return {tensor};
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
......@@ -95,17 +56,7 @@ pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) {
}
pten::ScalarArray MakePtenScalarArray(const paddle::framework::Tensor& src) {
if (src.type() == paddle::framework::proto::VarType::INT64) {
return {src.data<int64_t>(), src.numel()};
} else if (src.type() == paddle::framework::proto::VarType::INT32) {
return {src.data<int32_t>(), src.numel()};
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Data type error. When cast a LoDTensor to ScalarArray, "
"the data type of LoDTensor must be int32 or int64, "
"but now data type is %s.",
src.type()));
}
return {src};
}
pten::ScalarArray MakePtenScalarArrayFromVar(
......@@ -128,6 +79,7 @@ pten::ScalarArray MakePtenScalarArrayFromVar(
}
}
// TODO(chentianyu03): Inplace with ScalarArray constructor
pten::ScalarArray MakePtenScalarArrayFromVarList(
const std::vector<framework::Variable*>& variable_list) {
if (variable_list.size() == 0) {
......@@ -135,45 +87,28 @@ pten::ScalarArray MakePtenScalarArrayFromVarList(
}
auto expected_place = pten::TransToPtenPlace(pten::Backend::CPU);
paddle::framework::proto::VarType::Type data_type;
auto* first_var = variable_list.front();
if (first_var->IsType<framework::LoDTensor>()) {
const auto& tensor = first_var->Get<framework::LoDTensor>();
data_type = tensor.type();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(first_var->Type())));
}
std::vector<int64_t> vector_data;
vector_data.reserve(variable_list.size());
if (data_type == paddle::framework::proto::VarType::INT64) {
for (auto* var : variable_list) {
if (var->IsType<framework::LoDTensor>()) {
for (auto* var : variable_list) {
paddle::framework::proto::VarType::Type data_type;
if (var->IsType<framework::LoDTensor>()) {
const auto& tensor = var->Get<framework::LoDTensor>();
data_type = tensor.type();
if (data_type == paddle::framework::proto::VarType::INT64) {
const auto& tensor = var->Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
if (tensor.IsInitialized() &&
!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
vector_data.push_back(*tmp_tensor.data<int64_t>());
} else {
vector_data.push_back(*tensor.data<int64_t>());
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(var->Type())));
}
}
} else if (data_type == paddle::framework::proto::VarType::INT32) {
for (auto* var : variable_list) {
if (var->IsType<framework::LoDTensor>()) {
} else if (data_type == paddle::framework::proto::VarType::INT32) {
const auto& tensor = var->Get<framework::LoDTensor>();
if (!platform::is_same_place(tensor.place(), expected_place)) {
if (tensor.IsInitialized() &&
!platform::is_same_place(tensor.place(), expected_place)) {
framework::LoDTensor tmp_tensor;
framework::TensorCopySync(tensor, expected_place, &tmp_tensor);
vector_data.push_back(*tmp_tensor.data<int32_t>());
......@@ -181,21 +116,24 @@ pten::ScalarArray MakePtenScalarArrayFromVarList(
vector_data.push_back(*tensor.data<int32_t>());
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(var->Type())));
PADDLE_THROW(pten::errors::InvalidArgument(
"Data type error. When cast a LoDTensor to VectorTensor, "
"the data type of LoDTensor must be int32 or int64, "
"but now data type is %s.",
data_type));
}
} else {
PADDLE_THROW(pten::errors::Unimplemented(
"Unsupport casting input `%s` type to VectorTensor when call pt "
"kernel.",
framework::ToTypeName(var->Type())));
}
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Data type error. When cast a LoDTensor to VectorTensor, "
"the data type of LoDTensor must be int32 or int64, "
"but now data type is %s.",
data_type));
}
return {vector_data};
pten::ScalarArray result{vector_data};
result.setInitByTensor(true);
return result;
}
void ResetTensorDtypeAndLayoutByArgDef(pten::TensorBase* dst,
......
......@@ -33,8 +33,6 @@ namespace experimental {
std::unique_ptr<pten::DenseTensor> MakePtenDenseTensor(
const paddle::framework::Tensor& src);
pten::Scalar MakePtenScalar(const paddle::framework::Tensor& src);
pten::ScalarArray MakePtenScalarArray(const paddle::framework::Tensor& src);
pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable);
......
......@@ -25,6 +25,7 @@ namespace experimental {
template <typename T>
class ScalarBase {
public:
bool IsInitByTensor() const { return is_init_by_tensor_; }
// Constructor support implicit
ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT
data_.f64 = val;
......@@ -103,6 +104,7 @@ class ScalarBase {
// The Tensor must have one dim
ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT
is_init_by_tensor_ = true;
PD_CHECK(
tensor.numel() == 1,
"The Scalar only supports Tensor with 1 element, but now Tensor has `",
......@@ -194,6 +196,7 @@ class ScalarBase {
friend void CopyScalar(const ScalarBase<T1>& src, ScalarBase<T2>* dst);
private:
bool is_init_by_tensor_{false};
DataType dtype_;
union data {
bool b;
......
......@@ -43,8 +43,13 @@ class ScalarArrayBase {
AssignData(date_value, n);
}
bool IsInitByTensor() const { return is_init_by_tensor_; }
void setInitByTensor(bool val) { is_init_by_tensor_ = val; }
// The Tensor must have one dim
ScalarArrayBase(const T& tensor) { // NOLINT
is_init_by_tensor_ = true;
size_t n = tensor.numel();
array_.reserve(n);
switch (tensor.dtype()) {
......@@ -66,41 +71,17 @@ class ScalarArrayBase {
// The Tensor in vec must have only one element
ScalarArrayBase(const std::vector<T>& tensor_list) { // NOLINT
auto n = tensor_list.size();
array_.reserve(n);
if (!tensor_list.empty()) {
DataType data_type = tensor_list[0].dtype();
is_init_by_tensor_ = true;
for (size_t i = 0; i < tensor_list.size(); ++i) {
DataType data_type = tensor_list[i].dtype();
switch (data_type) {
case DataType::INT32: {
for (size_t i = 0; i < n; ++i) {
PD_CHECK(tensor_list[i].dtype() == data_type,
"The data_type of tensors in the list isn't consistent."
"the first tensor is`",
data_type,
"` but `",
i,
"`th tensor is`",
tensor_list[i].dtype(),
"`.");
array_.push_back(*tensor_list[i].template data<int32_t>());
}
case DataType::INT32:
array_.push_back(*tensor_list[i].template data<int32_t>());
break;
}
case DataType::INT64: {
for (size_t i = 0; i < n; ++i) {
PD_CHECK(tensor_list[i].dtype() == data_type,
"The data_type of tensors in the list isn't consistent."
"the first tensor is`",
data_type,
"` but `",
i,
"`th tensor is`",
tensor_list[i].dtype(),
"`.");
array_.push_back(*tensor_list[i].template data<int64_t>());
}
case DataType::INT64:
array_.push_back(*tensor_list[i].template data<int64_t>());
break;
}
default:
PD_THROW(
"Data type error. Currently, The data type of ScalarArrayBase "
......@@ -136,6 +117,7 @@ class ScalarArrayBase {
// TODO(zhangyunfei) Replace std::vector with a more efficient container
// structure.
std::vector<int64_t> array_;
bool is_init_by_tensor_{false};
};
using ScalarArray =
......
......@@ -315,4 +315,137 @@ void TransferLayoutInferMeta(const MetaTensor& x,
out->set_layout(layout);
}
void SplitInferMeta(const MetaTensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis,
std::vector<MetaTensor>* out,
MetaConfig config) {
int axis_value = axis.to<int>();
int rank = x.dims().size();
PADDLE_ENFORCE_EQ(
axis_value >= -rank && axis_value < rank,
true,
paddle::platform::errors::InvalidArgument(
"The axis is expected to be in range of [%d, %d), but got %d",
-rank,
rank,
axis_value));
if (axis_value < 0) {
axis_value = axis_value + rank;
}
auto input_axis_dim = x.dims().at(axis_value);
auto num_or_sections_data = num_or_sections.GetData();
// step1: get formated sections
std::vector<int64_t> sections;
// num_or_sections is a number
if (num_or_sections_data.size() == 1) {
int num = num_or_sections_data.at(0);
PADDLE_ENFORCE_EQ(input_axis_dim % num,
0,
paddle::platform::errors::InvalidArgument(
"The input's size along the split dimension "
"must be evenly divisible by Attr(num_or_sections). "
"But received Attr(num_or_sections) "
"= %d, input(X)'s shape = [%s], Attr(dim) = %d.",
num,
x.dims(),
axis_value));
for (int i = 0; i < num; ++i) {
sections.push_back(input_axis_dim / num);
}
} else {
// num_or_sections is a sections
const int unknow_dim_val = -1;
int unknow_dim_idx = -1;
int num_of_unknow = 0;
int sum_of_section = 0;
for (size_t i = 0; i < num_or_sections_data.size(); ++i) {
sections.push_back(num_or_sections_data[i]);
if (num_or_sections_data[i] == unknow_dim_val) {
num_of_unknow++;
unknow_dim_idx = i;
} else {
sum_of_section += num_or_sections_data[i];
}
}
if (config.is_runtime) {
PADDLE_ENFORCE_LE(num_of_unknow,
1,
paddle::platform::errors::InvalidArgument(
"Only one dimension value of Attr(num_or_sections) "
"in SplitOp can be -1. "
"But received Attr(num_or_sections) = [%s].",
pten::framework::make_ddim(num_or_sections_data)));
}
if (unknow_dim_idx != -1) {
// for example, input shape = [4 ,5], axis = 1, sections = [2, 3, -1].
// input_axis_dim = 5, sum_of_sections = 5.
// the following check will fail.
PADDLE_ENFORCE_LT(
sum_of_section,
input_axis_dim,
paddle::platform::errors::InvalidArgument(
"Sum of Attr(num_or_sections) other than unknown section "
"must be less than the input's "
"size "
"along the split dimension. But received Attr(num_or_sections) "
"= [%s], input(X)'s shape = [%s], Attr(dim) = %d.",
pten::framework::make_ddim(num_or_sections_data),
x.dims(),
axis_value));
if (config.is_runtime) {
sections[unknow_dim_idx] = input_axis_dim - sum_of_section;
}
} else {
PADDLE_ENFORCE_EQ(
sum_of_section,
input_axis_dim,
paddle::platform::errors::InvalidArgument(
"Sum of Attr(num_or_sections) must be equal to the input's "
"size "
"along the split dimension. But received Attr(num_or_sections)"
" = [%s], input(X)'s shape = [%s], Attr(dim) = %d.",
pten::framework::make_ddim(num_or_sections_data),
x.dims(),
axis_value));
}
}
// setp2: fill out dims
std::vector<pten::DDim> out_dims(sections.size(), x.dims());
if (config.is_runtime || input_axis_dim > 0) {
for (size_t i = 0; i < sections.size(); ++i) {
out_dims[i][axis_value] = sections[i];
}
} else {
for (size_t i = 0; i < sections.size(); ++i) {
out_dims[i][axis_value] = -1;
}
}
for (size_t i = 0; i < sections.size(); ++i) {
if (axis_value != 0) {
// Only pass LoD when not spliting along the first dim.
(*out)[i].set_dtype(x.dtype());
(*out)[i].set_dims(out_dims[i]);
(*out)[i].set_layout(x.layout());
} else {
(*out)[i].set_dtype(x.dtype());
(*out)[i].set_dims(out_dims[i]);
(*out)[i].set_layout(x.layout());
(*out)[i].share_lod(x);
}
}
return;
}
} // namespace pten
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/meta_tensor.h"
......@@ -74,4 +75,9 @@ void TransferLayoutInferMeta(const MetaTensor& x,
DataLayout layout,
MetaTensor* out);
void SplitInferMeta(const MetaTensor& x_meta,
const ScalarArray& num_or_sections,
const Scalar& axis,
std::vector<MetaTensor>* out,
MetaConfig config = MetaConfig());
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/split_kernel.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/pten/common/float16.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/cpu/concat_and_split.h"
namespace pten {
template <typename T, typename Context>
void SplitKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) {
// need to infershape output
if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) {
std::vector<MetaTensor> out_metas;
for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]);
}
pten::SplitInferMeta(x, num_or_sections, axis_scalar, &out_metas, true);
for (size_t i = 0; i < out_metas.size(); ++i) {
outs[i]->Resize(out_metas[i].dims());
}
}
std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.Alloc(outs[j]);
shape_refer.emplace_back(outs[j]);
}
int axis = axis_scalar.to<int>();
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
paddle::operators::StridedMemcpyWithAxis0<T>(
dev_ctx, x, shape_refer, &outs);
} else {
SplitImpl<T, Context>(dev_ctx, x, shape_refer, axis, &outs);
}
}
} // namespace pten
PT_REGISTER_KERNEL(split,
CPU,
ALL_LAYOUT,
pten::SplitKernel,
float,
double,
int64_t,
int,
bool,
pten::dtype::float16) {}
......@@ -134,12 +134,12 @@ __global__ void ConcatKernel_(const T** inputs_data,
}
template <typename T>
__global__ void SplitKernel(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t* out_cols,
int out_cols_size,
T** outputs_data) {
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t* out_cols,
int out_cols_size,
T** outputs_data) {
int tid_x = blockIdx.x * blockDim.x + threadIdx.x;
int curr_segment = 0;
int curr_offset = out_cols[0];
......@@ -184,21 +184,21 @@ __device__ void SplitKernelDetail(const T* input_data,
}
template <typename T>
__global__ void SplitKernel(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T** outputs_data) {
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T** outputs_data) {
SplitKernelDetail<T>(input_data, in_row, in_col, fixed_out_col, outputs_data);
}
template <typename T>
__global__ void SplitKernel(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T* outputs_addr0,
T* outputs_addr1) {
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T* outputs_addr0,
T* outputs_addr1) {
T* outputs_data[2];
outputs_data[0] = outputs_addr0;
outputs_data[1] = outputs_addr1;
......@@ -206,13 +206,13 @@ __global__ void SplitKernel(const T* input_data,
}
template <typename T>
__global__ void SplitKernel(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T* outputs_addr0,
T* outputs_addr1,
T* outputs_addr2) {
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T* outputs_addr0,
T* outputs_addr1,
T* outputs_addr2) {
T* outputs_data[3];
outputs_data[0] = outputs_addr0;
outputs_data[1] = outputs_addr1;
......@@ -221,14 +221,14 @@ __global__ void SplitKernel(const T* input_data,
}
template <typename T>
__global__ void SplitKernel(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T* outputs_addr0,
T* outputs_addr1,
T* outputs_addr2,
T* outputs_addr3) {
__global__ void SplitKernel_(const T* input_data,
const int64_t in_row,
const int64_t in_col,
const int64_t fixed_out_col,
T* outputs_addr0,
T* outputs_addr1,
T* outputs_addr2,
T* outputs_addr3) {
T* outputs_data[4];
outputs_data[0] = outputs_addr0;
outputs_data[1] = outputs_addr1;
......@@ -497,7 +497,7 @@ void SplitImpl(const Context& context,
if (has_same_shape) {
if (o_num == 2) {
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(),
in_row,
in_col,
......@@ -505,7 +505,7 @@ void SplitImpl(const Context& context,
outputs_data[0],
outputs_data[1]);
} else if (o_num == 3) {
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(),
in_row,
in_col,
......@@ -514,7 +514,7 @@ void SplitImpl(const Context& context,
outputs_data[1],
outputs_data[2]);
} else if (o_num == 4) {
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(),
in_row,
in_col,
......@@ -524,7 +524,7 @@ void SplitImpl(const Context& context,
outputs_data[2],
outputs_data[3]);
} else {
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(), in_row, in_col, out0_col, dev_out_gpu_data);
}
} else {
......@@ -542,7 +542,7 @@ void SplitImpl(const Context& context,
int64_t* dev_outs_col_data =
reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr());
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>(
SplitKernel_<<<grid_dims, block_dims, 0, context.stream()>>>(
input.data<T>(),
in_row,
in_col,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/split_kernel.h"
#include "paddle/fluid/operators/strided_memcpy.h"
#include "paddle/pten/common/float16.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/gpu/concat_and_split.h"
namespace pten {
template <typename T, typename Context>
void SplitKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) {
// need to infershape output
if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) {
std::vector<MetaTensor> out_metas;
for (size_t i = 0; i < outs.size(); ++i) {
out_metas.push_back(outs[i]);
}
pten::SplitInferMeta(x, num_or_sections, axis_scalar, &out_metas, true);
for (size_t i = 0; i < out_metas.size(); ++i) {
outs[i]->Resize(out_metas[i].dims());
}
}
std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.Alloc(outs[j]);
shape_refer.emplace_back(outs[j]);
}
int axis = axis_scalar.to<int>();
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && outs.size() < 10) {
paddle::operators::StridedMemcpyWithAxis0<T>(
dev_ctx, x, shape_refer, &outs);
} else {
SplitImpl<T, Context>(dev_ctx, x, shape_refer, axis, &outs);
}
}
} // namespace pten
PT_REGISTER_KERNEL(split,
GPU,
ALL_LAYOUT,
pten::SplitKernel,
float,
double,
int64_t,
int,
bool,
pten::dtype::float16,
pten::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
template <typename T, typename Context>
void SplitKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis,
std::vector<DenseTensor*> out);
template <typename T, typename Context>
std::vector<DenseTensor> Split(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& num_or_sections,
const Scalar& axis) {
size_t out_number;
if (num_or_sections.GetData().size() == 1) {
out_number = num_or_sections.GetData()[0];
} else {
out_number = num_or_sections.GetData().size();
}
std::vector<MetaTensor> out_meta;
out_meta.reserve(out_number);
std::vector<DenseTensor> result;
result.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
auto dense_out = pten::Empty<T, Context>(dev_ctx);
MetaTensor tmp_meta(&dense_out);
result.push_back(dense_out);
out_meta.push_back(&result.back());
}
SplitInferMeta(x, num_or_sections, axis, &out_meta);
std::vector<DenseTensor*> outs;
outs.reserve(out_meta.size());
for (size_t i = 0; i < out_meta.size(); ++i) {
outs.push_back(&result[i]);
}
SplitKernel<T, Context>(dev_ctx, x, num_or_sections, axis, outs);
return result;
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature SplitOpArgumentMapping(const ArgumentMappingContext& ctx) {
// priority: num > SectionsTensorList > sections
// priority: AxisTensor > axis
if (paddle::any_cast<int>(ctx.Attr("num")) > 0) {
if (ctx.HasInput("AxisTensor")) {
return KernelSignature("split", {"X"}, {"num", "AxisTensor"}, {"Out"});
} else {
return KernelSignature("split", {"X"}, {"num", "axis"}, {"Out"});
}
}
if (ctx.InputSize("SectionsTensorList") > 0) {
if (ctx.HasInput("AxisTensor")) {
return KernelSignature(
"split", {"X"}, {"SectionsTensorList", "AxisTensor"}, {"Out"});
} else {
return KernelSignature(
"split", {"X"}, {"SectionsTensorList", "axis"}, {"Out"});
}
}
if (ctx.HasInput("AxisTensor")) {
return KernelSignature("split", {"X"}, {"sections", "AxisTensor"}, {"Out"});
} else {
return KernelSignature("split", {"X"}, {"sections", "axis"}, {"Out"});
}
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(split, pten::SplitOpArgumentMapping);
......@@ -22,6 +22,6 @@ cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api
cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_conj_api SRCS test_conj_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_concat_api SRCS test_concat_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_split_api SRCS test_split_api.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_data_transform SRCS test_data_transform.cc DEPS pten_tensor pten_api pten_api_utils)
cc_test(test_sparse_utils_api SRCS test_sparse_utils_api.cc DEPS pten_tensor pten_api pten_api_utils)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/api/include/api.h"
#include "paddle/pten/api/include/manual_api.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace paddle {
namespace tests {
namespace framework = paddle::framework;
using DDim = pten::framework::DDim;
// TODO(chentianyu03): Remove this test after the API is used in the dygraph
TEST(API, split) {
// 1. create tensor
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace());
auto dense_x = std::make_shared<pten::DenseTensor>(
alloc.get(),
pten::DenseTensorMeta(pten::DataType::FLOAT32,
pten::framework::make_ddim({4, 10}),
pten::DataLayout::NCHW));
auto* dense_x_data =
dense_x->mutable_data<float>(paddle::platform::CPUPlace());
for (size_t i = 0; i < 4; ++i) {
for (size_t j = 0; j < 10; ++j) {
dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0;
}
}
paddle::experimental::Tensor x(dense_x);
// 2. test API
auto out = paddle::experimental::split(x, {2, 2}, 0);
// 3. check result
ASSERT_EQ(out.size(), static_cast<size_t>(2));
ASSERT_EQ(out[0].dims().size(), 2);
ASSERT_EQ(out[0].dims()[0], 2);
ASSERT_EQ(out[0].dims()[1], 10);
ASSERT_EQ(out[0].type(), pten::DataType::FLOAT32);
ASSERT_EQ(out[0].layout(), pten::DataLayout::NCHW);
ASSERT_EQ(out[1].dims().size(), 2);
ASSERT_EQ(out[1].dims()[0], 2);
ASSERT_EQ(out[1].dims()[1], 10);
ASSERT_EQ(out[1].type(), pten::DataType::FLOAT32);
ASSERT_EQ(out[1].layout(), pten::DataLayout::NCHW);
auto out_data_0 = std::dynamic_pointer_cast<pten::DenseTensor>(out[0].impl())
->data<float>();
auto out_data_1 = std::dynamic_pointer_cast<pten::DenseTensor>(out[1].impl())
->data<float>();
for (size_t i = 0; i < 4; ++i) {
if (i < 20) {
ASSERT_NEAR(dense_x_data[i], out_data_0[i], 1e-6);
} else {
ASSERT_NEAR(dense_x_data[i], out_data_1[i - 20], 1e-6);
}
}
}
} // namespace tests
} // namespace paddle
......@@ -11,4 +11,5 @@ cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_uti
cc_test(test_sum_dev_api SRCS test_sum_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_split_dev_api SRCS test_split_dev_api.cc DEPS pten pten_api_utils)
cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS pten pten_api_utils)
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/kernels/split_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/pten/api/include/manual_api.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten {
namespace tests {
namespace framework = paddle::framework;
using DDim = pten::framework::DDim;
TEST(DEV_API, split) {
// 1. create tensor
const auto alloc = std::make_unique<paddle::experimental::DefaultAllocator>(
pten::CPUPlace());
pten::DenseTensor dense_x(
alloc.get(),
pten::DenseTensorMeta(pten::DataType::FLOAT32,
pten::framework::make_ddim({4, 10}),
pten::DataLayout::NCHW));
pten::CPUContext dev_ctx;
dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance()
.GetAllocator(paddle::platform::CPUPlace())
.get());
dev_ctx.Init();
auto* dense_x_data = dev_ctx.Alloc<float>(&dense_x);
for (size_t i = 0; i < 4; ++i) {
for (size_t j = 0; j < 10; ++j) {
dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0;
}
}
// 2. test API
auto out = pten::Split<float>(dev_ctx, dense_x, {2, 2}, 0);
// 3. check result
ASSERT_EQ(out.size(), static_cast<size_t>(2));
ASSERT_EQ(out[0].dims().size(), 2);
ASSERT_EQ(out[0].dims()[0], 2);
ASSERT_EQ(out[0].dims()[1], 10);
ASSERT_EQ(out[0].meta().dtype, pten::DataType::FLOAT32);
ASSERT_EQ(out[0].meta().layout, pten::DataLayout::NCHW);
ASSERT_EQ(out[1].dims().size(), 2);
ASSERT_EQ(out[1].dims()[0], 2);
ASSERT_EQ(out[1].dims()[1], 10);
ASSERT_EQ(out[1].meta().dtype, pten::DataType::FLOAT32);
ASSERT_EQ(out[1].meta().layout, pten::DataLayout::NCHW);
auto out_data_0 = out[0].data<float>();
auto out_data_1 = out[1].data<float>();
for (size_t i = 0; i < 4; ++i) {
if (i < 20) {
ASSERT_NEAR(dense_x_data[i], out_data_0[i], 1e-6);
} else {
ASSERT_NEAR(dense_x_data[i], out_data_1[i - 20], 1e-6);
}
}
}
} // namespace tests
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册