diff --git a/paddle/fluid/framework/custom_kernel_test.cc b/paddle/fluid/framework/custom_kernel_test.cc index 29072551c80768d2313ba4399952c21644e079e6..63dd583504d601813d8e546d0c3e8611012ae2af 100644 --- a/paddle/fluid/framework/custom_kernel_test.cc +++ b/paddle/fluid/framework/custom_kernel_test.cc @@ -22,9 +22,12 @@ limitations under the License. */ #include #include #include "paddle/extension.h" +#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_kernel_info_helper.h" +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/fluid/platform/device_context.h" #include "paddle/pten/api/lib/utils/allocator.h" -#include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_context.h" #include "paddle/pten/core/kernel_factory.h" @@ -183,14 +186,14 @@ TEST(CustomKernel, custom_kernel_dot) { paddle::platform::CPUPlace()); auto dense_x = std::make_shared( alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8, - paddle::framework::make_ddim({2, 3}), + pten::framework::make_ddim({2, 3}), pten::DataLayout::NCHW)); auto* dense_x_data = dense_x->mutable_data(paddle::platform::CPUPlace()); auto dense_y = std::make_shared( alloc.get(), pten::DenseTensorMeta(pten::DataType::UINT8, - paddle::framework::make_ddim({2, 3}), + pten::framework::make_ddim({2, 3}), pten::DataLayout::NCHW)); auto* dense_y_data = dense_y->mutable_data(paddle::platform::CPUPlace()); @@ -231,8 +234,7 @@ TEST(CustomKernel, custom_kernel_dot) { pten::DataType fake_attr_dtype = pten::DataType::UINT32; paddle::framework::LoDTensor tmp_tensor; tmp_tensor.mutable_data({1}, pten::TransToPtenPlace(backend)); - pten::Scalar fake_attr_scalar = - paddle::experimental::MakePtenScalar(tmp_tensor); + pten::Scalar fake_attr_scalar{tmp_tensor}; pten::ScalarArray fake_attr_scalar_array; std::vector fake_attr_int64_vec; std::vector fake_attr_int_vec; diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 754a0947ae4b88521eddf1e5d01688cfb413a144..392047c150dc125b3d90a4ede323e3c45e2525ce 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -2099,6 +2099,10 @@ void OperatorWithKernel::BuildPtenKernelContext( std::type_index(typeid(std::vector))) { pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray( BOOST_GET_CONST(std::vector, attr_iter->second)))); + } else if (std::type_index(attr_iter->second.type()) == + std::type_index(typeid(int32_t))) { + pt_kernel_context->EmplaceBackAttr(std::move(pten::ScalarArray( + &BOOST_GET_CONST(int32_t, attr_iter->second), 1))); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` to ScalarArray when " diff --git a/paddle/fluid/imperative/prepared_operator.h b/paddle/fluid/imperative/prepared_operator.h index 90b75680ef766e1fcb3f8cd8a111b78706ed457a..b3b3725ba993e373446dea30fe1d35ce6e26ee23 100644 --- a/paddle/fluid/imperative/prepared_operator.h +++ b/paddle/fluid/imperative/prepared_operator.h @@ -346,6 +346,14 @@ void BuildDygraphPtenKernelContext( std::type_index(typeid(std::vector))) { kernel_ctx->EmplaceBackAttr(std::move( pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(int64_t))) { + kernel_ctx->EmplaceBackAttr( + std::move(pten::ScalarArray(&BOOST_GET_CONST(int64_t, attr), 1))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(int32_t))) { + kernel_ctx->EmplaceBackAttr( + std::move(pten::ScalarArray(&BOOST_GET_CONST(int32_t, attr), 1))); } else if (attr_defs[i].type_index == std::type_index(typeid(std::vector))) { const auto& vector_int_attr = BOOST_GET_CONST(std::vector, attr); diff --git a/paddle/fluid/imperative/tests/test_prepare_op.cc b/paddle/fluid/imperative/tests/test_prepare_op.cc index b4ff3cff38217a57c0b1091c3e003043ca4c9673..fa52aa6d0af61578e18d51e8b95c13b5d383c858 100644 --- a/paddle/fluid/imperative/tests/test_prepare_op.cc +++ b/paddle/fluid/imperative/tests/test_prepare_op.cc @@ -217,7 +217,7 @@ TEST(test_prepare_op, test_prepare_data_cpu_mkldnn) { } // namespace imperative } // namespace paddle -USE_OP(split); +USE_OP_ITSELF(split); USE_OP(relu); #ifdef PADDLE_WITH_MKLDNN USE_OP_DEVICE_KERNEL(relu, MKLDNN); diff --git a/paddle/fluid/operators/split_op.cc b/paddle/fluid/operators/split_op.cc index 5bd699e08abbcad5524a500c28ec7d7768dc18f0..79636aced0333284d00f0bbcd96616c01ffd88c9 100644 --- a/paddle/fluid/operators/split_op.cc +++ b/paddle/fluid/operators/split_op.cc @@ -172,11 +172,3 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker, ops::SplitGradMaker, ops::SplitGradMaker); -namespace plat = paddle::platform; -REGISTER_OP_CPU_KERNEL( - split, ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel); diff --git a/paddle/fluid/operators/split_op.cu.cc b/paddle/fluid/operators/split_op.cu.cc deleted file mode 100644 index a8a1383614bddb24b285734edb6f74e2789fdfeb..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/split_op.cu.cc +++ /dev/null @@ -1,25 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/split_op.h" -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - split, ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel, - ops::SplitOpKernel); diff --git a/paddle/fluid/operators/split_op.h b/paddle/fluid/operators/split_op.h index 96ac2c7a1bd086c2ca937d26160a6ac9316c92cc..0538fad08278e774889cc55cae9f5b72da0d27e3 100644 --- a/paddle/fluid/operators/split_op.h +++ b/paddle/fluid/operators/split_op.h @@ -19,10 +19,8 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/concat_and_split.h" -#include "paddle/fluid/operators/strided_memcpy.h" #include "paddle/fluid/operators/utils.h" - +#include "paddle/pten/kernels/split_kernel.h" namespace paddle { namespace operators { static inline std::vector UpdateOutsDims( @@ -108,56 +106,6 @@ static inline std::vector UpdateOutsDims( } return outs_dims; } -template -class SplitOpKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* in = ctx.Input("X"); - auto outs = ctx.MultiOutput("Out"); - int num = ctx.Attr("num"); - std::vector sections = ctx.Attr>("sections"); - int axis = ctx.Attr("axis"); - - auto in_dims = in->dims(); - auto outs_number = outs.size(); - - bool need_resize_outs_dims = false; - if (ctx.HasInput("AxisTensor")) { - auto* axis_tensor = ctx.Input("AxisTensor"); - axis = GetDataFromTensor(axis_tensor)[0]; - need_resize_outs_dims = true; - } - auto sections_tensor_list = - ctx.MultiInput("SectionsTensorList"); - if (sections_tensor_list.size() > 0) { - sections = GetDataFromTensorList(sections_tensor_list); - need_resize_outs_dims = true; - } - - if (need_resize_outs_dims) { - std::vector outs_dims = - UpdateOutsDims(true, true, in_dims, num, sections, axis, outs_number); - for (size_t j = 0; j < outs.size(); ++j) { - outs[j]->Resize(outs_dims[j]); - } - } - - std::vector shape_refer; - for (size_t j = 0; j < outs.size(); ++j) { - outs[j]->mutable_data(ctx.GetPlace()); - shape_refer.emplace_back(outs[j]); - } - - auto& dev_ctx = ctx.template device_context(); - // Sometimes direct copies will be faster, this maybe need deeply analysis. - if (axis == 0 && outs.size() < 10) { - StridedMemcpyWithAxis0(dev_ctx, *in, shape_refer, &outs); - } else { - math::SplitFunctor functor; - functor(dev_ctx, *in, shape_refer, axis, &outs); - } - } -}; template class SplitGradMaker : public framework::SingleGradOpMaker { diff --git a/paddle/pten/api/include/manual_api.h b/paddle/pten/api/include/manual_api.h index 3bd7e60154d06a6f5bf2381b8c62d14bc85c2212..942bbe970457211f7e74c95067c43b63a4059748 100644 --- a/paddle/pten/api/include/manual_api.h +++ b/paddle/pten/api/include/manual_api.h @@ -16,6 +16,8 @@ limitations under the License. */ #include "paddle/pten/api/include/tensor.h" #include "paddle/pten/common/backend.h" +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" /** * This file stores some special APIs that are implemented manually @@ -28,5 +30,11 @@ namespace experimental { // TODO(chenweihang): Replace backend by place when place is ready PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking); +// TODO(chentianyu03): Split API has extra logic to calculate the outputs size, +// api_gen do not support +PADDLE_API std::vector split(const Tensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis); + } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/manual_api.cc b/paddle/pten/api/lib/manual_api.cc index 1af5150b4aed475884fc16e01c67feaf020dce53..667bd177ee1f6232f44299502e348e23579cb49c 100644 --- a/paddle/pten/api/lib/manual_api.cc +++ b/paddle/pten/api/lib/manual_api.cc @@ -19,9 +19,12 @@ limitations under the License. */ #include "glog/logging.h" #include "paddle/pten/api/lib/api_registry.h" +#include "paddle/pten/api/lib/api_utils.h" +#include "paddle/pten/api/lib/data_transform.h" #include "paddle/pten/api/lib/kernel_dispatch.h" #include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/core/kernel_registry.h" +#include "paddle/pten/core/meta_tensor.h" #include "paddle/pten/infermeta/unary.h" PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT); @@ -75,6 +78,71 @@ PADDLE_API Tensor copy_to(const Tensor& x, Backend backend, bool blocking) { return out; } +PADDLE_API std::vector split(const Tensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis) { + Backend kernel_backend = Backend::UNDEFINED; + DataLayout kernel_layout = DataLayout::UNDEFINED; + DataType kernel_data_type = DataType::UNDEFINED; + + if (kernel_backend == Backend::UNDEFINED || + kernel_layout == DataLayout::UNDEFINED || + kernel_data_type == DataType::UNDEFINED) { + auto kernel_key_set = ParseKernelKeyByInputArgs(x); + auto kernel_key = kernel_key_set.GetHigestPriorityKernelKey(); + if (kernel_backend == Backend::UNDEFINED) { + kernel_backend = kernel_key.backend(); + } + if (kernel_layout == DataLayout::UNDEFINED) { + kernel_layout = kernel_key.layout(); + } + if (kernel_data_type == DataType::UNDEFINED) { + kernel_data_type = kernel_key.dtype(); + } + } + + auto kernel = pten::KernelFactory::Instance().SelectKernelOrThrowError( + "split", {kernel_backend, kernel_layout, kernel_data_type}); + VLOG(6) << "split API kernel key: [" << kernel_backend << ", " + << kernel_layout << ", " << kernel_data_type << "]"; + VLOG(6) << "split API kernel: " << kernel; + + auto* dev_ctx = GetDeviceContextByBackend(kernel_backend); + + auto dense_x = PrepareData(x, kernel.InputAt(0), {}); + + // Calculate the number of out tensors + size_t out_number; + if (num_or_sections.GetData().size() == 1) { + out_number = num_or_sections.GetData()[0]; + } else { + out_number = num_or_sections.GetData().size(); + } + + std::vector out; + auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out); + std::vector meta_outs; + for (size_t i = 0; i < out_number; ++i) { + meta_outs.push_back(dense_outs[i]); + } + + pten::SplitInferMeta( + MakeMetaTensor(*dense_x), num_or_sections, axis, &meta_outs); + + using kernel_signature = void (*)(const platform::DeviceContext&, + const pten::DenseTensor&, + const pten::ScalarArray&, + const pten::Scalar&, + std::vector&); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, + *dense_x, + pten::ScalarArray(num_or_sections), + pten::Scalar(axis), + dense_outs); + + return out; +} } // namespace experimental } // namespace paddle diff --git a/paddle/pten/api/lib/utils/tensor_utils.cc b/paddle/pten/api/lib/utils/tensor_utils.cc index c0d72452501c3d9cc659710002573655ab2458a8..230787c1b35cdc5bfe2d44ee5757b301404ea871 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.cc +++ b/paddle/pten/api/lib/utils/tensor_utils.cc @@ -36,45 +36,6 @@ std::unique_ptr MakePtenDenseTensor( return std::make_unique(src); } -pten::Scalar MakePtenScalar(const paddle::framework::Tensor& src) { - PADDLE_ENFORCE_EQ(src.numel(), - 1, - paddle::platform::errors::InvalidArgument( - "The Scalar only supports Tensor with 1 element, " - "but now Tensor has %d element.", - src.numel())); - switch (src.type()) { - case paddle::framework::proto::VarType::FP32: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::FP64: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::FP16: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::BF16: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::INT32: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::INT64: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::INT16: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::INT8: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::UINT8: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::BOOL: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::COMPLEX64: - return {src.template data()[0]}; - case paddle::framework::proto::VarType::COMPLEX128: - return {src.template data()[0]}; - default: - PADDLE_THROW(paddle::platform::errors::InvalidArgument( - "Data type error. Don't support casting a %d LoDTensor to Scalar.", - src.type())); - } -} - pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) { auto expected_place = pten::TransToPtenPlace(pten::Backend::CPU); if (variable.IsType()) { @@ -82,9 +43,9 @@ pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) { if (!platform::is_same_place(tensor.place(), expected_place)) { framework::LoDTensor tmp_tensor; framework::TensorCopySync(tensor, expected_place, &tmp_tensor); - return MakePtenScalar(tmp_tensor); + return {tmp_tensor}; } else { - return MakePtenScalar(tensor); + return {tensor}; } } else { PADDLE_THROW(platform::errors::Unimplemented( @@ -95,17 +56,7 @@ pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable) { } pten::ScalarArray MakePtenScalarArray(const paddle::framework::Tensor& src) { - if (src.type() == paddle::framework::proto::VarType::INT64) { - return {src.data(), src.numel()}; - } else if (src.type() == paddle::framework::proto::VarType::INT32) { - return {src.data(), src.numel()}; - } else { - PADDLE_THROW(paddle::platform::errors::InvalidArgument( - "Data type error. When cast a LoDTensor to ScalarArray, " - "the data type of LoDTensor must be int32 or int64, " - "but now data type is %s.", - src.type())); - } + return {src}; } pten::ScalarArray MakePtenScalarArrayFromVar( @@ -128,6 +79,7 @@ pten::ScalarArray MakePtenScalarArrayFromVar( } } +// TODO(chentianyu03): Inplace with ScalarArray constructor pten::ScalarArray MakePtenScalarArrayFromVarList( const std::vector& variable_list) { if (variable_list.size() == 0) { @@ -135,45 +87,28 @@ pten::ScalarArray MakePtenScalarArrayFromVarList( } auto expected_place = pten::TransToPtenPlace(pten::Backend::CPU); - paddle::framework::proto::VarType::Type data_type; - auto* first_var = variable_list.front(); - if (first_var->IsType()) { - const auto& tensor = first_var->Get(); - data_type = tensor.type(); - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupport casting input `%s` type to VectorTensor when call pt " - "kernel.", - framework::ToTypeName(first_var->Type()))); - } - std::vector vector_data; vector_data.reserve(variable_list.size()); - if (data_type == paddle::framework::proto::VarType::INT64) { - for (auto* var : variable_list) { - if (var->IsType()) { + for (auto* var : variable_list) { + paddle::framework::proto::VarType::Type data_type; + if (var->IsType()) { + const auto& tensor = var->Get(); + data_type = tensor.type(); + if (data_type == paddle::framework::proto::VarType::INT64) { const auto& tensor = var->Get(); - if (!platform::is_same_place(tensor.place(), expected_place)) { + if (tensor.IsInitialized() && + !platform::is_same_place(tensor.place(), expected_place)) { framework::LoDTensor tmp_tensor; framework::TensorCopySync(tensor, expected_place, &tmp_tensor); vector_data.push_back(*tmp_tensor.data()); } else { vector_data.push_back(*tensor.data()); } - } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupport casting input `%s` type to VectorTensor when call pt " - "kernel.", - framework::ToTypeName(var->Type()))); - } - } - - } else if (data_type == paddle::framework::proto::VarType::INT32) { - for (auto* var : variable_list) { - if (var->IsType()) { + } else if (data_type == paddle::framework::proto::VarType::INT32) { const auto& tensor = var->Get(); - if (!platform::is_same_place(tensor.place(), expected_place)) { + if (tensor.IsInitialized() && + !platform::is_same_place(tensor.place(), expected_place)) { framework::LoDTensor tmp_tensor; framework::TensorCopySync(tensor, expected_place, &tmp_tensor); vector_data.push_back(*tmp_tensor.data()); @@ -181,21 +116,24 @@ pten::ScalarArray MakePtenScalarArrayFromVarList( vector_data.push_back(*tensor.data()); } } else { - PADDLE_THROW(platform::errors::Unimplemented( - "Unsupport casting input `%s` type to VectorTensor when call pt " - "kernel.", - framework::ToTypeName(var->Type()))); + PADDLE_THROW(pten::errors::InvalidArgument( + "Data type error. When cast a LoDTensor to VectorTensor, " + "the data type of LoDTensor must be int32 or int64, " + "but now data type is %s.", + data_type)); } + } else { + PADDLE_THROW(pten::errors::Unimplemented( + "Unsupport casting input `%s` type to VectorTensor when call pt " + "kernel.", + framework::ToTypeName(var->Type()))); } - } else { - PADDLE_THROW(paddle::platform::errors::InvalidArgument( - "Data type error. When cast a LoDTensor to VectorTensor, " - "the data type of LoDTensor must be int32 or int64, " - "but now data type is %s.", - data_type)); } - return {vector_data}; + pten::ScalarArray result{vector_data}; + result.setInitByTensor(true); + + return result; } void ResetTensorDtypeAndLayoutByArgDef(pten::TensorBase* dst, diff --git a/paddle/pten/api/lib/utils/tensor_utils.h b/paddle/pten/api/lib/utils/tensor_utils.h index 1e2d8b74db84941f970c0613fad4fa488f813053..cf1daf732ee96711fde3e5910899d972dd748ceb 100644 --- a/paddle/pten/api/lib/utils/tensor_utils.h +++ b/paddle/pten/api/lib/utils/tensor_utils.h @@ -33,8 +33,6 @@ namespace experimental { std::unique_ptr MakePtenDenseTensor( const paddle::framework::Tensor& src); -pten::Scalar MakePtenScalar(const paddle::framework::Tensor& src); - pten::ScalarArray MakePtenScalarArray(const paddle::framework::Tensor& src); pten::Scalar MakePtenScalarFromVar(const framework::Variable& variable); diff --git a/paddle/pten/common/scalar.h b/paddle/pten/common/scalar.h index 5c8fb04633088a0f9bc53877e1ab7bddf1f073ad..0ab880d6218f8778d479f166ab26db2c651ab6ac 100644 --- a/paddle/pten/common/scalar.h +++ b/paddle/pten/common/scalar.h @@ -25,6 +25,7 @@ namespace experimental { template class ScalarBase { public: + bool IsInitByTensor() const { return is_init_by_tensor_; } // Constructor support implicit ScalarBase(double val) : dtype_(DataType::FLOAT64) { // NOLINT data_.f64 = val; @@ -103,6 +104,7 @@ class ScalarBase { // The Tensor must have one dim ScalarBase(const T& tensor) : dtype_(tensor.dtype()) { // NOLINT + is_init_by_tensor_ = true; PD_CHECK( tensor.numel() == 1, "The Scalar only supports Tensor with 1 element, but now Tensor has `", @@ -194,6 +196,7 @@ class ScalarBase { friend void CopyScalar(const ScalarBase& src, ScalarBase* dst); private: + bool is_init_by_tensor_{false}; DataType dtype_; union data { bool b; diff --git a/paddle/pten/common/scalar_array.h b/paddle/pten/common/scalar_array.h index 81013d8e5a11cdd6b44587bb2151b7be18895c27..dcc8ff6748b869ac25c882e66f1ff77bc94534a2 100644 --- a/paddle/pten/common/scalar_array.h +++ b/paddle/pten/common/scalar_array.h @@ -43,8 +43,13 @@ class ScalarArrayBase { AssignData(date_value, n); } + bool IsInitByTensor() const { return is_init_by_tensor_; } + + void setInitByTensor(bool val) { is_init_by_tensor_ = val; } + // The Tensor must have one dim ScalarArrayBase(const T& tensor) { // NOLINT + is_init_by_tensor_ = true; size_t n = tensor.numel(); array_.reserve(n); switch (tensor.dtype()) { @@ -66,41 +71,17 @@ class ScalarArrayBase { // The Tensor in vec must have only one element ScalarArrayBase(const std::vector& tensor_list) { // NOLINT - auto n = tensor_list.size(); - array_.reserve(n); - if (!tensor_list.empty()) { - DataType data_type = tensor_list[0].dtype(); + is_init_by_tensor_ = true; + + for (size_t i = 0; i < tensor_list.size(); ++i) { + DataType data_type = tensor_list[i].dtype(); switch (data_type) { - case DataType::INT32: { - for (size_t i = 0; i < n; ++i) { - PD_CHECK(tensor_list[i].dtype() == data_type, - "The data_type of tensors in the list isn't consistent." - "the first tensor is`", - data_type, - "` but `", - i, - "`th tensor is`", - tensor_list[i].dtype(), - "`."); - array_.push_back(*tensor_list[i].template data()); - } + case DataType::INT32: + array_.push_back(*tensor_list[i].template data()); break; - } - case DataType::INT64: { - for (size_t i = 0; i < n; ++i) { - PD_CHECK(tensor_list[i].dtype() == data_type, - "The data_type of tensors in the list isn't consistent." - "the first tensor is`", - data_type, - "` but `", - i, - "`th tensor is`", - tensor_list[i].dtype(), - "`."); - array_.push_back(*tensor_list[i].template data()); - } + case DataType::INT64: + array_.push_back(*tensor_list[i].template data()); break; - } default: PD_THROW( "Data type error. Currently, The data type of ScalarArrayBase " @@ -136,6 +117,7 @@ class ScalarArrayBase { // TODO(zhangyunfei) Replace std::vector with a more efficient container // structure. std::vector array_; + bool is_init_by_tensor_{false}; }; using ScalarArray = diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index 5f3b0712b5863145f7340dfb6ae34a6809d9d635..ca59937399a226558c213fed5b43a2311a2f368a 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -315,4 +315,137 @@ void TransferLayoutInferMeta(const MetaTensor& x, out->set_layout(layout); } +void SplitInferMeta(const MetaTensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis, + std::vector* out, + MetaConfig config) { + int axis_value = axis.to(); + int rank = x.dims().size(); + PADDLE_ENFORCE_EQ( + axis_value >= -rank && axis_value < rank, + true, + paddle::platform::errors::InvalidArgument( + "The axis is expected to be in range of [%d, %d), but got %d", + -rank, + rank, + axis_value)); + if (axis_value < 0) { + axis_value = axis_value + rank; + } + + auto input_axis_dim = x.dims().at(axis_value); + auto num_or_sections_data = num_or_sections.GetData(); + // step1: get formated sections + std::vector sections; + // num_or_sections is a number + if (num_or_sections_data.size() == 1) { + int num = num_or_sections_data.at(0); + + PADDLE_ENFORCE_EQ(input_axis_dim % num, + 0, + paddle::platform::errors::InvalidArgument( + "The input's size along the split dimension " + "must be evenly divisible by Attr(num_or_sections). " + "But received Attr(num_or_sections) " + "= %d, input(X)'s shape = [%s], Attr(dim) = %d.", + num, + x.dims(), + axis_value)); + + for (int i = 0; i < num; ++i) { + sections.push_back(input_axis_dim / num); + } + } else { + // num_or_sections is a sections + const int unknow_dim_val = -1; + int unknow_dim_idx = -1; + int num_of_unknow = 0; + int sum_of_section = 0; + + for (size_t i = 0; i < num_or_sections_data.size(); ++i) { + sections.push_back(num_or_sections_data[i]); + + if (num_or_sections_data[i] == unknow_dim_val) { + num_of_unknow++; + unknow_dim_idx = i; + } else { + sum_of_section += num_or_sections_data[i]; + } + } + + if (config.is_runtime) { + PADDLE_ENFORCE_LE(num_of_unknow, + 1, + paddle::platform::errors::InvalidArgument( + "Only one dimension value of Attr(num_or_sections) " + "in SplitOp can be -1. " + "But received Attr(num_or_sections) = [%s].", + pten::framework::make_ddim(num_or_sections_data))); + } + + if (unknow_dim_idx != -1) { + // for example, input shape = [4 ,5], axis = 1, sections = [2, 3, -1]. + // input_axis_dim = 5, sum_of_sections = 5. + // the following check will fail. + PADDLE_ENFORCE_LT( + sum_of_section, + input_axis_dim, + paddle::platform::errors::InvalidArgument( + "Sum of Attr(num_or_sections) other than unknown section " + "must be less than the input's " + "size " + "along the split dimension. But received Attr(num_or_sections) " + "= [%s], input(X)'s shape = [%s], Attr(dim) = %d.", + pten::framework::make_ddim(num_or_sections_data), + x.dims(), + axis_value)); + + if (config.is_runtime) { + sections[unknow_dim_idx] = input_axis_dim - sum_of_section; + } + } else { + PADDLE_ENFORCE_EQ( + sum_of_section, + input_axis_dim, + paddle::platform::errors::InvalidArgument( + "Sum of Attr(num_or_sections) must be equal to the input's " + "size " + "along the split dimension. But received Attr(num_or_sections)" + " = [%s], input(X)'s shape = [%s], Attr(dim) = %d.", + pten::framework::make_ddim(num_or_sections_data), + x.dims(), + axis_value)); + } + } + + // setp2: fill out dims + std::vector out_dims(sections.size(), x.dims()); + if (config.is_runtime || input_axis_dim > 0) { + for (size_t i = 0; i < sections.size(); ++i) { + out_dims[i][axis_value] = sections[i]; + } + } else { + for (size_t i = 0; i < sections.size(); ++i) { + out_dims[i][axis_value] = -1; + } + } + + for (size_t i = 0; i < sections.size(); ++i) { + if (axis_value != 0) { + // Only pass LoD when not spliting along the first dim. + (*out)[i].set_dtype(x.dtype()); + (*out)[i].set_dims(out_dims[i]); + (*out)[i].set_layout(x.layout()); + } else { + (*out)[i].set_dtype(x.dtype()); + (*out)[i].set_dims(out_dims[i]); + (*out)[i].set_layout(x.layout()); + (*out)[i].share_lod(x); + } + } + + return; +} + } // namespace pten diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index f1dc806b4e9caee980f0bd4b9d5085f375c55bd2..4c816c4adbc233e0442c2100f62ee8e62cc8f78c 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once // See Note [ Why still include the fluid headers? ] +#include "paddle/pten/common/scalar.h" #include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/meta_tensor.h" @@ -74,4 +75,9 @@ void TransferLayoutInferMeta(const MetaTensor& x, DataLayout layout, MetaTensor* out); +void SplitInferMeta(const MetaTensor& x_meta, + const ScalarArray& num_or_sections, + const Scalar& axis, + std::vector* out, + MetaConfig config = MetaConfig()); } // namespace pten diff --git a/paddle/pten/kernels/cpu/split_kernel.cc b/paddle/pten/kernels/cpu/split_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..78fcdcb155cf23b146c1a44ccc9651b6506d1d4d --- /dev/null +++ b/paddle/pten/kernels/cpu/split_kernel.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/split_kernel.h" + +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/pten/common/float16.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/pten/infermeta/unary.h" +#include "paddle/pten/kernels/cpu/concat_and_split.h" +namespace pten { + +template +void SplitKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis_scalar, + std::vector outs) { + // need to infershape output + if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) { + std::vector out_metas; + for (size_t i = 0; i < outs.size(); ++i) { + out_metas.push_back(outs[i]); + } + + pten::SplitInferMeta(x, num_or_sections, axis_scalar, &out_metas, true); + + for (size_t i = 0; i < out_metas.size(); ++i) { + outs[i]->Resize(out_metas[i].dims()); + } + } + + std::vector shape_refer; + for (size_t j = 0; j < outs.size(); ++j) { + dev_ctx.Alloc(outs[j]); + shape_refer.emplace_back(outs[j]); + } + + int axis = axis_scalar.to(); + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && outs.size() < 10) { + paddle::operators::StridedMemcpyWithAxis0( + dev_ctx, x, shape_refer, &outs); + } else { + SplitImpl(dev_ctx, x, shape_refer, axis, &outs); + } +} + +} // namespace pten + +PT_REGISTER_KERNEL(split, + CPU, + ALL_LAYOUT, + pten::SplitKernel, + float, + double, + int64_t, + int, + bool, + pten::dtype::float16) {} diff --git a/paddle/pten/kernels/gpu/concat_and_split.h b/paddle/pten/kernels/gpu/concat_and_split.h index 47022666564df9ba5626f00ef15feccfd3e900d1..17b54bbbfdc549adfd06194e7851506341638879 100644 --- a/paddle/pten/kernels/gpu/concat_and_split.h +++ b/paddle/pten/kernels/gpu/concat_and_split.h @@ -134,12 +134,12 @@ __global__ void ConcatKernel_(const T** inputs_data, } template -__global__ void SplitKernel(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t* out_cols, - int out_cols_size, - T** outputs_data) { +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t* out_cols, + int out_cols_size, + T** outputs_data) { int tid_x = blockIdx.x * blockDim.x + threadIdx.x; int curr_segment = 0; int curr_offset = out_cols[0]; @@ -184,21 +184,21 @@ __device__ void SplitKernelDetail(const T* input_data, } template -__global__ void SplitKernel(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T** outputs_data) { +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T** outputs_data) { SplitKernelDetail(input_data, in_row, in_col, fixed_out_col, outputs_data); } template -__global__ void SplitKernel(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1) { +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1) { T* outputs_data[2]; outputs_data[0] = outputs_addr0; outputs_data[1] = outputs_addr1; @@ -206,13 +206,13 @@ __global__ void SplitKernel(const T* input_data, } template -__global__ void SplitKernel(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1, - T* outputs_addr2) { +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1, + T* outputs_addr2) { T* outputs_data[3]; outputs_data[0] = outputs_addr0; outputs_data[1] = outputs_addr1; @@ -221,14 +221,14 @@ __global__ void SplitKernel(const T* input_data, } template -__global__ void SplitKernel(const T* input_data, - const int64_t in_row, - const int64_t in_col, - const int64_t fixed_out_col, - T* outputs_addr0, - T* outputs_addr1, - T* outputs_addr2, - T* outputs_addr3) { +__global__ void SplitKernel_(const T* input_data, + const int64_t in_row, + const int64_t in_col, + const int64_t fixed_out_col, + T* outputs_addr0, + T* outputs_addr1, + T* outputs_addr2, + T* outputs_addr3) { T* outputs_data[4]; outputs_data[0] = outputs_addr0; outputs_data[1] = outputs_addr1; @@ -497,7 +497,7 @@ void SplitImpl(const Context& context, if (has_same_shape) { if (o_num == 2) { - SplitKernel<<>>( + SplitKernel_<<>>( input.data(), in_row, in_col, @@ -505,7 +505,7 @@ void SplitImpl(const Context& context, outputs_data[0], outputs_data[1]); } else if (o_num == 3) { - SplitKernel<<>>( + SplitKernel_<<>>( input.data(), in_row, in_col, @@ -514,7 +514,7 @@ void SplitImpl(const Context& context, outputs_data[1], outputs_data[2]); } else if (o_num == 4) { - SplitKernel<<>>( + SplitKernel_<<>>( input.data(), in_row, in_col, @@ -524,7 +524,7 @@ void SplitImpl(const Context& context, outputs_data[2], outputs_data[3]); } else { - SplitKernel<<>>( + SplitKernel_<<>>( input.data(), in_row, in_col, out0_col, dev_out_gpu_data); } } else { @@ -542,7 +542,7 @@ void SplitImpl(const Context& context, int64_t* dev_outs_col_data = reinterpret_cast(tmp_dev_ins_col_data->ptr()); - SplitKernel<<>>( + SplitKernel_<<>>( input.data(), in_row, in_col, diff --git a/paddle/pten/kernels/gpu/split_kernel.cu b/paddle/pten/kernels/gpu/split_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..46d18b75b611b7acad8155b0d6bbed7b015a23c9 --- /dev/null +++ b/paddle/pten/kernels/gpu/split_kernel.cu @@ -0,0 +1,72 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/kernels/split_kernel.h" + +#include "paddle/fluid/operators/strided_memcpy.h" +#include "paddle/pten/common/float16.h" +#include "paddle/pten/core/kernel_registry.h" + +#include "paddle/pten/kernels/gpu/concat_and_split.h" +namespace pten { + +template +void SplitKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis_scalar, + std::vector outs) { + // need to infershape output + if (num_or_sections.IsInitByTensor() || axis_scalar.IsInitByTensor()) { + std::vector out_metas; + for (size_t i = 0; i < outs.size(); ++i) { + out_metas.push_back(outs[i]); + } + + pten::SplitInferMeta(x, num_or_sections, axis_scalar, &out_metas, true); + + for (size_t i = 0; i < out_metas.size(); ++i) { + outs[i]->Resize(out_metas[i].dims()); + } + } + + std::vector shape_refer; + for (size_t j = 0; j < outs.size(); ++j) { + dev_ctx.Alloc(outs[j]); + shape_refer.emplace_back(outs[j]); + } + + int axis = axis_scalar.to(); + // Sometimes direct copies will be faster, this maybe need deeply analysis. + if (axis == 0 && outs.size() < 10) { + paddle::operators::StridedMemcpyWithAxis0( + dev_ctx, x, shape_refer, &outs); + } else { + SplitImpl(dev_ctx, x, shape_refer, axis, &outs); + } +} + +} // namespace pten + +PT_REGISTER_KERNEL(split, + GPU, + ALL_LAYOUT, + pten::SplitKernel, + float, + double, + int64_t, + int, + bool, + pten::dtype::float16, + pten::dtype::bfloat16) {} diff --git a/paddle/pten/kernels/split_kernel.h b/paddle/pten/kernels/split_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..30ac4da7a4ca04f519c58bb9ac205388a4448f5a --- /dev/null +++ b/paddle/pten/kernels/split_kernel.h @@ -0,0 +1,70 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +#include "paddle/pten/common/scalar.h" +#include "paddle/pten/common/scalar_array.h" +#include "paddle/pten/infermeta/unary.h" +#include "paddle/pten/kernels/empty_kernel.h" + +namespace pten { + +template +void SplitKernel(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis, + std::vector out); + +template +std::vector Split(const Context& dev_ctx, + const DenseTensor& x, + const ScalarArray& num_or_sections, + const Scalar& axis) { + size_t out_number; + if (num_or_sections.GetData().size() == 1) { + out_number = num_or_sections.GetData()[0]; + } else { + out_number = num_or_sections.GetData().size(); + } + + std::vector out_meta; + out_meta.reserve(out_number); + std::vector result; + result.reserve(out_number); + + for (size_t i = 0; i < out_number; ++i) { + auto dense_out = pten::Empty(dev_ctx); + MetaTensor tmp_meta(&dense_out); + + result.push_back(dense_out); + out_meta.push_back(&result.back()); + } + SplitInferMeta(x, num_or_sections, axis, &out_meta); + + std::vector outs; + outs.reserve(out_meta.size()); + for (size_t i = 0; i < out_meta.size(); ++i) { + outs.push_back(&result[i]); + } + + SplitKernel(dev_ctx, x, num_or_sections, axis, outs); + + return result; +} + +} // namespace pten diff --git a/paddle/pten/ops/compat/split_sig.cc b/paddle/pten/ops/compat/split_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec58af5e9e41d3f89e9422b9b0f5bd650ee8347d --- /dev/null +++ b/paddle/pten/ops/compat/split_sig.cc @@ -0,0 +1,49 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/pten/core/compat/op_utils.h" + +namespace pten { + +KernelSignature SplitOpArgumentMapping(const ArgumentMappingContext& ctx) { + // priority: num > SectionsTensorList > sections + // priority: AxisTensor > axis + if (paddle::any_cast(ctx.Attr("num")) > 0) { + if (ctx.HasInput("AxisTensor")) { + return KernelSignature("split", {"X"}, {"num", "AxisTensor"}, {"Out"}); + } else { + return KernelSignature("split", {"X"}, {"num", "axis"}, {"Out"}); + } + } + + if (ctx.InputSize("SectionsTensorList") > 0) { + if (ctx.HasInput("AxisTensor")) { + return KernelSignature( + "split", {"X"}, {"SectionsTensorList", "AxisTensor"}, {"Out"}); + } else { + return KernelSignature( + "split", {"X"}, {"SectionsTensorList", "axis"}, {"Out"}); + } + } + + if (ctx.HasInput("AxisTensor")) { + return KernelSignature("split", {"X"}, {"sections", "AxisTensor"}, {"Out"}); + } else { + return KernelSignature("split", {"X"}, {"sections", "axis"}, {"Out"}); + } +} + +} // namespace pten + +PT_REGISTER_ARG_MAPPING_FN(split, pten::SplitOpArgumentMapping); diff --git a/paddle/pten/tests/api/CMakeLists.txt b/paddle/pten/tests/api/CMakeLists.txt index b8491ab7f5ea89c0aaaf72ae9ae55ba7ea435083..d875dbd4444ae664472663caa0ea5b2694ca8e4f 100644 --- a/paddle/pten/tests/api/CMakeLists.txt +++ b/paddle/pten/tests/api/CMakeLists.txt @@ -22,6 +22,6 @@ cc_test(test_scale_api SRCS test_scale_api.cc DEPS pten_tensor pten_api pten_api cc_test(test_scale_benchmark SRCS test_scale_benchmark.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_conj_api SRCS test_conj_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_concat_api SRCS test_concat_api.cc DEPS pten_tensor pten_api pten_api_utils) - +cc_test(test_split_api SRCS test_split_api.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_data_transform SRCS test_data_transform.cc DEPS pten_tensor pten_api pten_api_utils) cc_test(test_sparse_utils_api SRCS test_sparse_utils_api.cc DEPS pten_tensor pten_api pten_api_utils) diff --git a/paddle/pten/tests/api/test_split_api.cc b/paddle/pten/tests/api/test_split_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac139832aa0082ae23d1ebee05d68cb360690241 --- /dev/null +++ b/paddle/pten/tests/api/test_split_api.cc @@ -0,0 +1,83 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/pten/api/include/api.h" + +#include "paddle/pten/api/include/manual_api.h" +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" + +namespace paddle { +namespace tests { + +namespace framework = paddle::framework; +using DDim = pten::framework::DDim; + +// TODO(chentianyu03): Remove this test after the API is used in the dygraph +TEST(API, split) { + // 1. create tensor + const auto alloc = std::make_unique( + paddle::platform::CPUPlace()); + auto dense_x = std::make_shared( + alloc.get(), + pten::DenseTensorMeta(pten::DataType::FLOAT32, + pten::framework::make_ddim({4, 10}), + pten::DataLayout::NCHW)); + auto* dense_x_data = + dense_x->mutable_data(paddle::platform::CPUPlace()); + + for (size_t i = 0; i < 4; ++i) { + for (size_t j = 0; j < 10; ++j) { + dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0; + } + } + + paddle::experimental::Tensor x(dense_x); + + // 2. test API + auto out = paddle::experimental::split(x, {2, 2}, 0); + + // 3. check result + ASSERT_EQ(out.size(), static_cast(2)); + ASSERT_EQ(out[0].dims().size(), 2); + ASSERT_EQ(out[0].dims()[0], 2); + ASSERT_EQ(out[0].dims()[1], 10); + ASSERT_EQ(out[0].type(), pten::DataType::FLOAT32); + ASSERT_EQ(out[0].layout(), pten::DataLayout::NCHW); + + ASSERT_EQ(out[1].dims().size(), 2); + ASSERT_EQ(out[1].dims()[0], 2); + ASSERT_EQ(out[1].dims()[1], 10); + ASSERT_EQ(out[1].type(), pten::DataType::FLOAT32); + ASSERT_EQ(out[1].layout(), pten::DataLayout::NCHW); + + auto out_data_0 = std::dynamic_pointer_cast(out[0].impl()) + ->data(); + auto out_data_1 = std::dynamic_pointer_cast(out[1].impl()) + ->data(); + for (size_t i = 0; i < 4; ++i) { + if (i < 20) { + ASSERT_NEAR(dense_x_data[i], out_data_0[i], 1e-6); + } else { + ASSERT_NEAR(dense_x_data[i], out_data_1[i - 20], 1e-6); + } + } +} + +} // namespace tests +} // namespace paddle diff --git a/paddle/pten/tests/kernels/CMakeLists.txt b/paddle/pten/tests/kernels/CMakeLists.txt index e2063241689f929e6d173bcb29dde849ca5a3f48..15a1cab5f0dd473498ebb23e564ce88400af9713 100644 --- a/paddle/pten/tests/kernels/CMakeLists.txt +++ b/paddle/pten/tests/kernels/CMakeLists.txt @@ -11,4 +11,5 @@ cc_test(test_reshape_dev_api SRCS test_reshape_dev_api.cc DEPS pten pten_api_uti cc_test(test_sum_dev_api SRCS test_sum_dev_api.cc DEPS pten pten_api_utils) cc_test(test_conj_dev_api SRCS test_conj_dev_api.cc DEPS pten pten_api_utils) cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS pten pten_api_utils) +cc_test(test_split_dev_api SRCS test_split_dev_api.cc DEPS pten pten_api_utils) cc_test(test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS pten pten_api_utils) diff --git a/paddle/pten/tests/kernels/test_split_dev_api.cc b/paddle/pten/tests/kernels/test_split_dev_api.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4e3619e11a3ad976938b3c78b303cbdcd04ce1b --- /dev/null +++ b/paddle/pten/tests/kernels/test_split_dev_api.cc @@ -0,0 +1,83 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/pten/kernels/split_kernel.h" + +#include "paddle/fluid/memory/allocation/allocator_facade.h" +#include "paddle/pten/api/include/manual_api.h" +#include "paddle/pten/api/lib/utils/allocator.h" +#include "paddle/pten/backends/cpu/cpu_context.h" +#include "paddle/pten/core/dense_tensor.h" +#include "paddle/pten/core/kernel_registry.h" +namespace pten { +namespace tests { + +namespace framework = paddle::framework; +using DDim = pten::framework::DDim; + +TEST(DEV_API, split) { + // 1. create tensor + const auto alloc = std::make_unique( + pten::CPUPlace()); + pten::DenseTensor dense_x( + alloc.get(), + pten::DenseTensorMeta(pten::DataType::FLOAT32, + pten::framework::make_ddim({4, 10}), + pten::DataLayout::NCHW)); + pten::CPUContext dev_ctx; + dev_ctx.SetAllocator(paddle::memory::allocation::AllocatorFacade::Instance() + .GetAllocator(paddle::platform::CPUPlace()) + .get()); + dev_ctx.Init(); + + auto* dense_x_data = dev_ctx.Alloc(&dense_x); + for (size_t i = 0; i < 4; ++i) { + for (size_t j = 0; j < 10; ++j) { + dense_x_data[i * 10 + j] = (i * 10 + j) * 1.0; + } + } + + // 2. test API + auto out = pten::Split(dev_ctx, dense_x, {2, 2}, 0); + + // 3. check result + ASSERT_EQ(out.size(), static_cast(2)); + ASSERT_EQ(out[0].dims().size(), 2); + ASSERT_EQ(out[0].dims()[0], 2); + ASSERT_EQ(out[0].dims()[1], 10); + ASSERT_EQ(out[0].meta().dtype, pten::DataType::FLOAT32); + ASSERT_EQ(out[0].meta().layout, pten::DataLayout::NCHW); + + ASSERT_EQ(out[1].dims().size(), 2); + ASSERT_EQ(out[1].dims()[0], 2); + ASSERT_EQ(out[1].dims()[1], 10); + ASSERT_EQ(out[1].meta().dtype, pten::DataType::FLOAT32); + ASSERT_EQ(out[1].meta().layout, pten::DataLayout::NCHW); + + auto out_data_0 = out[0].data(); + auto out_data_1 = out[1].data(); + for (size_t i = 0; i < 4; ++i) { + if (i < 20) { + ASSERT_NEAR(dense_x_data[i], out_data_0[i], 1e-6); + } else { + ASSERT_NEAR(dense_x_data[i], out_data_1[i - 20], 1e-6); + } + } +} + +} // namespace tests +} // namespace pten