diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 716f0a85c171ec0c5052e66deec825879d5ca6f4..265ce01d814361818f508c4a46245f44e14652cf 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1872,6 +1872,10 @@ void OperatorWithKernel::BuildPtenKernelContext( std::type_index(typeid(std::vector))) { pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray( BOOST_GET_CONST(std::vector, attr_iter->second)))); + } else if (std::type_index(attr_iter->second.type()) == + std::type_index(typeid(std::vector))) { + pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray( + BOOST_GET_CONST(std::vector, attr_iter->second)))); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` to ScalarArray when " diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 54f46e49c4f730904056aa546a0c1d6ce51c7791..055e02b0cb258def0e8820c73cf9058721272c77 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -358,6 +358,10 @@ static void BuildDygraphPtenKernelContext( std::type_index(typeid(std::vector))) { kernel_ctx->EmplaceBackAttr(std::move( pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); + } else if (std::type_index(attr.type()) == + std::type_index(typeid(std::vector))) { + kernel_ctx->EmplaceBackAttr(std::move( + pten::ScalarArray(BOOST_GET_CONST(std::vector, attr)))); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported cast op attribute `%s` to VectorTensor when " diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index dcd31f1ded6d464e930cdb4f4e9b19e7f1ca7061..a1b8dd6bae4945a7f1ea934792e8886278512b26 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -337,8 +337,9 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext &ctx) const override { if (ctx.HasOutput("XShape")) { - return framework::KernelSignature( - "flatten.mid", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"}); + return framework::KernelSignature("flatten_with_xshape", {"X"}, + {"start_axis", "stop_axis"}, + {"Out", "XShape"}); } else { return framework::KernelSignature("flatten", {"X"}, {"start_axis", "stop_axis"}, {"Out"}); diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 155eb1ebbe3db3416aa6f14e1e7f6e083bfb54e6..a796821729fb28006e20f1215ea1fa8faac8885f 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -19,6 +19,7 @@ limitations under the License. */ // only can include the headers in paddle/pten/api dirs #include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/include/core.h" #include "paddle/pten/include/manipulation.h" namespace paddle { @@ -402,6 +403,7 @@ class ReshapeKernel { auto *shape_tensor = ctx.HasInput("Shape") ? ctx.Input("Shape") : nullptr; + pten::ScalarArray pt_scalar_shape; if (list_new_shape_tensor.size() > 0) { // have shape tensor std::vector pt_vec_shape; @@ -417,22 +419,7 @@ class ReshapeKernel { std::move(*(paddle::experimental::MakePtenDenseTensor(*tensor)))); } } - if (platform::is_cpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); - } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); - } -#endif -#ifdef PADDLE_WITH_XPU - if (platform::is_xpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); - } -#endif + pt_scalar_shape = pten::ScalarArray(pt_vec_shape); } else if (shape_tensor) { std::unique_ptr pt_shape; if (platform::is_gpu_place(shape_tensor->place()) || @@ -443,44 +430,27 @@ class ReshapeKernel { } else { pt_shape = paddle::experimental::MakePtenDenseTensor(*shape_tensor); } - - if (platform::is_cpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); - } -#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); - } -#endif -#ifdef PADDLE_WITH_XPU - if (platform::is_xpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); - } -#endif + pt_scalar_shape = pten::ScalarArray(*pt_shape.get()); } else { auto &shape_attr = ctx.Attr>("shape"); - const std::vector shape_vec(shape_attr.begin(), - shape_attr.end()); - if (platform::is_cpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); - } + pt_scalar_shape = pten::ScalarArray(shape_attr); + } + if (platform::is_cpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); + } #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) - if (platform::is_gpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); - } + if (platform::is_gpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); + } #endif #ifdef PADDLE_WITH_XPU - if (platform::is_xpu_place(ctx.GetPlace())) { - auto &dev_ctx = ctx.device_context(); - pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); - } -#endif + if (platform::is_xpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); } +#endif // non-inplace need move all result from pt_out to out, inplace need set // result dims. if (in != out) { @@ -553,16 +523,16 @@ class Reshape2Op : public ReshapeOp { framework::KernelSignature GetExpectedPtenKernelArgs( const framework::ExecutionContext &ctx) const override { + std::string shape; auto multi_inputs = ctx.MultiInput("ShapeTensor"); if (multi_inputs.size() > 0) { - return framework::KernelSignature("reshape_mulhost", {"X", "ShapeTensor"}, - {}, {"Out"}); + shape = "ShapeTensor"; } else if (ctx.HasInput("Shape")) { - return framework::KernelSignature("reshape_host", {"X", "Shape"}, {}, - {"Out"}); + shape = "Shape"; } else { - return framework::KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); + shape = "shape"; } + return framework::KernelSignature("reshape", {"X"}, {shape}, {"Out"}); } }; diff --git a/paddle/pten/api/include/kernel_signature.h b/paddle/pten/api/include/kernel_signature.h index 1ff91f7e94a20a4807ff4f5f1ebd3f89225f59f1..7d92019d29e5383e5619d748ddc3c64ce9fa1e66 100644 --- a/paddle/pten/api/include/kernel_signature.h +++ b/paddle/pten/api/include/kernel_signature.h @@ -83,7 +83,7 @@ using multiply_kernel = void (*)(const DeviceContext&, using reshape_kernel = void (*)(const DeviceContext&, const DenseTensor&, - const std::vector&, + const ScalarArray&, DenseTensor*); using scale_kernel = void (*)(const DeviceContext&, diff --git a/paddle/pten/common/scalar_array.h b/paddle/pten/common/scalar_array.h index b4d21b98ca0680679b329cf24202f0d93e8ce7c5..81013d8e5a11cdd6b44587bb2151b7be18895c27 100644 --- a/paddle/pten/common/scalar_array.h +++ b/paddle/pten/common/scalar_array.h @@ -28,6 +28,10 @@ class ScalarArrayBase { ScalarArrayBase(const std::vector& vec) : array_(vec) {} // NOLINT + ScalarArrayBase(const std::vector& vec) { // NOLINT + array_.insert(array_.begin(), vec.begin(), vec.end()); + } + ScalarArrayBase(std::initializer_list array_list) : array_(array_list) {} @@ -43,7 +47,7 @@ class ScalarArrayBase { ScalarArrayBase(const T& tensor) { // NOLINT size_t n = tensor.numel(); array_.reserve(n); - switch (tensor.type()) { + switch (tensor.dtype()) { case DataType::INT32: AssignData(tensor.template data(), n); break; @@ -55,7 +59,7 @@ class ScalarArrayBase { "Data type error. Currently, The data type of ScalarArrayBase " "only supports Tensor with int32 and int64, " "but now received `", - tensor.type(), + tensor.dtype(), "`."); } } diff --git a/paddle/pten/include/manipulation.h b/paddle/pten/include/manipulation.h index ee68c6aa62257ca619b5d39e2196837f476492c4..e694a89f700cf5c68ce631f5671b285f23d298e2 100644 --- a/paddle/pten/include/manipulation.h +++ b/paddle/pten/include/manipulation.h @@ -60,7 +60,7 @@ DenseTensor Reshape(const ContextT& dev_ctx, std::make_shared( dev_ctx.GetPlace()); pten::DenseTensor dense_out(allocator, out_meta); - ReshapeFromVectorVal(dev_ctx, x, shape, &dense_out); + Reshape(dev_ctx, x, ScalarArray(shape), &dense_out); return dense_out; } diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index 4e861161ca98bfe58d48fccfaf472758e53fd487..4092e2842b9752cba95ab7b82b74733171a1db06 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -227,6 +227,11 @@ DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta, return return_meta; } +DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta, + const ScalarArray& shape) { + return InferMetaFromVecValue(x_meta, shape.GetData()); +} + DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, const std::vector& axis, bool keep_dim) { diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index 560a27759ad28d36e980786620629cfafdccb1c0..408a77234f4b62d08933c9ac7ad33ca9e95814e4 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once // See Note [ Why still include the fluid headers? ] +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/tensor_meta.h" namespace pten { @@ -50,6 +51,9 @@ DenseTensorMeta FullLikeInferMeta(const DenseTensorMeta& x_meta, DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta, const std::vector& shape); +DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta, + const ScalarArray& shape); + DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, const std::vector& axis, bool keep_dim); diff --git a/paddle/pten/kernels/cpu/manipulation.cc b/paddle/pten/kernels/cpu/manipulation.cc index 61c6cb57a9f780010a5f83dd874cb790fb9879dd..9c34f84233731c0b895e12e236e995440fce5014 100644 --- a/paddle/pten/kernels/cpu/manipulation.cc +++ b/paddle/pten/kernels/cpu/manipulation.cc @@ -46,74 +46,27 @@ void FlattenWithXShape(const CPUContext& dev_ctx, general::SetXShape(x, xshape); } -void ReshapeFromVectorVal(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - auto out_meta = InferMetaFromVecValue(x.meta(), shape); +void Reshape(const CPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out) { + auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData()); if (x.data() == out->data() && x.numel() == out->numel()) { out->Resize(out_meta.dims); return; } pten::Copy(dev_ctx, x, false, out); out->Resize(out_meta.dims); -} - -void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out) { - general::SetXShape(x, xshape); - ReshapeFromVectorVal(dev_ctx, x, shape, out); -} - -void ReshapeFromDT(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out) { - auto* shape_data = shape.data(); - auto vector_shape = - std::vector(shape_data, shape_data + shape.numel()); - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); out->ResetLoD(x.lod()); } -void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* xshape, - DenseTensor* out) { - general::SetXShape(x, xshape); - ReshapeFromDT(dev_ctx, x, shape, out); -} - -void ReshapeFromVectorDT(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - std::vector vector_shape; - for (auto& tensor : shape) { - PADDLE_ENFORCE_EQ( - tensor.dims(), - paddle::framework::make_ddim({1}), - paddle::platform::errors::InvalidArgument( - "If the element type of 'shape' in ReshapeOp is Tensor, " - "the element's shape must be [1]. But received the element's shape " - "is [%s]", - tensor.dims())); - vector_shape.push_back(*tensor.data()); - } - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); -} - -void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out) { +void ReshapeWithXShape(const CPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out) { general::SetXShape(x, xshape); - ReshapeFromVectorDT(dev_ctx, x, shape, out); + Reshape(dev_ctx, x, shape, out); } template @@ -130,8 +83,6 @@ void Cast(const CPUContext& dev_ctx, } // namespace pten -// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel -// architecture, kernel_name should be "flatten". PT_REGISTER_KERNEL(flatten, CPU, ANY, @@ -142,7 +93,7 @@ PT_REGISTER_KERNEL(flatten, int8_t, int, int64_t) {} -PT_REGISTER_KERNEL(flatten_mid, +PT_REGISTER_KERNEL(flatten_with_xshape, CPU, ANY, pten::FlattenWithXShape, @@ -171,33 +122,8 @@ PT_REGISTER_KERNEL(cast, kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); } -PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::ReshapeFromVectorVal) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid, - CPU, - ANY, - pten::ReshapeFromVectorValWithXShape) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CPU, ANY, pten::ReshapeFromDT) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid, - CPU, - ANY, - pten::ReshapeFromDTWithXShape) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost, - CPU, - ANY, - pten::ReshapeFromVectorDT) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid, +PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::Reshape) {} +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape, CPU, ANY, - pten::ReshapeFromVectorDTWithXShape) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} + pten::ReshapeWithXShape) {} diff --git a/paddle/pten/kernels/cpu/manipulation.h b/paddle/pten/kernels/cpu/manipulation.h index 36f9aaa85aa5e3b9b8cc62f843d046d0ee3824e8..cc58354787585e9283e6d9000bbc1114980e31c8 100644 --- a/paddle/pten/kernels/cpu/manipulation.h +++ b/paddle/pten/kernels/cpu/manipulation.h @@ -14,6 +14,7 @@ limitations under the License. */ #pragma once +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_registry.h" @@ -38,37 +39,15 @@ void Cast(const CPUContext& dev_ctx, DataType in_dtype, DenseTensor* out); -void ReshapeFromDT(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out); - -void ReshapeFromVectorVal(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); - -void ReshapeFromVectorDT(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); - -void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* xshape, - DenseTensor* out); - -void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out); +void Reshape(const CPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out); -void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out); +void ReshapeWithXShape(const CPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out); } // namespace pten diff --git a/paddle/pten/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu index e668d1b04d7238e6bbcebc7710abb61f885b1659..d3c0759698eea9ebfb06c1347937713a81dedeb2 100644 --- a/paddle/pten/kernels/cuda/manipulation.cu +++ b/paddle/pten/kernels/cuda/manipulation.cu @@ -46,74 +46,27 @@ void FlattenWithXShape(const CUDAContext& dev_ctx, general::SetXShape(x, xshape); } -void ReshapeFromVectorVal(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - auto out_meta = InferMetaFromVecValue(x.meta(), shape); +void Reshape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out) { + auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData()); if (x.data() == out->data() && x.numel() == out->numel()) { out->Resize(out_meta.dims); return; } pten::Copy(dev_ctx, x, false, out); out->Resize(out_meta.dims); -} - -void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out) { - general::SetXShape(x, xshape); - ReshapeFromVectorVal(dev_ctx, x, shape, out); -} - -void ReshapeFromDT(const CUDAContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out) { - auto* shape_data = shape.data(); - auto vector_shape = - std::vector(shape_data, shape_data + shape.numel()); - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); out->ResetLoD(x.lod()); } -void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* xshape, - DenseTensor* out) { - general::SetXShape(x, xshape); - ReshapeFromDT(dev_ctx, x, shape, out); -} - -void ReshapeFromVectorDT(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - std::vector vector_shape; - for (auto& tensor : shape) { - PADDLE_ENFORCE_EQ( - tensor.dims(), - paddle::framework::make_ddim({1}), - paddle::platform::errors::InvalidArgument( - "If the element type of 'shape' in ReshapeOp is Tensor, " - "the element's shape must be [1]. But received the element's shape " - "is [%s]", - tensor.dims())); - vector_shape.push_back(*tensor.data()); - } - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); -} - -void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out) { +void ReshapeWithXShape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out) { general::SetXShape(x, xshape); - ReshapeFromVectorDT(dev_ctx, x, shape, out); + Reshape(dev_ctx, x, shape, out); } template @@ -142,7 +95,7 @@ PT_REGISTER_KERNEL(flatten, int8_t, int, int64_t) {} -PT_REGISTER_KERNEL(flatten_mid, +PT_REGISTER_KERNEL(flatten_with_xshape, CUDA, ANY, pten::FlattenWithXShape, @@ -179,33 +132,8 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) #endif -PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::ReshapeFromVectorVal) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid, - CUDA, - ANY, - pten::ReshapeFromVectorValWithXShape) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CUDA, ANY, pten::ReshapeFromDT) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid, - CUDA, - ANY, - pten::ReshapeFromDTWithXShape) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost, - CUDA, - ANY, - pten::ReshapeFromVectorDT) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid, +PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::Reshape) {} +PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape, CUDA, ANY, - pten::ReshapeFromVectorDTWithXShape) { - kernel->InputAt(1).SetBackend(pten::Backend::CPU); - kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32); -} + pten::ReshapeWithXShape) {} diff --git a/paddle/pten/kernels/cuda/manipulation.h b/paddle/pten/kernels/cuda/manipulation.h index c0f2d8a11414e6665ae7e1e74a78d1e106603e94..be935a045f9384350aea2cef4fe929ef0932f775 100644 --- a/paddle/pten/kernels/cuda/manipulation.h +++ b/paddle/pten/kernels/cuda/manipulation.h @@ -17,6 +17,7 @@ // CUDA and HIP use same api #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_registry.h" @@ -41,38 +42,16 @@ void Cast(const CUDAContext& dev_ctx, DataType in_dtype, DenseTensor* out); -void ReshapeFromDT(const CUDAContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out); - -void ReshapeFromVectorVal(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); - -void ReshapeFromVectorDT(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); - -void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* xshape, - DenseTensor* out); - -void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out); +void Reshape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out); -void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* xshape, - DenseTensor* out); +void ReshapeWithXShape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out); } // namespace pten diff --git a/paddle/pten/kernels/xpu/manipulation.cc b/paddle/pten/kernels/xpu/manipulation.cc index f361933cad45a5e703b106588a8b3c5514269e78..cee3e5ceedb6dadea24ba6cc96110d2d582b95cd 100644 --- a/paddle/pten/kernels/xpu/manipulation.cc +++ b/paddle/pten/kernels/xpu/manipulation.cc @@ -51,46 +51,27 @@ void FlattenWithXShape(const XPUContext& dev_ctx, xshape->ResetLoD(x.lod()); } -void ReshapeFromVectorVal(const XPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - auto out_meta = InferMetaFromVecValue(x.meta(), shape); - if (&x == out) { +void Reshape(const XPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out) { + auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData()); + if (x.data() == out->data() && x.numel() == out->numel()) { out->Resize(out_meta.dims); return; } pten::Copy(dev_ctx, x, false, out); out->Resize(out_meta.dims); + out->ResetLoD(x.lod()); } -void ReshapeFromDT(const XPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out) { - auto* shape_data = shape.data(); - auto vector_shape = - std::vector(shape_data, shape_data + shape.numel()); - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); -} - -void ReshapeFromVectorDT(const XPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out) { - std::vector vector_shape; - for (auto& tensor : shape) { - PADDLE_ENFORCE_EQ( - tensor.dims(), - paddle::framework::make_ddim({1}), - paddle::platform::errors::InvalidArgument( - "If the element type of 'shape' in ReshapeOp is Tensor, " - "the element's shape must be [1]. But received the element's shape " - "is [%s]", - tensor.dims())); - vector_shape.push_back(*tensor.data()); - } - ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); +void ReshapeWithXShape(const XPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out) { + general::SetXShape(x, xshape); + Reshape(dev_ctx, x, shape, out); } } // namespace pten @@ -107,7 +88,7 @@ PT_REGISTER_KERNEL(flatten, int, int64_t) {} -PT_REGISTER_KERNEL(flatten_mid, +PT_REGISTER_KERNEL(flatten_with_xshape, XPU, ANY, pten::FlattenWithXShape, @@ -119,4 +100,4 @@ PT_REGISTER_KERNEL(flatten_mid, int, int64_t) {} -PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::ReshapeFromVectorVal) {} +PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::Reshape) {} diff --git a/paddle/pten/kernels/xpu/manipulation.h b/paddle/pten/kernels/xpu/manipulation.h index b519a23a50038ec1d5405b2b0664dc51ee0a9b02..a9f57025e1e2c3e2cc215497589dfded17092d94 100644 --- a/paddle/pten/kernels/xpu/manipulation.h +++ b/paddle/pten/kernels/xpu/manipulation.h @@ -16,6 +16,7 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU +#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/kernel_registry.h" @@ -33,20 +34,16 @@ void Flatten(const XPUContext& dev_ctx, int stop_axis, DenseTensor* out); -void ReshapeFromDT(const XPUContext& dev_ctx, - const DenseTensor& x, - const DenseTensor& shape, - DenseTensor* out); - -void ReshapeFromVectorVal(const XPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); +void Reshape(const XPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* out); -void ReshapeFromVectorDT(const XPUContext& dev_ctx, - const DenseTensor& x, - const std::vector& shape, - DenseTensor* out); +void ReshapeWithXShape(const XPUContext& dev_ctx, + const DenseTensor& x, + const ScalarArray& shape, + DenseTensor* xshape, + DenseTensor* out); } // namespace pten diff --git a/python/paddle/utils/code_gen/api.yaml b/python/paddle/utils/code_gen/api.yaml index 0625000cb888350a9f0c8bbe67f908800bd4fa00..e0ea80feebeba59ab271f17871c4c144e7904c83 100644 --- a/python/paddle/utils/code_gen/api.yaml +++ b/python/paddle/utils/code_gen/api.yaml @@ -103,10 +103,10 @@ invoke : full_like(x, 1, dtype, place, layout) - api : reshape - args : (const Tensor& x, const std::vector& shape) + args : (const Tensor& x, const ScalarArray& shape) output : Tensor infer_meta : - func : InferMetaFromVecValue + func : ReshapeInferMeta kernel : func : reshape