未验证 提交 a3c8abc7 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen] Reduce reshape kernel functions in pten (#38055)

* Reduce reshape kernel functions in pten

* delete notes

* fix bugs when compile
上级 19a833c8
...@@ -1872,6 +1872,10 @@ void OperatorWithKernel::BuildPtenKernelContext( ...@@ -1872,6 +1872,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
std::type_index(typeid(std::vector<int64_t>))) { std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray( pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second)))); BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when " "Unsupported cast op attribute `%s` to ScalarArray when "
......
...@@ -358,6 +358,10 @@ static void BuildDygraphPtenKernelContext( ...@@ -358,6 +358,10 @@ static void BuildDygraphPtenKernelContext(
std::type_index(typeid(std::vector<int64_t>))) { std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr(std::move( kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int64_t>, attr)))); pten::ScalarArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to VectorTensor when " "Unsupported cast op attribute `%s` to VectorTensor when "
......
...@@ -337,8 +337,9 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel { ...@@ -337,8 +337,9 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
framework::KernelSignature GetExpectedPtenKernelArgs( framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
if (ctx.HasOutput("XShape")) { if (ctx.HasOutput("XShape")) {
return framework::KernelSignature( return framework::KernelSignature("flatten_with_xshape", {"X"},
"flatten.mid", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"}); {"start_axis", "stop_axis"},
{"Out", "XShape"});
} else { } else {
return framework::KernelSignature("flatten", {"X"}, return framework::KernelSignature("flatten", {"X"},
{"start_axis", "stop_axis"}, {"Out"}); {"start_axis", "stop_axis"}, {"Out"});
......
...@@ -19,6 +19,7 @@ limitations under the License. */ ...@@ -19,6 +19,7 @@ limitations under the License. */
// only can include the headers in paddle/pten/api dirs // only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h" #include "paddle/pten/include/manipulation.h"
namespace paddle { namespace paddle {
...@@ -402,6 +403,7 @@ class ReshapeKernel { ...@@ -402,6 +403,7 @@ class ReshapeKernel {
auto *shape_tensor = ctx.HasInput("Shape") auto *shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape") ? ctx.Input<framework::LoDTensor>("Shape")
: nullptr; : nullptr;
pten::ScalarArray pt_scalar_shape;
if (list_new_shape_tensor.size() > 0) { if (list_new_shape_tensor.size() > 0) {
// have shape tensor // have shape tensor
std::vector<pten::DenseTensor> pt_vec_shape; std::vector<pten::DenseTensor> pt_vec_shape;
...@@ -417,22 +419,7 @@ class ReshapeKernel { ...@@ -417,22 +419,7 @@ class ReshapeKernel {
std::move(*(paddle::experimental::MakePtenDenseTensor(*tensor)))); std::move(*(paddle::experimental::MakePtenDenseTensor(*tensor))));
} }
} }
if (platform::is_cpu_place(ctx.GetPlace())) { pt_scalar_shape = pten::ScalarArray(pt_vec_shape);
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#endif
} else if (shape_tensor) { } else if (shape_tensor) {
std::unique_ptr<pten::DenseTensor> pt_shape; std::unique_ptr<pten::DenseTensor> pt_shape;
if (platform::is_gpu_place(shape_tensor->place()) || if (platform::is_gpu_place(shape_tensor->place()) ||
...@@ -443,44 +430,27 @@ class ReshapeKernel { ...@@ -443,44 +430,27 @@ class ReshapeKernel {
} else { } else {
pt_shape = paddle::experimental::MakePtenDenseTensor(*shape_tensor); pt_shape = paddle::experimental::MakePtenDenseTensor(*shape_tensor);
} }
pt_scalar_shape = pten::ScalarArray(*pt_shape.get());
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#endif
} else { } else {
auto &shape_attr = ctx.Attr<std::vector<int>>("shape"); auto &shape_attr = ctx.Attr<std::vector<int>>("shape");
const std::vector<int64_t> shape_vec(shape_attr.begin(), pt_scalar_shape = pten::ScalarArray(shape_attr);
shape_attr.end()); }
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>(); auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
} }
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) { if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>(); auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out); pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
}
#endif
} }
#endif
// non-inplace need move all result from pt_out to out, inplace need set // non-inplace need move all result from pt_out to out, inplace need set
// result dims. // result dims.
if (in != out) { if (in != out) {
...@@ -553,16 +523,16 @@ class Reshape2Op : public ReshapeOp { ...@@ -553,16 +523,16 @@ class Reshape2Op : public ReshapeOp {
framework::KernelSignature GetExpectedPtenKernelArgs( framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
std::string shape;
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor"); auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (multi_inputs.size() > 0) { if (multi_inputs.size() > 0) {
return framework::KernelSignature("reshape_mulhost", {"X", "ShapeTensor"}, shape = "ShapeTensor";
{}, {"Out"});
} else if (ctx.HasInput("Shape")) { } else if (ctx.HasInput("Shape")) {
return framework::KernelSignature("reshape_host", {"X", "Shape"}, {}, shape = "Shape";
{"Out"});
} else { } else {
return framework::KernelSignature("reshape", {"X"}, {"shape"}, {"Out"}); shape = "shape";
} }
return framework::KernelSignature("reshape", {"X"}, {shape}, {"Out"});
} }
}; };
......
...@@ -83,7 +83,7 @@ using multiply_kernel = void (*)(const DeviceContext&, ...@@ -83,7 +83,7 @@ using multiply_kernel = void (*)(const DeviceContext&,
using reshape_kernel = void (*)(const DeviceContext&, using reshape_kernel = void (*)(const DeviceContext&,
const DenseTensor&, const DenseTensor&,
const std::vector<int64_t>&, const ScalarArray&,
DenseTensor*); DenseTensor*);
using scale_kernel = void (*)(const DeviceContext&, using scale_kernel = void (*)(const DeviceContext&,
......
...@@ -28,6 +28,10 @@ class ScalarArrayBase { ...@@ -28,6 +28,10 @@ class ScalarArrayBase {
ScalarArrayBase(const std::vector<int64_t>& vec) : array_(vec) {} // NOLINT ScalarArrayBase(const std::vector<int64_t>& vec) : array_(vec) {} // NOLINT
ScalarArrayBase(const std::vector<int32_t>& vec) { // NOLINT
array_.insert(array_.begin(), vec.begin(), vec.end());
}
ScalarArrayBase(std::initializer_list<int64_t> array_list) ScalarArrayBase(std::initializer_list<int64_t> array_list)
: array_(array_list) {} : array_(array_list) {}
...@@ -43,7 +47,7 @@ class ScalarArrayBase { ...@@ -43,7 +47,7 @@ class ScalarArrayBase {
ScalarArrayBase(const T& tensor) { // NOLINT ScalarArrayBase(const T& tensor) { // NOLINT
size_t n = tensor.numel(); size_t n = tensor.numel();
array_.reserve(n); array_.reserve(n);
switch (tensor.type()) { switch (tensor.dtype()) {
case DataType::INT32: case DataType::INT32:
AssignData(tensor.template data<int32_t>(), n); AssignData(tensor.template data<int32_t>(), n);
break; break;
...@@ -55,7 +59,7 @@ class ScalarArrayBase { ...@@ -55,7 +59,7 @@ class ScalarArrayBase {
"Data type error. Currently, The data type of ScalarArrayBase " "Data type error. Currently, The data type of ScalarArrayBase "
"only supports Tensor with int32 and int64, " "only supports Tensor with int32 and int64, "
"but now received `", "but now received `",
tensor.type(), tensor.dtype(),
"`."); "`.");
} }
} }
......
...@@ -60,7 +60,7 @@ DenseTensor Reshape(const ContextT& dev_ctx, ...@@ -60,7 +60,7 @@ DenseTensor Reshape(const ContextT& dev_ctx,
std::make_shared<paddle::experimental::DefaultAllocator>( std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace()); dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta); pten::DenseTensor dense_out(allocator, out_meta);
ReshapeFromVectorVal(dev_ctx, x, shape, &dense_out); Reshape(dev_ctx, x, ScalarArray(shape), &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -227,6 +227,11 @@ DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta, ...@@ -227,6 +227,11 @@ DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta,
return return_meta; return return_meta;
} }
DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta,
const ScalarArray& shape) {
return InferMetaFromVecValue(x_meta, shape.GetData());
}
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim) { bool keep_dim) {
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/tensor_meta.h" #include "paddle/pten/core/tensor_meta.h"
namespace pten { namespace pten {
...@@ -50,6 +51,9 @@ DenseTensorMeta FullLikeInferMeta(const DenseTensorMeta& x_meta, ...@@ -50,6 +51,9 @@ DenseTensorMeta FullLikeInferMeta(const DenseTensorMeta& x_meta,
DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta, DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& shape); const std::vector<int64_t>& shape);
DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta,
const ScalarArray& shape);
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta, DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& axis, const std::vector<int64_t>& axis,
bool keep_dim); bool keep_dim);
......
...@@ -46,74 +46,27 @@ void FlattenWithXShape(const CPUContext& dev_ctx, ...@@ -46,74 +46,27 @@ void FlattenWithXShape(const CPUContext& dev_ctx,
general::SetXShape(x, xshape); general::SetXShape(x, xshape);
} }
void ReshapeFromVectorVal(const CPUContext& dev_ctx, void Reshape(const CPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int64_t>& shape, const ScalarArray& shape,
DenseTensor* out) { DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape); auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData());
if (x.data() == out->data() && x.numel() == out->numel()) { if (x.data() == out->data() && x.numel() == out->numel()) {
out->Resize(out_meta.dims); out->Resize(out_meta.dims);
return; return;
} }
pten::Copy(dev_ctx, x, false, out); pten::Copy(dev_ctx, x, false, out);
out->Resize(out_meta.dims); out->Resize(out_meta.dims);
}
void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape);
ReshapeFromVectorVal(dev_ctx, x, shape, out);
}
void ReshapeFromDT(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* out) {
auto* shape_data = shape.data<int>();
auto vector_shape =
std::vector<int64_t>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
out->ResetLoD(x.lod()); out->ResetLoD(x.lod());
} }
void ReshapeFromDTWithXShape(const CPUContext& dev_ctx, void ReshapeWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& shape, const ScalarArray& shape,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out) { DenseTensor* out) {
general::SetXShape(x, xshape);
ReshapeFromDT(dev_ctx, x, shape, out);
}
void ReshapeFromVectorDT(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int64_t> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
paddle::framework::make_ddim({1}),
paddle::platform::errors::InvalidArgument(
"If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor.dims()));
vector_shape.push_back(*tensor.data<int32_t>());
}
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}
void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape); general::SetXShape(x, xshape);
ReshapeFromVectorDT(dev_ctx, x, shape, out); Reshape(dev_ctx, x, shape, out);
} }
template <typename T> template <typename T>
...@@ -130,8 +83,6 @@ void Cast(const CPUContext& dev_ctx, ...@@ -130,8 +83,6 @@ void Cast(const CPUContext& dev_ctx,
} // namespace pten } // namespace pten
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL(flatten, PT_REGISTER_KERNEL(flatten,
CPU, CPU,
ANY, ANY,
...@@ -142,7 +93,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -142,7 +93,7 @@ PT_REGISTER_KERNEL(flatten,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(flatten_mid, PT_REGISTER_KERNEL(flatten_with_xshape,
CPU, CPU,
ANY, ANY,
pten::FlattenWithXShape, pten::FlattenWithXShape,
...@@ -171,33 +122,8 @@ PT_REGISTER_KERNEL(cast, ...@@ -171,33 +122,8 @@ PT_REGISTER_KERNEL(cast,
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED); kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
} }
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::ReshapeFromVectorVal) {} PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid, PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CPU,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CPU, ANY, pten::ReshapeFromDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid,
CPU,
ANY,
pten::ReshapeFromDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost,
CPU,
ANY,
pten::ReshapeFromVectorDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid,
CPU, CPU,
ANY, ANY,
pten::ReshapeFromVectorDTWithXShape) { pten::ReshapeWithXShape) {}
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
...@@ -38,37 +39,15 @@ void Cast(const CPUContext& dev_ctx, ...@@ -38,37 +39,15 @@ void Cast(const CPUContext& dev_ctx,
DataType in_dtype, DataType in_dtype,
DenseTensor* out); DenseTensor* out);
void ReshapeFromDT(const CPUContext& dev_ctx, void Reshape(const CPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& shape, const ScalarArray& shape,
DenseTensor* out); DenseTensor* out);
void ReshapeFromVectorVal(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* out);
void ReshapeFromVectorDT(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out);
void ReshapeFromDTWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx, void ReshapeWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<DenseTensor>& shape, const ScalarArray& shape,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out); DenseTensor* out);
} // namespace pten } // namespace pten
...@@ -46,74 +46,27 @@ void FlattenWithXShape(const CUDAContext& dev_ctx, ...@@ -46,74 +46,27 @@ void FlattenWithXShape(const CUDAContext& dev_ctx,
general::SetXShape(x, xshape); general::SetXShape(x, xshape);
} }
void ReshapeFromVectorVal(const CUDAContext& dev_ctx, void Reshape(const CUDAContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int64_t>& shape, const ScalarArray& shape,
DenseTensor* out) { DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape); auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData());
if (x.data() == out->data() && x.numel() == out->numel()) { if (x.data() == out->data() && x.numel() == out->numel()) {
out->Resize(out_meta.dims); out->Resize(out_meta.dims);
return; return;
} }
pten::Copy(dev_ctx, x, false, out); pten::Copy(dev_ctx, x, false, out);
out->Resize(out_meta.dims); out->Resize(out_meta.dims);
}
void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape);
ReshapeFromVectorVal(dev_ctx, x, shape, out);
}
void ReshapeFromDT(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* out) {
auto* shape_data = shape.data<int>();
auto vector_shape =
std::vector<int64_t>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
out->ResetLoD(x.lod()); out->ResetLoD(x.lod());
} }
void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx, void ReshapeWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& shape, const ScalarArray& shape,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out) { DenseTensor* out) {
general::SetXShape(x, xshape);
ReshapeFromDT(dev_ctx, x, shape, out);
}
void ReshapeFromVectorDT(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int64_t> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
paddle::framework::make_ddim({1}),
paddle::platform::errors::InvalidArgument(
"If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor.dims()));
vector_shape.push_back(*tensor.data<int32_t>());
}
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}
void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape); general::SetXShape(x, xshape);
ReshapeFromVectorDT(dev_ctx, x, shape, out); Reshape(dev_ctx, x, shape, out);
} }
template <typename T> template <typename T>
...@@ -142,7 +95,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -142,7 +95,7 @@ PT_REGISTER_KERNEL(flatten,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(flatten_mid, PT_REGISTER_KERNEL(flatten_with_xshape,
CUDA, CUDA,
ANY, ANY,
pten::FlattenWithXShape, pten::FlattenWithXShape,
...@@ -179,33 +132,8 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16) ...@@ -179,33 +132,8 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast) PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif #endif
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::ReshapeFromVectorVal) {} PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid, PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CUDA,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CUDA, ANY, pten::ReshapeFromDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid,
CUDA,
ANY,
pten::ReshapeFromDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost,
CUDA,
ANY,
pten::ReshapeFromVectorDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid,
CUDA, CUDA,
ANY, ANY,
pten::ReshapeFromVectorDTWithXShape) { pten::ReshapeWithXShape) {}
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
// CUDA and HIP use same api // CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
...@@ -41,38 +42,16 @@ void Cast(const CUDAContext& dev_ctx, ...@@ -41,38 +42,16 @@ void Cast(const CUDAContext& dev_ctx,
DataType in_dtype, DataType in_dtype,
DenseTensor* out); DenseTensor* out);
void ReshapeFromDT(const CUDAContext& dev_ctx, void Reshape(const CUDAContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& shape, const ScalarArray& shape,
DenseTensor* out); DenseTensor* out);
void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* out);
void ReshapeFromVectorDT(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out);
void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx, void ReshapeWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<DenseTensor>& shape, const ScalarArray& shape,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out); DenseTensor* out);
} // namespace pten } // namespace pten
......
...@@ -51,46 +51,27 @@ void FlattenWithXShape(const XPUContext& dev_ctx, ...@@ -51,46 +51,27 @@ void FlattenWithXShape(const XPUContext& dev_ctx,
xshape->ResetLoD(x.lod()); xshape->ResetLoD(x.lod());
} }
void ReshapeFromVectorVal(const XPUContext& dev_ctx, void Reshape(const XPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<int64_t>& shape, const ScalarArray& shape,
DenseTensor* out) { DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape); auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData());
if (&x == out) { if (x.data() == out->data() && x.numel() == out->numel()) {
out->Resize(out_meta.dims); out->Resize(out_meta.dims);
return; return;
} }
pten::Copy(dev_ctx, x, false, out); pten::Copy(dev_ctx, x, false, out);
out->Resize(out_meta.dims); out->Resize(out_meta.dims);
out->ResetLoD(x.lod());
} }
void ReshapeFromDT(const XPUContext& dev_ctx, void ReshapeWithXShape(const XPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& shape, const ScalarArray& shape,
DenseTensor* out) { DenseTensor* xshape,
auto* shape_data = shape.data<int>(); DenseTensor* out) {
auto vector_shape = general::SetXShape(x, xshape);
std::vector<int64_t>(shape_data, shape_data + shape.numel()); Reshape(dev_ctx, x, shape, out);
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}
void ReshapeFromVectorDT(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int64_t> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
paddle::framework::make_ddim({1}),
paddle::platform::errors::InvalidArgument(
"If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor.dims()));
vector_shape.push_back(*tensor.data<int32_t>());
}
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
} }
} // namespace pten } // namespace pten
...@@ -107,7 +88,7 @@ PT_REGISTER_KERNEL(flatten, ...@@ -107,7 +88,7 @@ PT_REGISTER_KERNEL(flatten,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL(flatten_mid, PT_REGISTER_KERNEL(flatten_with_xshape,
XPU, XPU,
ANY, ANY,
pten::FlattenWithXShape, pten::FlattenWithXShape,
...@@ -119,4 +100,4 @@ PT_REGISTER_KERNEL(flatten_mid, ...@@ -119,4 +100,4 @@ PT_REGISTER_KERNEL(flatten_mid,
int, int,
int64_t) {} int64_t) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::ReshapeFromVectorVal) {} PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::Reshape) {}
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h" #include "paddle/pten/core/kernel_registry.h"
...@@ -33,20 +34,16 @@ void Flatten(const XPUContext& dev_ctx, ...@@ -33,20 +34,16 @@ void Flatten(const XPUContext& dev_ctx,
int stop_axis, int stop_axis,
DenseTensor* out); DenseTensor* out);
void ReshapeFromDT(const XPUContext& dev_ctx, void Reshape(const XPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& shape, const ScalarArray& shape,
DenseTensor* out); DenseTensor* out);
void ReshapeFromVectorVal(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* out);
void ReshapeFromVectorDT(const XPUContext& dev_ctx, void ReshapeWithXShape(const XPUContext& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const std::vector<DenseTensor>& shape, const ScalarArray& shape,
DenseTensor* out); DenseTensor* xshape,
DenseTensor* out);
} // namespace pten } // namespace pten
......
...@@ -103,10 +103,10 @@ ...@@ -103,10 +103,10 @@
invoke : full_like(x, 1, dtype, place, layout) invoke : full_like(x, 1, dtype, place, layout)
- api : reshape - api : reshape
args : (const Tensor& x, const std::vector<int64_t>& shape) args : (const Tensor& x, const ScalarArray& shape)
output : Tensor output : Tensor
infer_meta : infer_meta :
func : InferMetaFromVecValue func : ReshapeInferMeta
kernel : kernel :
func : reshape func : reshape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册