From 895692e3c56eb9927ed652c1ecb27dc94e1bf297 Mon Sep 17 00:00:00 2001 From: YuanRisheng Date: Sun, 14 Nov 2021 11:54:43 +0800 Subject: [PATCH] [PTen]Reshape Kernel Refactor (#37164) * reshape kernel refactor * fix compile bugs when run ci * support xpu for reshape * fix bugs when run unittest in kunlun ci * fix compile bugs when run kunlun * perfect code according to suggestion --- paddle/fluid/framework/operator.cc | 4 + paddle/fluid/imperative/prepared_operator.cc | 3 + paddle/fluid/operators/reshape_op.cc | 170 +++++++++++++----- paddle/pten/core/kernel_registry.h | 72 +++----- paddle/pten/core/kernel_utils.h | 1 + paddle/pten/include/manipulation.h | 13 ++ paddle/pten/infermeta/unary.cc | 138 ++++++++++++++ paddle/pten/infermeta/unary.h | 2 + paddle/pten/kernels/cpu/manipulation.cc | 88 ++++++++- paddle/pten/kernels/cpu/manipulation.h | 35 +++- paddle/pten/kernels/cpu/utils.cc | 12 +- paddle/pten/kernels/cuda/manipulation.cu | 89 ++++++++- paddle/pten/kernels/cuda/manipulation.h | 33 ++++ paddle/pten/kernels/cuda/utils.cu | 50 ++++-- paddle/pten/kernels/cuda/utils.h | 5 +- .../kernels/functions/general/manipulation.h | 34 ++++ paddle/pten/kernels/xpu/manipulation.cc | 49 +++++ paddle/pten/kernels/xpu/manipulation.h | 15 ++ paddle/pten/tests/api/test_matmul_api.cc | 6 +- 19 files changed, 684 insertions(+), 135 deletions(-) create mode 100644 paddle/pten/kernels/functions/general/manipulation.h diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index e0a80d3c79..2eb054be49 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1883,6 +1883,10 @@ void OperatorWithKernel::BuildPtenKernelContext( pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { pt_kernel_context_->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + pt_kernel_context_->EmplaceBackAttr( + BOOST_GET_CONST(std::vector, attr)); } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` when construct " diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index c9e211809a..658272b7c0 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -372,6 +372,9 @@ static void BuildDygraphPtenKernelContext( kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(float, attr)); } else if (attr_defs[i].type_index == std::type_index(typeid(bool))) { kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(bool, attr)); + } else if (attr_defs[i].type_index == + std::type_index(typeid(std::vector))) { + kernel_ctx->EmplaceBackAttr(BOOST_GET_CONST(std::vector, attr)); } else { PADDLE_THROW(platform::errors::Unimplemented( "unsupported cast op attribute `%s` when construct " diff --git a/paddle/fluid/operators/reshape_op.cc b/paddle/fluid/operators/reshape_op.cc index 6f244b1a4c..e3104fa065 100644 --- a/paddle/fluid/operators/reshape_op.cc +++ b/paddle/fluid/operators/reshape_op.cc @@ -15,7 +15,12 @@ limitations under the License. */ #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/pten_utils.h" +// only can include the headers in paddle/pten/api dirs +#include "paddle/pten/api/lib/utils/tensor_utils.h" +#include "paddle/pten/include/core.h" +#include "paddle/pten/include/manipulation.h" namespace paddle { namespace framework { class InferShapeContext; @@ -248,13 +253,6 @@ class ReshapeOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - //#ifdef PADDLE_WITH_MKLDNN - // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { - // return framework::OpKernelType(input_data_type, ctx.GetPlace(), - // framework::DataLayout::kMKLDNN, - // framework::LibraryType::kMKLDNN); - // } - //#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } @@ -366,13 +364,6 @@ class ReshapeGradOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); - //#ifdef PADDLE_WITH_MKLDNN - // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { - // return framework::OpKernelType(input_data_type, ctx.GetPlace(), - // framework::DataLayout::kMKLDNN, - // framework::LibraryType::kMKLDNN); - // } - //#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -382,42 +373,117 @@ class ReshapeKernel { void operator()(const framework::ExecutionContext &ctx) const { auto *out = ctx.Output("Out"); auto *in = ctx.Input("X"); - - framework::DDim out_dims = out->dims(); + // framework::DDim out_dims = out->dims(); + auto pt_x = paddle::experimental::MakePtenDenseTensor(*in); + + // we can't MakePtenDenseTensor by out, because reshape will realloc memory + // and this will throw error(can't realloc shared memory) in current + // DenseTensor + // design. So, codes below create a tmp densetensor for output. + // TODO(YuanRisheng) we can use MakePtenDenseTensor after #36916 merge. + const auto alloc = std::make_shared( + paddle::platform::CPUPlace()); + pten::DenseTensorMeta meta{pten::TransToPtenDataType(in->type()), + in->dims(), + pten::TransToPtenDataLayout(in->layout())}; + auto pt_out_tmp = + std::make_shared(alloc, std::move(meta)); + pten::DenseTensor *pt_out = nullptr; + if (in == out) { + pt_out = pt_x.get(); + } else { + pt_out = pt_out_tmp.get(); + } auto list_new_shape_tensor = ctx.MultiInput("ShapeTensor"); + auto *shape_tensor = ctx.HasInput("Shape") + ? ctx.Input("Shape") + : nullptr; if (list_new_shape_tensor.size() > 0) { // have shape tensor - auto new_shape = get_new_shape(list_new_shape_tensor); - out_dims = ReshapeOp::ValidateShape(new_shape, in->dims()); + std::vector pt_vec_shape; + for (auto &tensor : list_new_shape_tensor) { + if (platform::is_gpu_place(tensor->place()) || + platform::is_xpu_place(tensor->place())) { + framework::Tensor temp; + TensorCopySync(*tensor, platform::CPUPlace(), &temp); + pt_vec_shape.push_back( + std::move(*(paddle::experimental::MakePtenDenseTensor(temp)))); + } else { + pt_vec_shape.push_back( + std::move(*(paddle::experimental::MakePtenDenseTensor(*tensor)))); + } + } + if (platform::is_cpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::is_gpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); + } +#endif +#ifdef PADDLE_WITH_XPU + if (platform::is_xpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out); + } +#endif + } else if (shape_tensor) { + std::unique_ptr pt_shape; + if (platform::is_gpu_place(shape_tensor->place()) || + platform::is_xpu_place(shape_tensor->place())) { + framework::Tensor temp; + TensorCopySync(*shape_tensor, platform::CPUPlace(), &temp); + pt_shape = paddle::experimental::MakePtenDenseTensor(temp); + } else { + pt_shape = paddle::experimental::MakePtenDenseTensor(*shape_tensor); + } + if (platform::is_cpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::is_gpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); + } +#endif +#ifdef PADDLE_WITH_XPU + if (platform::is_xpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out); + } +#endif } else { - auto *shape_tensor = ctx.HasInput("Shape") - ? ctx.Input("Shape") - : nullptr; - - if (shape_tensor) { - auto *shape_data = shape_tensor->data(); - framework::Tensor cpu_shape_tensor; - if (platform::is_gpu_place(shape_tensor->place()) || - platform::is_xpu_place(shape_tensor->place())) { - TensorCopySync(*shape_tensor, platform::CPUPlace(), - &cpu_shape_tensor); - shape_data = cpu_shape_tensor.data(); - } - auto shape = - std::vector(shape_data, shape_data + shape_tensor->numel()); - out_dims = ReshapeOp::ValidateShape(shape, in->dims()); + auto &shape_vec = ctx.Attr>("shape"); + if (platform::is_cpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); + } +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + if (platform::is_gpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); } +#endif +#ifdef PADDLE_WITH_XPU + if (platform::is_xpu_place(ctx.GetPlace())) { + auto &dev_ctx = ctx.device_context(); + pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); + } +#endif + } + // non-inplace need move all result from pt_out to out, inplace need set + // result dims. + if (in != out) { + paddle::experimental::MovesStorage(pt_out, static_cast(out)); + } else { + out->Resize(pt_out->dims()); } - - out->Resize(out_dims); - out->mutable_data(ctx.GetPlace(), in->type()); - framework::TensorCopy( - *in, ctx.GetPlace(), - ctx.template device_context(), out); - out->Resize(out_dims); } }; @@ -479,6 +545,21 @@ class Reshape2Op : public ReshapeOp { ReshapeOp::InferShape(ctx); } + + framework::KernelSignature GetExpectedPtenKernelArgs( + const framework::ExecutionContext &ctx) const override { + auto multi_inputs = ctx.MultiInput("ShapeTensor"); + if (multi_inputs.size() > 0) { + return framework::KernelSignature( + "reshape2.mulhost.mid", {"X", "ShapeTensor"}, {}, {"XShape", "Out"}); + } else if (ctx.HasInput("Shape")) { + return framework::KernelSignature("reshape2.host.mid", {"X", "Shape"}, {}, + {"XShape", "Out"}); + } else { + return framework::KernelSignature("reshape2.mid", {"X"}, {"shape"}, + {"XShape", "Out"}); + } + } }; class Reshape2OpMaker : public ReshapeOpMaker { @@ -557,13 +638,6 @@ class Reshape2GradOp : public framework::OperatorWithKernel { auto input_data_type = framework::OperatorWithKernel::IndicateVarDataType( ctx, framework::GradVarName("Out")); - //#ifdef PADDLE_WITH_MKLDNN - // if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { - // return framework::OpKernelType(input_data_type, ctx.GetPlace(), - // framework::DataLayout::kMKLDNN, - // framework::LibraryType::kMKLDNN); - // } - //#endif return framework::OpKernelType(input_data_type, ctx.GetPlace()); } diff --git a/paddle/pten/core/kernel_registry.h b/paddle/pten/core/kernel_registry.h index c2b97148aa..cd6fa80906 100644 --- a/paddle/pten/core/kernel_registry.h +++ b/paddle/pten/core/kernel_registry.h @@ -114,34 +114,16 @@ struct KernelRegistrar { KernelArgsParseFn args_parse_fn, KernelArgsDefFn args_def_fn, KernelFn kernel_fn) { - if (layout == DataLayout::ANY) { - for (size_t layout_iter = static_cast(DataLayout::NHWC); - layout_iter != static_cast(DataLayout::NUM_DATA_LAYOUTS); - layout_iter++) { - for (size_t dtype = static_cast(DataType::BOOL); - dtype != static_cast(DataType::NUM_DATA_TYPES); - dtype++) { - ConstructKernel(kernel_name_cstr, - backend, - static_cast(layout_iter), - static_cast(dtype), - args_parse_fn, - args_def_fn, - kernel_fn); - } - } - } else { - for (size_t dtype = static_cast(DataType::BOOL); - dtype != static_cast(DataType::NUM_DATA_TYPES); - dtype++) { - ConstructKernel(kernel_name_cstr, - backend, - layout, - static_cast(dtype), - args_parse_fn, - args_def_fn, - kernel_fn); - } + for (size_t dtype = static_cast(DataType::BOOL); + dtype != static_cast(DataType::NUM_DATA_TYPES); + dtype++) { + ConstructKernel(kernel_name_cstr, + backend, + layout, + static_cast(dtype), + args_parse_fn, + args_def_fn, + kernel_fn); } } @@ -158,7 +140,6 @@ struct KernelRegistrar { Kernel kernel(kernel_fn); args_parse_fn(kernel_key, kernel.mutable_args_def()); args_def_fn(&kernel); - KernelFactory::Instance().InsertCompatibleOpType(kernel_name.name()); KernelFactory::Instance().kernels()[kernel_name][kernel_key] = kernel; } @@ -838,21 +819,22 @@ struct KernelRegistrar { _PT_REGISTER_KERNEL_WITH_NO_TYPE( \ kernel_name, PT_ID, backend, layout, meta_kernel_fn) -#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \ - kernel_name, func_id, backend, layout, meta_kernel_fn) \ - PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ - PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ - "PT_REGISTER_KERNEL must be called in global namespace."); \ - decltype(meta_kernel_fn) meta_kernel_fn; \ - static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ - func_id)(::pten::Kernel*); \ - static const ::pten::KernelRegistrar __reg_pt_op_kernel_##func_id( \ - kernel_name, \ - BACKEND(backend), \ - DATALAYOUT(layout), \ - ::pten::KernelArgsParseFunctor::Parse, \ - &PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ - PT_KERNEL(meta_kernel_fn)); \ - void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ +#define _PT_REGISTER_KERNEL_WITH_NO_TYPE( \ + kernel_name, func_id, backend, layout, meta_kernel_fn) \ + PT_STATIC_ASSERT_GLOBAL_NAMESPACE( \ + PT_CONCATENATE(pt_op_kernel_ns_check_, func_id), \ + "PT_REGISTER_KERNEL must be called in global namespace."); \ + decltype(meta_kernel_fn) meta_kernel_fn; \ + static void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ + func_id)(::pten::Kernel*); \ + static const ::pten::KernelRegistrar PT_CONCATENATE(__reg_pt_op_kernel_, \ + func_id)( \ + kernel_name, \ + BACKEND(backend), \ + DATALAYOUT(layout), \ + ::pten::KernelArgsParseFunctor::Parse, \ + &PT_CONCATENATE(__PT_KERNEL_args_def_FN_, func_id), \ + PT_KERNEL(meta_kernel_fn)); \ + void PT_CONCATENATE(__PT_KERNEL_args_def_FN_, \ func_id)(::pten::Kernel * kernel) } // namespace pten diff --git a/paddle/pten/core/kernel_utils.h b/paddle/pten/core/kernel_utils.h index 23143c0624..c464519cb9 100644 --- a/paddle/pten/core/kernel_utils.h +++ b/paddle/pten/core/kernel_utils.h @@ -208,6 +208,7 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(paddle::platform::float16); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const Scalar&); PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); + PT_SPECIALIZE_KernelCallHelper_FOR_ATTRIBUTE(const std::vector&); /* Output Helpers */ diff --git a/paddle/pten/include/manipulation.h b/paddle/pten/include/manipulation.h index e10f296dbd..d779b772b0 100644 --- a/paddle/pten/include/manipulation.h +++ b/paddle/pten/include/manipulation.h @@ -37,4 +37,17 @@ DenseTensor Flatten(const ContextT& dev_ctx, return dense_out; } +template +DenseTensor Reshape(const ContextT& dev_ctx, + const DenseTensor& x, + const std::vector& shape) { + auto out_meta = InferShapeFromVecValue(x.meta(), shape); + const auto allocator = + std::make_shared( + dev_ctx.GetPlace()); + pten::DenseTensor dense_out(allocator, out_meta); + ReshapeFromVectorVal(dev_ctx, x, shape, &dense_out); + return dense_out; +} + } // namespace pten diff --git a/paddle/pten/infermeta/unary.cc b/paddle/pten/infermeta/unary.cc index 74b7fdf706..e2f9e5fccc 100644 --- a/paddle/pten/infermeta/unary.cc +++ b/paddle/pten/infermeta/unary.cc @@ -82,4 +82,142 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, layout == DataLayout::UNDEFINED ? x_meta.layout : layout}; } +static paddle::framework::DDim ValidateShape( + const std::vector shape, const paddle::framework::DDim& in_dims) { + const int64_t in_size = paddle::framework::product(in_dims); + auto in_dims_vec = paddle::framework::vectorize(in_dims); + bool all_positive = std::all_of(in_dims_vec.cbegin(), + in_dims_vec.cend(), + [](int64_t i) { return i > 0; }); + // only one dimension can be set to -1, whose size will be automatically + // infered. + const int64_t unk_dim_val = -1; + const int64_t copy_dim_val = 0; + + std::vector output_shape(shape.size(), 0); + int64_t capacity = 1; + int unk_dim_idx = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == unk_dim_val) { + PADDLE_ENFORCE_EQ( + unk_dim_idx, + -1, + paddle::platform::errors::InvalidArgument( + "Only one dimension value of 'shape' in ReshapeOp can " + "be -1. But received shape = [%s], shape[%d] is also -1.", + paddle::framework::make_ddim(shape), + i)); + unk_dim_idx = i; + } else if (shape[i] == copy_dim_val) { + PADDLE_ENFORCE_LT( + static_cast(i), + in_dims.size(), + paddle::platform::errors::InvalidArgument( + "The index of 0 in `shape` must be less than " + "the input tensor X's dimensions. " + "But received shape = [%s], shape[%d] = 0, X's shape = [%s], " + "X's dimensions = %d.", + paddle::framework::make_ddim(shape), + i, + in_dims, + in_dims.size())); + } else { + PADDLE_ENFORCE_GT( + shape[i], + 0, + paddle::platform::errors::InvalidArgument( + "Each dimension value of 'shape' in ReshapeOp must not " + "be negative except one unknown dimension. " + "But received shape = [%s], shape[%d] = %d.", + paddle::framework::make_ddim(shape), + i, + shape[i])); + } + + // NOTE all non-zero values will be converted to True (include negative + // value) + capacity *= (shape[i] ? shape[i] : in_dims[i]); + output_shape[i] = (shape[i] ? static_cast(shape[i]) : in_dims[i]); + } + + if (unk_dim_idx != -1) { + if (all_positive) { + // in_size < 0 and is un-determinate in compile time, skip the check, + // for example, in_dims = [-1, 8, 1, 1], shape = [-1, 3, 8], + // capacity = -24, in_size = -8, output_shape[0] = 0 + // the following check will fail. + output_shape[unk_dim_idx] = -in_size / capacity; + PADDLE_ENFORCE_EQ( + output_shape[unk_dim_idx] * capacity, + -in_size, + paddle::platform::errors::InvalidArgument( + "The 'shape' attribute in ReshapeOp is invalid. " + "The input tensor X'size must be divisible by known " + "capacity of 'shape'. " + "But received X's shape = [%s], X's size = %d, " + "'shape' is [%s], known capacity of 'shape' is %d.", + in_dims, + in_size, + paddle::framework::make_ddim(shape), + capacity)); + } else { + output_shape[unk_dim_idx] = -1; + } + } else { + if (all_positive) { + PADDLE_ENFORCE_EQ( + capacity, + in_size, + paddle::platform::errors::InvalidArgument( + "The 'shape' in ReshapeOp is invalid. " + "The input tensor X'size must be equal to the capacity of " + "'shape'. " + "But received X's shape = [%s], X's size = %d, 'shape' is " + "[%s], the capacity of 'shape' is %d.", + in_dims, + in_size, + paddle::framework::make_ddim(shape), + capacity)); + } + } + + // support reshape with zero-input(input tensor with product(shape) == 0) + // by now we require that if the input tensor is zero shape, the target + // shape of output must be zero + if (in_size == 0) { + PADDLE_ENFORCE_LE( + capacity, + in_size, + paddle::platform::errors::InvalidArgument( + "The 'shape' in ReshapeOp is invalid. " + "The input tensor X's shape = [%s], X's capacity = %d." + "But the target shape of Out is [%s], the " + "capacity of 'Out' is %d.", + in_dims, + in_size, + paddle::framework::make_ddim(shape), + capacity)); + } + + return paddle::framework::make_ddim(output_shape); +} + +DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta, + const std::vector& shape) { + PADDLE_ENFORCE_EQ(!shape.empty(), + true, + paddle::platform::errors::InvalidArgument( + "The parameter 'shape' in ReshapeOp must be set. " + "But received 'shape' is empty.")); + auto x_dims = x_meta.dims; + auto out_dims = ValidateShape(shape, x_dims); + DenseTensorMeta return_meta(x_meta.type, out_dims, x_meta.layout); + if (x_dims[0] == return_meta.dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + return_meta.lod = x_meta.lod; + } + return return_meta; +} + } // namespace pten diff --git a/paddle/pten/infermeta/unary.h b/paddle/pten/infermeta/unary.h index 05f1910ba0..cf88f0060e 100644 --- a/paddle/pten/infermeta/unary.h +++ b/paddle/pten/infermeta/unary.h @@ -45,4 +45,6 @@ DenseTensorMeta FullLikeInferShape(const DenseTensorMeta& x_meta, DataType dtype, DataLayout layout); +DenseTensorMeta InferShapeFromVecValue(const DenseTensorMeta& x_meta, + const std::vector& shape); } // namespace pten diff --git a/paddle/pten/kernels/cpu/manipulation.cc b/paddle/pten/kernels/cpu/manipulation.cc index 8a6ba954d8..fcff674f36 100644 --- a/paddle/pten/kernels/cpu/manipulation.cc +++ b/paddle/pten/kernels/cpu/manipulation.cc @@ -15,6 +15,7 @@ #include "paddle/pten/kernels/cpu/manipulation.h" #include "paddle/pten/infermeta/unary.h" #include "paddle/pten/kernels/cpu/utils.h" +#include "paddle/pten/kernels/functions/general/manipulation.h" namespace pten { @@ -40,14 +41,75 @@ void FlattenWithXShape(const CPUContext& dev_ctx, DenseTensor* out, DenseTensor* xshape) { Flatten(dev_ctx, x, start_axis, stop_axis, out); - const auto& in_dims = x.meta().dims; - std::vector xshape_dims(in_dims.size() + 1); - xshape_dims[0] = 0; - for (int i = 0; i < in_dims.size(); ++i) { - xshape_dims[i + 1] = in_dims[i]; + general::SetXShape(x, xshape); +} + +void ReshapeFromVectorVal(const CPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out) { + auto out_meta = InferShapeFromVecValue(x.meta(), shape); + if (&x == out) { + out->Resize(out_meta.dims); + return; + } + pten::Copy(dev_ctx, x, out); + out->Resize(out_meta.dims); +} + +void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* xshape, + DenseTensor* out) { + ReshapeFromVectorVal(dev_ctx, x, shape, out); + general::SetXShape(x, xshape); +} + +void ReshapeFromDT(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& shape, + DenseTensor* out) { + auto* shape_data = shape.data(); + auto vector_shape = std::vector(shape_data, shape_data + shape.numel()); + ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); +} + +void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& shape, + DenseTensor* xshape, + DenseTensor* out) { + ReshapeFromDT(dev_ctx, x, shape, out); + general::SetXShape(x, xshape); +} + +void ReshapeFromVectorDT(const CPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out) { + std::vector vector_shape; + for (auto& tensor : shape) { + PADDLE_ENFORCE_EQ( + tensor.dims(), + paddle::framework::make_ddim({1}), + paddle::platform::errors::InvalidArgument( + "If the element type of 'shape' in ReshapeOp is Tensor, " + "the element's shape must be [1]. But received the element's shape " + "is [%s]", + tensor.dims())); + vector_shape.push_back(*tensor.data()); } - xshape->Resize(paddle::framework::make_ddim(xshape_dims)); - xshape->set_lod(x.lod()); + ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); +} + +void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* xshape, + DenseTensor* out) { + ReshapeFromVectorDT(dev_ctx, x, shape, out); + general::SetXShape(x, xshape); } } // namespace pten @@ -78,3 +140,15 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid", int8_t, int, int64_t) {} + +// TODO(yuanrisheng): "reshape2" is compatible with old kernel +// architecture, kernel_name should be "reshape". +PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2", + CPU, + ANY, + pten::ReshapeFromVectorVal) {} + +PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2.mid", + CPU, + ANY, + pten::ReshapeFromVectorValWithXShape) {} diff --git a/paddle/pten/kernels/cpu/manipulation.h b/paddle/pten/kernels/cpu/manipulation.h index 22dfb0d8fc..c074774945 100644 --- a/paddle/pten/kernels/cpu/manipulation.h +++ b/paddle/pten/kernels/cpu/manipulation.h @@ -15,8 +15,6 @@ limitations under the License. */ #pragma once #include "paddle/pten/core/dense_tensor.h" -#include "paddle/pten/core/kernel_registry.h" - // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/platform/device_context.h" @@ -31,4 +29,37 @@ void Flatten(const CPUContext& dev_ctx, int stop_axis, DenseTensor* out); +void ReshapeFromDT(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& shape, + DenseTensor* out); + +void ReshapeFromVectorVal(const CPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out); + +void ReshapeFromVectorDT(const CPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out); + +void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& shape, + DenseTensor* xshape, + DenseTensor* out); + +void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* xshape, + DenseTensor* out); + +void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* xshape, + DenseTensor* out); + } // namespace pten diff --git a/paddle/pten/kernels/cpu/utils.cc b/paddle/pten/kernels/cpu/utils.cc index 1f9d675dea..3e0bfccb1e 100644 --- a/paddle/pten/kernels/cpu/utils.cc +++ b/paddle/pten/kernels/cpu/utils.cc @@ -21,21 +21,23 @@ namespace pten { void Copy(const CPUContext& dev_ctx, const DenseTensor& src, DenseTensor* dst) { auto* src_ptr = src.data(); - auto* dst_ptr = dst->mutable_data(); const auto& src_place = src.place(); const auto& dst_place = dst->place(); + VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " + << dst_place; + + dst->Resize(src.dims()); + auto* dst_ptr = dst->mutable_data(); + if (src_ptr == dst_ptr && src_place == dst_place) { VLOG(3) << "Skip copy the same data async from " << src_place << " to " << dst_place; return; } VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr; - - VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " - << dst_place; - dst->Resize(src.dims()); CHECK(dst->layout() == src.layout()); + auto size = src.numel() * paddle::framework::SizeOfType( TransToProtoVarType(src.data_type())); diff --git a/paddle/pten/kernels/cuda/manipulation.cu b/paddle/pten/kernels/cuda/manipulation.cu index 9e8dff8c26..47451226c7 100644 --- a/paddle/pten/kernels/cuda/manipulation.cu +++ b/paddle/pten/kernels/cuda/manipulation.cu @@ -15,6 +15,7 @@ #include "paddle/pten/infermeta/unary.h" #include "paddle/pten/kernels/cuda/manipulation.h" #include "paddle/pten/kernels/cuda/utils.h" +#include "paddle/pten/kernels/functions/general/manipulation.h" namespace pten { @@ -25,7 +26,7 @@ void Flatten(const CUDAContext& dev_ctx, int stop_axis, DenseTensor* out) { auto out_dims = out->dims(); - pten::Copy(dev_ctx, x, out); + pten::Copy(dev_ctx, x, false, out); out->Resize(out_dims); } @@ -40,14 +41,76 @@ void FlattenWithXShape(const CUDAContext& dev_ctx, DenseTensor* out, DenseTensor* xshape) { Flatten(dev_ctx, x, start_axis, stop_axis, out); - const auto& in_dims = x.meta().dims; - std::vector xshape_dims(in_dims.size() + 1); - xshape_dims[0] = 0; - for (int i = 0; i < in_dims.size(); ++i) { - xshape_dims[i + 1] = in_dims[i]; + general::SetXShape(x, xshape); +} + +void ReshapeFromVectorVal(const CUDAContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out) { + auto out_meta = InferShapeFromVecValue(x.meta(), shape); + if (&x == out) { + LOG(INFO) << "out_meta dims:" << out_meta.dims; + out->Resize(out_meta.dims); + return; } - xshape->Resize(paddle::framework::make_ddim(xshape_dims)); - xshape->set_lod(x.lod()); + pten::Copy(dev_ctx, x, false, out); + out->Resize(out_meta.dims); +} + +void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* xshape, + DenseTensor* out) { + ReshapeFromVectorVal(dev_ctx, x, shape, out); + general::SetXShape(x, xshape); +} + +void ReshapeFromDT(const CUDAContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& shape, + DenseTensor* out) { + auto* shape_data = shape.data(); + auto vector_shape = std::vector(shape_data, shape_data + shape.numel()); + ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); +} + +void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& shape, + DenseTensor* xshape, + DenseTensor* out) { + ReshapeFromDT(dev_ctx, x, shape, out); + general::SetXShape(x, xshape); +} + +void ReshapeFromVectorDT(const CUDAContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out) { + std::vector vector_shape; + for (auto& tensor : shape) { + PADDLE_ENFORCE_EQ( + tensor.dims(), + paddle::framework::make_ddim({1}), + paddle::platform::errors::InvalidArgument( + "If the element type of 'shape' in ReshapeOp is Tensor, " + "the element's shape must be [1]. But received the element's shape " + "is [%s]", + tensor.dims())); + vector_shape.push_back(*tensor.data()); + } + ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); +} + +void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* xshape, + DenseTensor* out) { + ReshapeFromVectorDT(dev_ctx, x, shape, out); + general::SetXShape(x, xshape); } } // namespace pten @@ -80,3 +143,13 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid", int8_t, int, int64_t) {} + +PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2", + CUDA, + ANY, + pten::ReshapeFromVectorVal) {} + +PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2.mid", + CUDA, + ANY, + pten::ReshapeFromVectorValWithXShape) {} diff --git a/paddle/pten/kernels/cuda/manipulation.h b/paddle/pten/kernels/cuda/manipulation.h index ac1cb0324f..6a071d6e49 100644 --- a/paddle/pten/kernels/cuda/manipulation.h +++ b/paddle/pten/kernels/cuda/manipulation.h @@ -33,6 +33,39 @@ void Flatten(const CUDAContext& dev_ctx, int stop_axis, DenseTensor* out); +void ReshapeFromDT(const CUDAContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& shape, + DenseTensor* out); + +void ReshapeFromVectorVal(const CUDAContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out); + +void ReshapeFromVectorDT(const CUDAContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out); + +void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& shape, + DenseTensor* xshape, + DenseTensor* out); + +void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* xshape, + DenseTensor* out); + +void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* xshape, + DenseTensor* out); + } // namespace pten #endif diff --git a/paddle/pten/kernels/cuda/utils.cu b/paddle/pten/kernels/cuda/utils.cu index e81e00a587..c3940b42ca 100644 --- a/paddle/pten/kernels/cuda/utils.cu +++ b/paddle/pten/kernels/cuda/utils.cu @@ -22,23 +22,32 @@ namespace pten { void Copy(const CUDAContext& dev_ctx, const DenseTensor& src, + bool is_sync, DenseTensor* dst) { auto* src_ptr = src.data(); - auto* dst_ptr = dst->mutable_data(); const auto& src_place = src.place(); const auto& dst_place = dst->place(); + if (src_place == dst_place && paddle::platform::is_cpu_place(src_place)) { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The src and dst tensor are all CPU tensor, you should call copy " + "function in CPU mode.")); + } + + VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " + << dst_place; + + dst->Resize(src.dims()); + auto* dst_ptr = dst->mutable_data(); + if (src_ptr == dst_ptr && src_place == dst_place) { VLOG(3) << "Skip copy the same data async from " << src_place << " to " << dst_place; return; } VLOG(4) << "src:" << src_ptr << ", dst:" << dst_ptr; - - VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " - << dst_place; - dst->Resize(src.dims()); CHECK(dst->layout() == src.layout()); + auto size = src.numel() * paddle::framework::SizeOfType( TransToProtoVarType(src.data_type())); @@ -88,8 +97,10 @@ void Copy(const CUDAContext& dev_ctx, src_gpu_place, ctx_gpu_place)); auto stream = - reinterpret_cast(dev_ctx) - .stream(); + is_sync ? nullptr + : reinterpret_cast( + dev_ctx) + .stream(); paddle::memory::Copy( dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); } else if (paddle::platform::is_cpu_place(src_place) && // NOLINT @@ -114,8 +125,10 @@ void Copy(const CUDAContext& dev_ctx, dst_gpu_place, ctx_gpu_place)); auto stream = - reinterpret_cast(dev_ctx) - .stream(); + is_sync ? nullptr + : reinterpret_cast( + dev_ctx) + .stream(); paddle::memory::Copy( dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size, stream); } else if (paddle::platform::is_gpu_place(src_place) && // NOLINT @@ -142,8 +155,10 @@ void Copy(const CUDAContext& dev_ctx, src_gpu_place.device, ctx_gpu_place.device)); auto stream = - reinterpret_cast(dev_ctx) - .stream(); + is_sync ? nullptr + : reinterpret_cast( + dev_ctx) + .stream(); paddle::memory::Copy( dst_cuda_pinned_place, dst_ptr, src_gpu_place, src_ptr, size, stream); } else if (paddle::platform::is_cuda_pinned_place(src_place) && // NOLINT @@ -170,8 +185,10 @@ void Copy(const CUDAContext& dev_ctx, dst_gpu_place.device, ctx_gpu_place.device)); auto stream = - reinterpret_cast(dev_ctx) - .stream(); + is_sync ? nullptr + : reinterpret_cast( + dev_ctx) + .stream(); paddle::memory::Copy( dst_gpu_place, dst_ptr, src_cuda_pinned_place, src_ptr, size, stream); } else if (paddle::platform::is_gpu_place(src_place) && // NOLINT @@ -188,8 +205,10 @@ void Copy(const CUDAContext& dev_ctx, "Context place error, excepted GPUPlace, but actually %s.", ctx_place)); auto stream = - reinterpret_cast(dev_ctx) - .stream(); + is_sync ? nullptr + : reinterpret_cast( + dev_ctx) + .stream(); if (paddle::platform::is_same_place(src_place, dst_place)) { paddle::memory::Copy( dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size, stream); @@ -213,7 +232,6 @@ void Copy(const CUDAContext& dev_ctx, } } } - } // namespace pten // TODO(chenweihang): replace by better impl diff --git a/paddle/pten/kernels/cuda/utils.h b/paddle/pten/kernels/cuda/utils.h index 0d79f04f2e..cc24628ee3 100644 --- a/paddle/pten/kernels/cuda/utils.h +++ b/paddle/pten/kernels/cuda/utils.h @@ -26,7 +26,10 @@ namespace pten { using CUDAContext = paddle::platform::CUDADeviceContext; -void Copy(const CUDAContext& dev_ctx, const DenseTensor& src, DenseTensor* dst); +void Copy(const CUDAContext& dev_ctx, + const DenseTensor& src, + bool is_sync, + DenseTensor* dst); } // namespace pten diff --git a/paddle/pten/kernels/functions/general/manipulation.h b/paddle/pten/kernels/functions/general/manipulation.h new file mode 100644 index 0000000000..b05691ed7c --- /dev/null +++ b/paddle/pten/kernels/functions/general/manipulation.h @@ -0,0 +1,34 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/pten/core/dense_tensor.h" + +namespace pten { +namespace general { + +inline void SetXShape(const DenseTensor& x, DenseTensor* xshape) { + const auto& in_dims = x.meta().dims; + std::vector xshape_dims(in_dims.size() + 1); + xshape_dims[0] = 0; + for (int i = 0; i < in_dims.size(); ++i) { + xshape_dims[i + 1] = in_dims[i]; + } + xshape->Resize(paddle::framework::make_ddim(xshape_dims)); + xshape->set_lod(x.meta().lod); +} + +} // namespace general +} // namespace pten diff --git a/paddle/pten/kernels/xpu/manipulation.cc b/paddle/pten/kernels/xpu/manipulation.cc index 4313520c8b..6ce143e5e3 100644 --- a/paddle/pten/kernels/xpu/manipulation.cc +++ b/paddle/pten/kernels/xpu/manipulation.cc @@ -14,6 +14,7 @@ #include "paddle/pten/kernels/xpu/manipulation.h" #include "paddle/pten/infermeta/unary.h" +#include "paddle/pten/kernels/functions/general/manipulation.h" #include "paddle/pten/kernels/xpu/utils.h" namespace pten { @@ -50,6 +51,47 @@ void FlattenWithXShape(const XPUContext& dev_ctx, xshape->set_lod(x.lod()); } +void ReshapeFromVectorVal(const XPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out) { + auto out_meta = InferShapeFromVecValue(x.meta(), shape); + if (&x == out) { + out->Resize(out_meta.dims); + return; + } + pten::Copy(dev_ctx, x, out); + out->Resize(out_meta.dims); +} + +void ReshapeFromDT(const XPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& shape, + DenseTensor* out) { + auto* shape_data = shape.data(); + auto vector_shape = std::vector(shape_data, shape_data + shape.numel()); + ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); +} + +void ReshapeFromVectorDT(const XPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out) { + std::vector vector_shape; + for (auto& tensor : shape) { + PADDLE_ENFORCE_EQ( + tensor.dims(), + paddle::framework::make_ddim({1}), + paddle::platform::errors::InvalidArgument( + "If the element type of 'shape' in ReshapeOp is Tensor, " + "the element's shape must be [1]. But received the element's shape " + "is [%s]", + tensor.dims())); + vector_shape.push_back(*tensor.data()); + } + ReshapeFromVectorVal(dev_ctx, x, vector_shape, out); +} + } // namespace pten // TODO(chenweihang): replace by better impl @@ -80,3 +122,10 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid", int8_t, int, int64_t) {} + +// TODO(yuanrisheng): "reshape2" is compatible with old kernel +// architecture, kernel_name should be "reshape". +PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2", + XPU, + ANY, + pten::ReshapeFromVectorVal) {} diff --git a/paddle/pten/kernels/xpu/manipulation.h b/paddle/pten/kernels/xpu/manipulation.h index 02947759b4..61a9536f8c 100644 --- a/paddle/pten/kernels/xpu/manipulation.h +++ b/paddle/pten/kernels/xpu/manipulation.h @@ -33,6 +33,21 @@ void Flatten(const XPUContext& dev_ctx, int stop_axis, DenseTensor* out); +void ReshapeFromDT(const XPUContext& dev_ctx, + const DenseTensor& x, + const DenseTensor& shape, + DenseTensor* out); + +void ReshapeFromVectorVal(const XPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out); + +void ReshapeFromVectorDT(const XPUContext& dev_ctx, + const DenseTensor& x, + const std::vector& shape, + DenseTensor* out); + } // namespace pten #endif diff --git a/paddle/pten/tests/api/test_matmul_api.cc b/paddle/pten/tests/api/test_matmul_api.cc index 83a70c905b..b0b649c690 100644 --- a/paddle/pten/tests/api/test_matmul_api.cc +++ b/paddle/pten/tests/api/test_matmul_api.cc @@ -125,8 +125,8 @@ TEST(API, matmul_cuda) { auto place = paddle::platform::CUDAPlace(); auto* dev_ctx = pool.GetByPlace(place); - pten::Copy(*dev_ctx, *ref_x.get(), dense_x.get()); - pten::Copy(*dev_ctx, *ref_y.get(), dense_y.get()); + pten::Copy(*dev_ctx, *ref_x.get(), false, dense_x.get()); + pten::Copy(*dev_ctx, *ref_y.get(), false, dense_y.get()); paddle::experimental::Tensor x(dense_x); paddle::experimental::Tensor y(dense_y); @@ -150,7 +150,7 @@ TEST(API, matmul_cuda) { pten::DenseTensorMeta( pten::DataType::FLOAT32, out.shape(), pten::DataLayout::NCHW)); - pten::Copy(*dev_ctx, *dense_out.get(), ref_out.get()); + pten::Copy(*dev_ctx, *dense_out.get(), false, ref_out.get()); for (size_t i = 0; i < 9; i++) { ASSERT_NEAR(sum[i], ref_out->data()[i], 1e-6f); -- GitLab