From a3c8abc707fb871426670d309377884599bc580e Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Tue, 14 Dec 2021 11:33:32 +0800 Subject: [PATCH] [PTen] Reduce reshape kernel functions in pten (#38055) * Reduce reshape kernel functions in pten * delete notes * fix bugs when compile --- paddle/fluid/framework/operator.cc | 4 + paddle/fluid/imperative/prepared_operator.cc | 4 + paddle/fluid/operators/flatten_op.cc | 5 +- paddle/fluid/operators/reshape_op.cc | 76 ++++---------- paddle/pten/api/include/kernel_signature.h | 2 +- paddle/pten/common/scalar_array.h | 8 +- paddle/pten/include/manipulation.h | 2 +- paddle/pten/infermeta/unary.cc | 5 + paddle/pten/infermeta/unary.h | 4 + paddle/pten/kernels/cpu/manipulation.cc | 104 +++---------------- paddle/pten/kernels/cpu/manipulation.h | 41 ++------ paddle/pten/kernels/cuda/manipulation.cu | 102 +++--------------- paddle/pten/kernels/cuda/manipulation.h | 41 ++------ paddle/pten/kernels/xpu/manipulation.cc | 51 +++------ paddle/pten/kernels/xpu/manipulation.h | 23 ++-- python/paddle/utils/code_gen/api.yaml | 4 +- 16 files changed, 129 insertions(+), 347 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 716f0a85c17..265ce01d814 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1872,6 +1872,10 @@ void OperatorWithKernel::BuildPtenKernelContext( std::type_index(typeid(std::vector))) { pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray( BOOST_GET_CONST(std::vector, attr_iter->second)))); + } else if (std::type_index(attr_iter->second.type()) == + std::type_index(typeid(std::vector))) { + pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray( + BOOST_GET_CONST(std::vector, attr_iter->second)))); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` to ScalarArray when " diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 54f46e49c4f..055e02b0cb2 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -358,6 +358,10 @@ static void BuildDygraphPtenKernelContext( std::type_index(typeid(std::vector))) { kernel_ctx->EmplaceBackAttr(std::move( pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + kernel_ctx->EmplaceBackAttr(std::move( + pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` to VectorTensor when " diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index dcd31f1ded6..a1b8dd6bae4 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -337,8 +337,9 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext &ctx) const override { if (ctx.HasOutput("XShape")) { - return framework::KernelSignature( - "flatten.mid", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"}); + return framework::KernelSignature("flatten_with_xshape", {"X"}, + {"start_axis", "stop_axis"}, + {"Out", "XShape"}); } else { return framework::KernelSignature("flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out"}); diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 155eb1ebbe3..a796821729f 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -19,6 +19,7 @@ limitations under the License. */ // only can include the headers in paddle/pten/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/include/core.h" #include "paddle/pten/include/manipulation.h" namespace paddle { @@ -402,6 +403,7 @@ class ReshapeKernel { auto *shape_tensor = ctx.HasInput("Shape") ? ctx.Input("Shape") : nullptr; + pten::ScalarArray pt_scalar_shape; if (list_new_shape_tensor.size() > 0) { // have shape tensor std::vector pt_vec_shape; @@ -417,22 +419,7 @@ class ReshapeKernel { std::move(*(paddle::experimental::MakePtenDenseTensor(*tensor)))); } } - if (platform::is_cpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); - } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); - } -#endif -#ifdef PADDLE_WITH_XPU - if (platform::is_xpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); - } -#endif + pt_scalar_shape = pten::ScalarArray(pt_vec_shape); } else if (shape_tensor) { std::unique_ptr pt_shape; if (platform::is_gpu_place(shape_tensor->place()) || @@ -443,44 +430,27 @@ class ReshapeKernel { } else { pt_shape = paddle::experimental::MakePtenDenseTensor(*shape_tensor); } - - if (platform::is_cpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); - } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); - } -#endif -#ifdef PADDLE_WITH_XPU - if (platform::is_xpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); - } -#endif + pt_scalar_shape = pten::ScalarArray(*pt_shape.get()); } else { auto &shape_attr = ctx.Attr>("shape"); - const std::vector shape_vec(shape_attr.begin(), - shape_attr.end()); - if (platform::is_cpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); - } + pt_scalar_shape = pten::ScalarArray(shape_attr); + } + if (platform::is_cpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); + } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); - } + if (platform::is_gpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); + } #endif #ifdef PADDLE_WITH_XPU - if (platform::is_xpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); - } -#endif + if (platform::is_xpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); } +#endif // non-inplace need move all result from pt_out to out, inplace need set // result dims. if (in != out) { @@ -553,16 +523,16 @@ class Reshape2Op : public ReshapeOp { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext &ctx) const override { + std::string shape; auto multi_inputs = ctx.MultiInput("ShapeTensor"); if (multi_inputs.size() > 0) { - return framework::KernelSignature("reshape_mulhost", {"X", "ShapeTensor"}, - {}, {"Out"}); + shape = "ShapeTensor"; } else if (ctx.HasInput("Shape")) { - return framework::KernelSignature("reshape_host", {"X", "Shape"}, {}, - {"Out"}); + shape = "Shape"; } else { - return framework::KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); + shape = "shape"; } + return framework::KernelSignature("reshape", {"X"}, {shape}, {"Out"}); } }; diff --git a/paddle/pten/api/include/kernel_signature.h b/paddle/pten/api/include/kernel_signature.h index 1ff91f7e94a..7d92019d29e 100644 --- a/paddle/pten/api/include/kernel_signature.h +++ b/paddle/pten/api/include/kernel_signature.h @@ -83,7 +83,7 @@ using multiply_kernel = void (*)(const DeviceContext&, using reshape_kernel = void (*)(const DeviceContext&, const DenseTensor&, - const std::vector&, + const ScalarArray&, DenseTensor*); using scale_kernel = void (*)(const DeviceContext&, diff --git a/paddle/pten/common/scalar_array.h b/paddle/pten/common/scalar_array.h index b4d21b98ca0..81013d8e5a1 100644 --- a/paddle/pten/common/scalar_array.h +++ b/paddle/pten/common/scalar_array.h @@ -28,6 +28,10 @@ class ScalarArrayBase { ScalarArrayBase(const std::vector& vec) : array_(vec) {} // NOLINT + ScalarArrayBase(const std::vector& vec) { // NOLINT + array_.insert(array_.begin(), vec.begin(), vec.end()); + } + ScalarArrayBase(std::initializer_list array_list) : array_(array_list) {} @@ -43,7 +47,7 @@ class ScalarArrayBase { ScalarArrayBase(const T& tensor) { // NOLINT size_t n = tensor.numel(); array_.reserve(n); - switch (tensor.type()) { + switch (tensor.dtype()) { case DataType::INT32: AssignData(tensor.template data(), n); break; @@ -55,7 +59,7 @@ class ScalarArrayBase { "Data type error. Currently, The data type of ScalarArrayBase " "only supports Tensor with int32 and int64, " "but now received `", - tensor.type(), + tensor.dtype(), "`."); } } diff --git a/paddle/pten/include/manipulation.h b/paddle/pten/include/manipulation.h index ee68c6aa622..e694a89f700 100644 --- a/paddle/pten/include/manipulation.h +++ b/paddle/pten/include/manipulation.h @@ -60,7 +60,7 @@ DenseTensor Reshape(const ContextT& dev_ctx, std::make_shared( dev_ctx.GetPlace()); pten::DenseTensor dense_out(allocator, out_meta); - ReshapeFromVectorVal(dev_ctx, x, shape, &dense_out); + Reshape(dev_ctx, x, ScalarArray(shape), &dense_out); return dense_out; } diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index 4e861161ca9..4092e2842b9 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -227,6 +227,11 @@ DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta, return return_meta; } +DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta, + const ScalarArray& shape) { + return InferMetaFromVecValue(x_meta, shape.GetData()); +} + DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, const std::vector& axis, bool keep_dim) { diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index 560a27759ad..408a77234f4 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once // See Note [ Why still include the fluid headers? ] +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/tensor_meta.h" namespace pten { @@ -50,6 +51,9 @@ DenseTensorMeta FullLikeInferMeta(const DenseTensorMeta& x_meta, DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta, const std::vector& shape); +DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta, + const ScalarArray& shape); + DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, const std::vector& axis, bool keep_dim); diff --git a/paddle/pten/kernels/cpu/manipulation.cc b/paddle/pten/kernels/cpu/manipulation.cc index 61c6cb57a9f..9c34f842337 100644 --- a/paddle/pten/kernels/cpu/manipulation.cc +++ b/paddle/pten/kernels/cpu/manipulation.cc @@ -46,74 +46,27 @@ void FlattenWithXShape(const CPUContext& dev_ctx, general::SetXShape(x, xshape); } -void ReshapeFromVectorVal(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - auto out_meta = InferMetaFromVecValue(x.meta(), shape); +void Reshape(const CPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out) { + auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData()); if (x.data() == out->data() && x.numel() == out->numel()) { out->Resize(out_meta.dims); return; } pten::Copy(dev_ctx, x, false, out); out->Resize(out_meta.dims); -} - -void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out) { - general::SetXShape(x, xshape); - ReshapeFromVectorVal(dev_ctx, x, shape, out); -} - -void ReshapeFromDT(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out) { - auto* shape_data = shape.data(); - auto vector_shape = - std::vector(shape_data, shape_data + shape.numel()); - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); out->ResetLoD(x.lod()); } -void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* xshape, - DenseTensor* out) { - general::SetXShape(x, xshape); - ReshapeFromDT(dev_ctx, x, shape, out); -} - -void ReshapeFromVectorDT(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - std::vector vector_shape; - for (auto& tensor : shape) { - PADDLE_ENFORCE_EQ( - tensor.dims(), - paddle::framework::make_ddim({1}), - paddle::platform::errors::InvalidArgument( - "If the element type of 'shape' in ReshapeOp is Tensor, " - "the element's shape must be [1]. But received the element's shape " - "is [%s]", - tensor.dims())); - vector_shape.push_back(*tensor.data()); - } - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); -} - -void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out) { +void ReshapeWithXShape(const CPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out) { general::SetXShape(x, xshape); - ReshapeFromVectorDT(dev_ctx, x, shape, out); + Reshape(dev_ctx, x, shape, out); } template @@ -130,8 +83,6 @@ void Cast(const CPUContext& dev_ctx, } // namespace pten -// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel -// architecture, kernel_name should be "flatten". PT_REGISTER_KERNEL(flatten, CPU, ANY, @@ -142,7 +93,7 @@ PT_REGISTER_KERNEL(flatten, int8_t, int, int64_t) {} -PT_REGISTER_KERNEL(flatten_mid, +PT_REGISTER_KERNEL(flatten_with_xshape, CPU, ANY, pten::FlattenWithXShape, @@ -171,33 +122,8 @@ PT_REGISTER_KERNEL(cast, kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } -PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::ReshapeFromVectorVal) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid, - CPU, - ANY, - pten::ReshapeFromVectorValWithXShape) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CPU, ANY, pten::ReshapeFromDT) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid, - CPU, - ANY, - pten::ReshapeFromDTWithXShape) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost, - CPU, - ANY, - pten::ReshapeFromVectorDT) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid, +PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::Reshape) {} +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape, CPU, ANY, - pten::ReshapeFromVectorDTWithXShape) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} + pten::ReshapeWithXShape) {} diff --git a/paddle/pten/kernels/cpu/manipulation.h b/paddle/pten/kernels/cpu/manipulation.h index 36f9aaa85aa..cc583547875 100644 --- a/paddle/pten/kernels/cpu/manipulation.h +++ b/paddle/pten/kernels/cpu/manipulation.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_registry.h" @@ -38,37 +39,15 @@ void Cast(const CPUContext& dev_ctx, DataType in_dtype, DenseTensor* out); -void ReshapeFromDT(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out); - -void ReshapeFromVectorVal(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); - -void ReshapeFromVectorDT(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); - -void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* xshape, - DenseTensor* out); - -void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out); +void Reshape(const CPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out); -void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out); +void ReshapeWithXShape(const CPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out); } // namespace pten diff --git a/paddle/pten/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu index e668d1b04d7..d3c0759698e 100644 --- a/paddle/pten/kernels/cuda/manipulation.cu +++ b/paddle/pten/kernels/cuda/manipulation.cu @@ -46,74 +46,27 @@ void FlattenWithXShape(const CUDAContext& dev_ctx, general::SetXShape(x, xshape); } -void ReshapeFromVectorVal(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - auto out_meta = InferMetaFromVecValue(x.meta(), shape); +void Reshape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out) { + auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData()); if (x.data() == out->data() && x.numel() == out->numel()) { out->Resize(out_meta.dims); return; } pten::Copy(dev_ctx, x, false, out); out->Resize(out_meta.dims); -} - -void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out) { - general::SetXShape(x, xshape); - ReshapeFromVectorVal(dev_ctx, x, shape, out); -} - -void ReshapeFromDT(const CUDAContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out) { - auto* shape_data = shape.data(); - auto vector_shape = - std::vector(shape_data, shape_data + shape.numel()); - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); out->ResetLoD(x.lod()); } -void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* xshape, - DenseTensor* out) { - general::SetXShape(x, xshape); - ReshapeFromDT(dev_ctx, x, shape, out); -} - -void ReshapeFromVectorDT(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - std::vector vector_shape; - for (auto& tensor : shape) { - PADDLE_ENFORCE_EQ( - tensor.dims(), - paddle::framework::make_ddim({1}), - paddle::platform::errors::InvalidArgument( - "If the element type of 'shape' in ReshapeOp is Tensor, " - "the element's shape must be [1]. But received the element's shape " - "is [%s]", - tensor.dims())); - vector_shape.push_back(*tensor.data()); - } - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); -} - -void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out) { +void ReshapeWithXShape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out) { general::SetXShape(x, xshape); - ReshapeFromVectorDT(dev_ctx, x, shape, out); + Reshape(dev_ctx, x, shape, out); } template @@ -142,7 +95,7 @@ PT_REGISTER_KERNEL(flatten, int8_t, int, int64_t) {} -PT_REGISTER_KERNEL(flatten_mid, +PT_REGISTER_KERNEL(flatten_with_xshape, CUDA, ANY, pten::FlattenWithXShape, @@ -179,33 +132,8 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) #endif -PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::ReshapeFromVectorVal) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid, - CUDA, - ANY, - pten::ReshapeFromVectorValWithXShape) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CUDA, ANY, pten::ReshapeFromDT) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid, - CUDA, - ANY, - pten::ReshapeFromDTWithXShape) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost, - CUDA, - ANY, - pten::ReshapeFromVectorDT) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid, +PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::Reshape) {} +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape, CUDA, ANY, - pten::ReshapeFromVectorDTWithXShape) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} + pten::ReshapeWithXShape) {} diff --git a/paddle/pten/kernels/cuda/manipulation.h b/paddle/pten/kernels/cuda/manipulation.h index c0f2d8a1141..be935a045f9 100644 --- a/paddle/pten/kernels/cuda/manipulation.h +++ b/paddle/pten/kernels/cuda/manipulation.h @@ -17,6 +17,7 @@ // CUDA and HIP use same api #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_registry.h" @@ -41,38 +42,16 @@ void Cast(const CUDAContext& dev_ctx, DataType in_dtype, DenseTensor* out); -void ReshapeFromDT(const CUDAContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out); - -void ReshapeFromVectorVal(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); - -void ReshapeFromVectorDT(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); - -void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* xshape, - DenseTensor* out); - -void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out); +void Reshape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out); -void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out); +void ReshapeWithXShape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out); } // namespace pten diff --git a/paddle/pten/kernels/xpu/manipulation.cc b/paddle/pten/kernels/xpu/manipulation.cc index f361933cad4..cee3e5ceedb 100644 --- a/paddle/pten/kernels/xpu/manipulation.cc +++ b/paddle/pten/kernels/xpu/manipulation.cc @@ -51,46 +51,27 @@ void FlattenWithXShape(const XPUContext& dev_ctx, xshape->ResetLoD(x.lod()); } -void ReshapeFromVectorVal(const XPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - auto out_meta = InferMetaFromVecValue(x.meta(), shape); - if (&x == out) { +void Reshape(const XPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out) { + auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData()); + if (x.data() == out->data() && x.numel() == out->numel()) { out->Resize(out_meta.dims); return; } pten::Copy(dev_ctx, x, false, out); out->Resize(out_meta.dims); + out->ResetLoD(x.lod()); } -void ReshapeFromDT(const XPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out) { - auto* shape_data = shape.data(); - auto vector_shape = - std::vector(shape_data, shape_data + shape.numel()); - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); -} - -void ReshapeFromVectorDT(const XPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - std::vector vector_shape; - for (auto& tensor : shape) { - PADDLE_ENFORCE_EQ( - tensor.dims(), - paddle::framework::make_ddim({1}), - paddle::platform::errors::InvalidArgument( - "If the element type of 'shape' in ReshapeOp is Tensor, " - "the element's shape must be [1]. But received the element's shape " - "is [%s]", - tensor.dims())); - vector_shape.push_back(*tensor.data()); - } - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); +void ReshapeWithXShape(const XPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out) { + general::SetXShape(x, xshape); + Reshape(dev_ctx, x, shape, out); } } // namespace pten @@ -107,7 +88,7 @@ PT_REGISTER_KERNEL(flatten, int, int64_t) {} -PT_REGISTER_KERNEL(flatten_mid, +PT_REGISTER_KERNEL(flatten_with_xshape, XPU, ANY, pten::FlattenWithXShape, @@ -119,4 +100,4 @@ PT_REGISTER_KERNEL(flatten_mid, int, int64_t) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::ReshapeFromVectorVal) {} +PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::Reshape) {} diff --git a/paddle/pten/kernels/xpu/manipulation.h b/paddle/pten/kernels/xpu/manipulation.h index b519a23a500..a9f57025e1e 100644 --- a/paddle/pten/kernels/xpu/manipulation.h +++ b/paddle/pten/kernels/xpu/manipulation.h @@ -16,6 +16,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_registry.h" @@ -33,20 +34,16 @@ void Flatten(const XPUContext& dev_ctx, int stop_axis, DenseTensor* out); -void ReshapeFromDT(const XPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out); - -void ReshapeFromVectorVal(const XPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); +void Reshape(const XPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out); -void ReshapeFromVectorDT(const XPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); +void ReshapeWithXShape(const XPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out); } // namespace pten diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 0625000cb88..e0ea80feebe 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -103,10 +103,10 @@ invoke : full_like(x, 1, dtype, place, layout) - api : reshape - args : (const Tensor& x, const std::vector& shape) + args : (const Tensor& x, const ScalarArray& shape) output : Tensor infer_meta : - func : InferMetaFromVecValue + func : ReshapeInferMeta kernel : func : reshape -- GitLab