未验证 提交 a3c8abc7 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen] Reduce reshape kernel functions in pten (#38055)

* Reduce reshape kernel functions in pten

* delete notes

* fix bugs when compile
上级 19a833c8
......@@ -1872,6 +1872,10 @@ void OperatorWithKernel::BuildPtenKernelContext(
std::type_index(typeid(std::vector<int64_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int64_t>, attr_iter->second))));
} else if (std::type_index(attr_iter->second.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
pt_kernel_context_->EmplaceBackAttr(std::move(pten::ScalarArray(
BOOST_GET_CONST(std::vector<int32_t>, attr_iter->second))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
......
......@@ -358,6 +358,10 @@ static void BuildDygraphPtenKernelContext(
std::type_index(typeid(std::vector<int64_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int64_t>, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(std::vector<int32_t>))) {
kernel_ctx->EmplaceBackAttr(std::move(
pten::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to VectorTensor when "
......
......@@ -337,8 +337,9 @@ class FlattenContiguousRangeOp : public framework::OperatorWithKernel {
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
if (ctx.HasOutput("XShape")) {
return framework::KernelSignature(
"flatten.mid", {"X"}, {"start_axis", "stop_axis"}, {"Out", "XShape"});
return framework::KernelSignature("flatten_with_xshape", {"X"},
{"start_axis", "stop_axis"},
{"Out", "XShape"});
} else {
return framework::KernelSignature("flatten", {"X"},
{"start_axis", "stop_axis"}, {"Out"});
......
......@@ -19,6 +19,7 @@ limitations under the License. */
// only can include the headers in paddle/pten/api dirs
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h"
namespace paddle {
......@@ -402,6 +403,7 @@ class ReshapeKernel {
auto *shape_tensor = ctx.HasInput("Shape")
? ctx.Input<framework::LoDTensor>("Shape")
: nullptr;
pten::ScalarArray pt_scalar_shape;
if (list_new_shape_tensor.size() > 0) {
// have shape tensor
std::vector<pten::DenseTensor> pt_vec_shape;
......@@ -417,22 +419,7 @@ class ReshapeKernel {
std::move(*(paddle::experimental::MakePtenDenseTensor(*tensor))));
}
}
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromVectorDT(dev_ctx, *pt_x.get(), pt_vec_shape, pt_out);
}
#endif
pt_scalar_shape = pten::ScalarArray(pt_vec_shape);
} else if (shape_tensor) {
std::unique_ptr<pten::DenseTensor> pt_shape;
if (platform::is_gpu_place(shape_tensor->place()) ||
......@@ -443,44 +430,27 @@ class ReshapeKernel {
} else {
pt_shape = paddle::experimental::MakePtenDenseTensor(*shape_tensor);
}
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromDT(dev_ctx, *pt_x.get(), *pt_shape.get(), pt_out);
}
#endif
pt_scalar_shape = pten::ScalarArray(*pt_shape.get());
} else {
auto &shape_attr = ctx.Attr<std::vector<int>>("shape");
const std::vector<int64_t> shape_vec(shape_attr.begin(),
shape_attr.end());
pt_scalar_shape = pten::ScalarArray(shape_attr);
}
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::ReshapeFromVectorVal(dev_ctx, *pt_x.get(), shape_vec, pt_out);
pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
}
#endif
}
// non-inplace need move all result from pt_out to out, inplace need set
// result dims.
if (in != out) {
......@@ -553,16 +523,16 @@ class Reshape2Op : public ReshapeOp {
framework::KernelSignature GetExpectedPtenKernelArgs(
const framework::ExecutionContext &ctx) const override {
std::string shape;
auto multi_inputs = ctx.MultiInput<framework::Tensor>("ShapeTensor");
if (multi_inputs.size() > 0) {
return framework::KernelSignature("reshape_mulhost", {"X", "ShapeTensor"},
{}, {"Out"});
shape = "ShapeTensor";
} else if (ctx.HasInput("Shape")) {
return framework::KernelSignature("reshape_host", {"X", "Shape"}, {},
{"Out"});
shape = "Shape";
} else {
return framework::KernelSignature("reshape", {"X"}, {"shape"}, {"Out"});
shape = "shape";
}
return framework::KernelSignature("reshape", {"X"}, {shape}, {"Out"});
}
};
......
......@@ -83,7 +83,7 @@ using multiply_kernel = void (*)(const DeviceContext&,
using reshape_kernel = void (*)(const DeviceContext&,
const DenseTensor&,
const std::vector<int64_t>&,
const ScalarArray&,
DenseTensor*);
using scale_kernel = void (*)(const DeviceContext&,
......
......@@ -28,6 +28,10 @@ class ScalarArrayBase {
ScalarArrayBase(const std::vector<int64_t>& vec) : array_(vec) {} // NOLINT
ScalarArrayBase(const std::vector<int32_t>& vec) { // NOLINT
array_.insert(array_.begin(), vec.begin(), vec.end());
}
ScalarArrayBase(std::initializer_list<int64_t> array_list)
: array_(array_list) {}
......@@ -43,7 +47,7 @@ class ScalarArrayBase {
ScalarArrayBase(const T& tensor) { // NOLINT
size_t n = tensor.numel();
array_.reserve(n);
switch (tensor.type()) {
switch (tensor.dtype()) {
case DataType::INT32:
AssignData(tensor.template data<int32_t>(), n);
break;
......@@ -55,7 +59,7 @@ class ScalarArrayBase {
"Data type error. Currently, The data type of ScalarArrayBase "
"only supports Tensor with int32 and int64, "
"but now received `",
tensor.type(),
tensor.dtype(),
"`.");
}
}
......
......@@ -60,7 +60,7 @@ DenseTensor Reshape(const ContextT& dev_ctx,
std::make_shared<paddle::experimental::DefaultAllocator>(
dev_ctx.GetPlace());
pten::DenseTensor dense_out(allocator, out_meta);
ReshapeFromVectorVal(dev_ctx, x, shape, &dense_out);
Reshape(dev_ctx, x, ScalarArray(shape), &dense_out);
return dense_out;
}
......
......@@ -227,6 +227,11 @@ DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta,
return return_meta;
}
DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta,
const ScalarArray& shape) {
return InferMetaFromVecValue(x_meta, shape.GetData());
}
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& axis,
bool keep_dim) {
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
// See Note [ Why still include the fluid headers? ]
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/tensor_meta.h"
namespace pten {
......@@ -50,6 +51,9 @@ DenseTensorMeta FullLikeInferMeta(const DenseTensorMeta& x_meta,
DenseTensorMeta InferMetaFromVecValue(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& shape);
DenseTensorMeta ReshapeInferMeta(const DenseTensorMeta& x_meta,
const ScalarArray& shape);
DenseTensorMeta ReduceInferMeta(const DenseTensorMeta& x_meta,
const std::vector<int64_t>& axis,
bool keep_dim);
......
......@@ -46,74 +46,27 @@ void FlattenWithXShape(const CPUContext& dev_ctx,
general::SetXShape(x, xshape);
}
void ReshapeFromVectorVal(const CPUContext& dev_ctx,
void Reshape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
const ScalarArray& shape,
DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape);
auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData());
if (x.data() == out->data() && x.numel() == out->numel()) {
out->Resize(out_meta.dims);
return;
}
pten::Copy(dev_ctx, x, false, out);
out->Resize(out_meta.dims);
}
void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape);
ReshapeFromVectorVal(dev_ctx, x, shape, out);
}
void ReshapeFromDT(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* out) {
auto* shape_data = shape.data<int>();
auto vector_shape =
std::vector<int64_t>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
out->ResetLoD(x.lod());
}
void ReshapeFromDTWithXShape(const CPUContext& dev_ctx,
void ReshapeWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape);
ReshapeFromDT(dev_ctx, x, shape, out);
}
void ReshapeFromVectorDT(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int64_t> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
paddle::framework::make_ddim({1}),
paddle::platform::errors::InvalidArgument(
"If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor.dims()));
vector_shape.push_back(*tensor.data<int32_t>());
}
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}
void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape);
ReshapeFromVectorDT(dev_ctx, x, shape, out);
Reshape(dev_ctx, x, shape, out);
}
template <typename T>
......@@ -130,8 +83,6 @@ void Cast(const CPUContext& dev_ctx,
} // namespace pten
// TODO(yuanrisheng): "flatten_contiguous_range" is compatible with old kernel
// architecture, kernel_name should be "flatten".
PT_REGISTER_KERNEL(flatten,
CPU,
ANY,
......@@ -142,7 +93,7 @@ PT_REGISTER_KERNEL(flatten,
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten_mid,
PT_REGISTER_KERNEL(flatten_with_xshape,
CPU,
ANY,
pten::FlattenWithXShape,
......@@ -171,33 +122,8 @@ PT_REGISTER_KERNEL(cast,
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid,
CPU,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CPU, ANY, pten::ReshapeFromDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid,
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CPU,
ANY,
pten::ReshapeFromDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost,
CPU,
ANY,
pten::ReshapeFromVectorDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid,
CPU,
ANY,
pten::ReshapeFromVectorDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
pten::ReshapeWithXShape) {}
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -38,36 +39,14 @@ void Cast(const CPUContext& dev_ctx,
DataType in_dtype,
DenseTensor* out);
void ReshapeFromDT(const CPUContext& dev_ctx,
void Reshape(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
const ScalarArray& shape,
DenseTensor* out);
void ReshapeFromVectorVal(const CPUContext& dev_ctx,
void ReshapeWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* out);
void ReshapeFromVectorDT(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out);
void ReshapeFromDTWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorValWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorDTWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out);
......
......@@ -46,74 +46,27 @@ void FlattenWithXShape(const CUDAContext& dev_ctx,
general::SetXShape(x, xshape);
}
void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
void Reshape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
const ScalarArray& shape,
DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape);
auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData());
if (x.data() == out->data() && x.numel() == out->numel()) {
out->Resize(out_meta.dims);
return;
}
pten::Copy(dev_ctx, x, false, out);
out->Resize(out_meta.dims);
}
void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape);
ReshapeFromVectorVal(dev_ctx, x, shape, out);
}
void ReshapeFromDT(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* out) {
auto* shape_data = shape.data<int>();
auto vector_shape =
std::vector<int64_t>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
out->ResetLoD(x.lod());
}
void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx,
void ReshapeWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape);
ReshapeFromDT(dev_ctx, x, shape, out);
}
void ReshapeFromVectorDT(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int64_t> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
paddle::framework::make_ddim({1}),
paddle::platform::errors::InvalidArgument(
"If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor.dims()));
vector_shape.push_back(*tensor.data<int32_t>());
}
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}
void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape);
ReshapeFromVectorDT(dev_ctx, x, shape, out);
Reshape(dev_ctx, x, shape, out);
}
template <typename T>
......@@ -142,7 +95,7 @@ PT_REGISTER_KERNEL(flatten,
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten_mid,
PT_REGISTER_KERNEL(flatten_with_xshape,
CUDA,
ANY,
pten::FlattenWithXShape,
......@@ -179,33 +132,8 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mid,
CUDA,
ANY,
pten::ReshapeFromVectorValWithXShape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host, CUDA, ANY, pten::ReshapeFromDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_host_mid,
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CUDA,
ANY,
pten::ReshapeFromDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost,
CUDA,
ANY,
pten::ReshapeFromVectorDT) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_mulhost_mid,
CUDA,
ANY,
pten::ReshapeFromVectorDTWithXShape) {
kernel->InputAt(1).SetBackend(pten::Backend::CPU);
kernel->InputAt(1).SetDataType(paddle::experimental::DataType::INT32);
}
pten::ReshapeWithXShape) {}
......@@ -17,6 +17,7 @@
// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -41,36 +42,14 @@ void Cast(const CUDAContext& dev_ctx,
DataType in_dtype,
DenseTensor* out);
void ReshapeFromDT(const CUDAContext& dev_ctx,
void Reshape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
const ScalarArray& shape,
DenseTensor* out);
void ReshapeFromVectorVal(const CUDAContext& dev_ctx,
void ReshapeWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* out);
void ReshapeFromVectorDT(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out);
void ReshapeFromDTWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorValWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* xshape,
DenseTensor* out);
void ReshapeFromVectorDTWithXShape(const CUDAContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out);
......
......@@ -51,46 +51,27 @@ void FlattenWithXShape(const XPUContext& dev_ctx,
xshape->ResetLoD(x.lod());
}
void ReshapeFromVectorVal(const XPUContext& dev_ctx,
void Reshape(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
const ScalarArray& shape,
DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape);
if (&x == out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData());
if (x.data() == out->data() && x.numel() == out->numel()) {
out->Resize(out_meta.dims);
return;
}
pten::Copy(dev_ctx, x, false, out);
out->Resize(out_meta.dims);
out->ResetLoD(x.lod());
}
void ReshapeFromDT(const XPUContext& dev_ctx,
void ReshapeWithXShape(const XPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out) {
auto* shape_data = shape.data<int>();
auto vector_shape =
std::vector<int64_t>(shape_data, shape_data + shape.numel());
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
}
void ReshapeFromVectorDT(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
DenseTensor* out) {
std::vector<int64_t> vector_shape;
for (auto& tensor : shape) {
PADDLE_ENFORCE_EQ(
tensor.dims(),
paddle::framework::make_ddim({1}),
paddle::platform::errors::InvalidArgument(
"If the element type of 'shape' in ReshapeOp is Tensor, "
"the element's shape must be [1]. But received the element's shape "
"is [%s]",
tensor.dims()));
vector_shape.push_back(*tensor.data<int32_t>());
}
ReshapeFromVectorVal(dev_ctx, x, vector_shape, out);
general::SetXShape(x, xshape);
Reshape(dev_ctx, x, shape, out);
}
} // namespace pten
......@@ -107,7 +88,7 @@ PT_REGISTER_KERNEL(flatten,
int,
int64_t) {}
PT_REGISTER_KERNEL(flatten_mid,
PT_REGISTER_KERNEL(flatten_with_xshape,
XPU,
ANY,
pten::FlattenWithXShape,
......@@ -119,4 +100,4 @@ PT_REGISTER_KERNEL(flatten_mid,
int,
int64_t) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::ReshapeFromVectorVal) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape, XPU, ANY, pten::Reshape) {}
......@@ -16,6 +16,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
......@@ -33,19 +34,15 @@ void Flatten(const XPUContext& dev_ctx,
int stop_axis,
DenseTensor* out);
void ReshapeFromDT(const XPUContext& dev_ctx,
void Reshape(const XPUContext& dev_ctx,
const DenseTensor& x,
const DenseTensor& shape,
const ScalarArray& shape,
DenseTensor* out);
void ReshapeFromVectorVal(const XPUContext& dev_ctx,
void ReshapeWithXShape(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape,
DenseTensor* out);
void ReshapeFromVectorDT(const XPUContext& dev_ctx,
const DenseTensor& x,
const std::vector<DenseTensor>& shape,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out);
} // namespace pten
......
......@@ -103,10 +103,10 @@
invoke : full_like(x, 1, dtype, place, layout)
- api : reshape
args : (const Tensor& x, const std::vector<int64_t>& shape)
args : (const Tensor& x, const ScalarArray& shape)
output : Tensor
infer_meta :
func : InferMetaFromVecValue
func : ReshapeInferMeta
kernel :
func : reshape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册