未验证 提交 3d3bc681 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen]Move manipulation mid to new directory and rename flatten/reshape kernel (#38730)

* move mid api and rename kernel

* use empty kernel
上级 ee813e34
......@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h"
#include "paddle/pten/kernels/cast_kernel.h"
namespace paddle {
namespace operators {
......
......@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h"
#include "paddle/pten/kernels/flatten_kernel.h"
namespace paddle {
namespace operators {
......@@ -134,8 +134,8 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel
pten::Flatten<T, DeviceContext>(dev_ctx, *pt_x.get(), start_axis, stop_axis,
pt_out.get());
pten::FlattenKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), start_axis,
stop_axis, pt_out.get());
}
};
......
......@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h"
#include "paddle/pten/kernels/reshape_kernel.h"
namespace paddle {
namespace framework {
class InferShapeContext;
......@@ -438,18 +438,18 @@ class ReshapeKernel {
}
if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
}
#endif
#ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
}
#endif
// non-inplace need move all result from pt_out to out, inplace need set
......
......@@ -18,5 +18,4 @@ limitations under the License. */
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/include/linalg.h"
#include "paddle/pten/include/manipulation.h"
#include "paddle/pten/include/math.h"
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// See Note: [ How do we organize the kernel directory ]
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/kernels/flatten_kernel.h"
#include "paddle/pten/kernels/reshape_kernel.h"
namespace pten {
template <typename T, typename ContextT>
DenseTensor Flatten(const ContextT& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis) {
auto out_meta = FlattenInferMeta(x.meta(), start_axis, stop_axis);
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Flatten<T, ContextT>(dev_ctx, x, start_axis, stop_axis, &dense_out);
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor Reshape(const ContextT& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape);
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Reshape<ContextT>(dev_ctx, x, ScalarArray(shape), &dense_out);
return dense_out;
}
} // namespace pten
......@@ -22,11 +22,11 @@
namespace pten {
template <typename T, typename Context>
void Flatten(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out) {
void FlattenKernel(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out) {
auto out_dims = out->dims();
pten::Copy(dev_ctx, x, false, out);
out->Resize(out_dims);
......@@ -42,7 +42,7 @@ void FlattenWithXShape(const Context& dev_ctx,
int stop_axis,
DenseTensor* out,
DenseTensor* xshape) {
Flatten<T, Context>(dev_ctx, x, start_axis, stop_axis, out);
FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, out);
funcs::SetXShape(x, xshape);
}
......@@ -51,7 +51,7 @@ void FlattenWithXShape(const Context& dev_ctx,
PT_REGISTER_CTX_KERNEL(flatten,
CPU,
ALL_LAYOUT,
pten::Flatten,
pten::FlattenKernel,
float,
double,
uint8_t,
......@@ -74,7 +74,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
PT_REGISTER_CTX_KERNEL(flatten,
GPU,
ALL_LAYOUT,
pten::Flatten,
pten::FlattenKernel,
float,
paddle::platform::float16,
double,
......@@ -100,7 +100,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
PT_REGISTER_CTX_KERNEL(flatten,
XPU,
ALL_LAYOUT,
pten::Flatten,
pten::FlattenKernel,
float,
paddle::platform::float16,
double,
......
......@@ -15,15 +15,17 @@ limitations under the License. */
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
template <typename T, typename Context>
void Flatten(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out);
void FlattenKernel(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out);
template <typename T, typename Context>
void FlattenWithXShape(const Context& dev_ctx,
......@@ -33,4 +35,15 @@ void FlattenWithXShape(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* xshape);
template <typename T, typename Context>
DenseTensor Flatten(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis) {
auto out_meta = FlattenInferMeta(x.meta(), start_axis, stop_axis);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, &dense_out);
return dense_out;
}
} // namespace pten
......@@ -22,10 +22,10 @@
namespace pten {
template <typename Context>
void Reshape(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out) {
void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData());
if (x.data() == out->data() && x.numel() == out->numel()) {
out->Resize(out_meta.dims);
......@@ -43,13 +43,16 @@ void ReshapeWithXShape(const Context& dev_ctx,
DenseTensor* xshape,
DenseTensor* out) {
funcs::SetXShape(x, xshape);
Reshape(dev_ctx, x, shape, out);
ReshapeKernel(dev_ctx, x, shape, out);
}
} // namespace pten
PT_REGISTER_GENERAL_KERNEL(
reshape, CPU, ALL_LAYOUT, pten::Reshape<pten::CPUContext>, ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape,
CPU,
ALL_LAYOUT,
pten::ReshapeKernel<pten::CPUContext>,
ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
CPU,
ALL_LAYOUT,
......@@ -57,8 +60,11 @@ PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_GENERAL_KERNEL(
reshape, GPU, ALL_LAYOUT, pten::Reshape<pten::GPUContext>, ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape,
GPU,
ALL_LAYOUT,
pten::ReshapeKernel<pten::GPUContext>,
ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
GPU,
ALL_LAYOUT,
......@@ -67,8 +73,11 @@ PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
#endif
#ifdef PADDLE_WITH_XPU
PT_REGISTER_GENERAL_KERNEL(
reshape, XPU, ALL_LAYOUT, pten::Reshape<pten::XPUContext>, ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape,
XPU,
ALL_LAYOUT,
pten::ReshapeKernel<pten::XPUContext>,
ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
XPU,
ALL_LAYOUT,
......
......@@ -16,14 +16,16 @@ limitations under the License. */
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten {
template <typename Context>
void Reshape(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out);
void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out);
template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx,
......@@ -32,4 +34,14 @@ void ReshapeWithXShape(const Context& dev_ctx,
DenseTensor* xshape,
DenseTensor* out);
template <typename T, typename Context>
DenseTensor Reshape(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
ReshapeKernel<Context>(dev_ctx, x, ScalarArray(shape), &dense_out);
return dense_out;
}
} // namespace pten
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/include/manipulation.h"
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/common/data_type.h"
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/include/manipulation.h"
#include "paddle/pten/kernels/flatten_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
......
......@@ -15,7 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/pten/include/manipulation.h"
#include "paddle/pten/kernels/reshape_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册