未验证 提交 3d3bc681 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen]Move manipulation mid to new directory and rename flatten/reshape kernel (#38730)

* move mid api and rename kernel

* use empty kernel
上级 ee813e34
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h" #include "paddle/pten/kernels/cast_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/pooling.h" #include "paddle/fluid/operators/math/pooling.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h" #include "paddle/pten/kernels/flatten_kernel.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -134,8 +134,8 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> { ...@@ -134,8 +134,8 @@ class FlattenContiguousRangeKernel : public framework::OpKernel<T> {
auto pt_out = paddle::experimental::MakePtenDenseTensor(*out); auto pt_out = paddle::experimental::MakePtenDenseTensor(*out);
// call new kernel // call new kernel
pten::Flatten<T, DeviceContext>(dev_ctx, *pt_x.get(), start_axis, stop_axis, pten::FlattenKernel<T, DeviceContext>(dev_ctx, *pt_x.get(), start_axis,
pt_out.get()); stop_axis, pt_out.get());
} }
}; };
......
...@@ -21,7 +21,7 @@ limitations under the License. */ ...@@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/pten/api/lib/utils/tensor_utils.h" #include "paddle/pten/api/lib/utils/tensor_utils.h"
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/include/manipulation.h" #include "paddle/pten/kernels/reshape_kernel.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class InferShapeContext; class InferShapeContext;
...@@ -438,18 +438,18 @@ class ReshapeKernel { ...@@ -438,18 +438,18 @@ class ReshapeKernel {
} }
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>(); auto &dev_ctx = ctx.device_context<platform::CPUDeviceContext>();
pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (platform::is_gpu_place(ctx.GetPlace())) { if (platform::is_gpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>(); auto &dev_ctx = ctx.device_context<platform::CUDADeviceContext>();
pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
} }
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
if (platform::is_xpu_place(ctx.GetPlace())) { if (platform::is_xpu_place(ctx.GetPlace())) {
auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>(); auto &dev_ctx = ctx.device_context<platform::XPUDeviceContext>();
pten::Reshape(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out); pten::ReshapeKernel(dev_ctx, *pt_x.get(), pt_scalar_shape, pt_out);
} }
#endif #endif
// non-inplace need move all result from pt_out to out, inplace need set // non-inplace need move all result from pt_out to out, inplace need set
......
...@@ -18,5 +18,4 @@ limitations under the License. */ ...@@ -18,5 +18,4 @@ limitations under the License. */
#include "paddle/pten/include/core.h" #include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h" #include "paddle/pten/include/infermeta.h"
#include "paddle/pten/include/linalg.h" #include "paddle/pten/include/linalg.h"
#include "paddle/pten/include/manipulation.h"
#include "paddle/pten/include/math.h" #include "paddle/pten/include/math.h"
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// See Note: [ How do we organize the kernel directory ]
#include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/kernels/flatten_kernel.h"
#include "paddle/pten/kernels/reshape_kernel.h"
namespace pten {
template <typename T, typename ContextT>
DenseTensor Flatten(const ContextT& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis) {
auto out_meta = FlattenInferMeta(x.meta(), start_axis, stop_axis);
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Flatten<T, ContextT>(dev_ctx, x, start_axis, stop_axis, &dense_out);
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor Reshape(const ContextT& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape);
pten::DenseTensor dense_out(
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
Reshape<ContextT>(dev_ctx, x, ScalarArray(shape), &dense_out);
return dense_out;
}
} // namespace pten
...@@ -22,11 +22,11 @@ ...@@ -22,11 +22,11 @@
namespace pten { namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void Flatten(const Context& dev_ctx, void FlattenKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
DenseTensor* out) { DenseTensor* out) {
auto out_dims = out->dims(); auto out_dims = out->dims();
pten::Copy(dev_ctx, x, false, out); pten::Copy(dev_ctx, x, false, out);
out->Resize(out_dims); out->Resize(out_dims);
...@@ -42,7 +42,7 @@ void FlattenWithXShape(const Context& dev_ctx, ...@@ -42,7 +42,7 @@ void FlattenWithXShape(const Context& dev_ctx,
int stop_axis, int stop_axis,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
Flatten<T, Context>(dev_ctx, x, start_axis, stop_axis, out); FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, out);
funcs::SetXShape(x, xshape); funcs::SetXShape(x, xshape);
} }
...@@ -51,7 +51,7 @@ void FlattenWithXShape(const Context& dev_ctx, ...@@ -51,7 +51,7 @@ void FlattenWithXShape(const Context& dev_ctx,
PT_REGISTER_CTX_KERNEL(flatten, PT_REGISTER_CTX_KERNEL(flatten,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::Flatten, pten::FlattenKernel,
float, float,
double, double,
uint8_t, uint8_t,
...@@ -74,7 +74,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape, ...@@ -74,7 +74,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
PT_REGISTER_CTX_KERNEL(flatten, PT_REGISTER_CTX_KERNEL(flatten,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::Flatten, pten::FlattenKernel,
float, float,
paddle::platform::float16, paddle::platform::float16,
double, double,
...@@ -100,7 +100,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape, ...@@ -100,7 +100,7 @@ PT_REGISTER_CTX_KERNEL(flatten_with_xshape,
PT_REGISTER_CTX_KERNEL(flatten, PT_REGISTER_CTX_KERNEL(flatten,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
pten::Flatten, pten::FlattenKernel,
float, float,
paddle::platform::float16, paddle::platform::float16,
double, double,
......
...@@ -15,15 +15,17 @@ limitations under the License. */ ...@@ -15,15 +15,17 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten { namespace pten {
template <typename T, typename Context> template <typename T, typename Context>
void Flatten(const Context& dev_ctx, void FlattenKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context> template <typename T, typename Context>
void FlattenWithXShape(const Context& dev_ctx, void FlattenWithXShape(const Context& dev_ctx,
...@@ -33,4 +35,15 @@ void FlattenWithXShape(const Context& dev_ctx, ...@@ -33,4 +35,15 @@ void FlattenWithXShape(const Context& dev_ctx,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape); DenseTensor* xshape);
template <typename T, typename Context>
DenseTensor Flatten(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis) {
auto out_meta = FlattenInferMeta(x.meta(), start_axis, stop_axis);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
FlattenKernel<T, Context>(dev_ctx, x, start_axis, stop_axis, &dense_out);
return dense_out;
}
} // namespace pten } // namespace pten
...@@ -22,10 +22,10 @@ ...@@ -22,10 +22,10 @@
namespace pten { namespace pten {
template <typename Context> template <typename Context>
void Reshape(const Context& dev_ctx, void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* out) { DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData()); auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData());
if (x.data() == out->data() && x.numel() == out->numel()) { if (x.data() == out->data() && x.numel() == out->numel()) {
out->Resize(out_meta.dims); out->Resize(out_meta.dims);
...@@ -43,13 +43,16 @@ void ReshapeWithXShape(const Context& dev_ctx, ...@@ -43,13 +43,16 @@ void ReshapeWithXShape(const Context& dev_ctx,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out) { DenseTensor* out) {
funcs::SetXShape(x, xshape); funcs::SetXShape(x, xshape);
Reshape(dev_ctx, x, shape, out); ReshapeKernel(dev_ctx, x, shape, out);
} }
} // namespace pten } // namespace pten
PT_REGISTER_GENERAL_KERNEL( PT_REGISTER_GENERAL_KERNEL(reshape,
reshape, CPU, ALL_LAYOUT, pten::Reshape<pten::CPUContext>, ALL_DTYPE) {} CPU,
ALL_LAYOUT,
pten::ReshapeKernel<pten::CPUContext>,
ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape, PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -57,8 +60,11 @@ PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape, ...@@ -57,8 +60,11 @@ PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
ALL_DTYPE) {} ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_GENERAL_KERNEL( PT_REGISTER_GENERAL_KERNEL(reshape,
reshape, GPU, ALL_LAYOUT, pten::Reshape<pten::GPUContext>, ALL_DTYPE) {} GPU,
ALL_LAYOUT,
pten::ReshapeKernel<pten::GPUContext>,
ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape, PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
...@@ -67,8 +73,11 @@ PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape, ...@@ -67,8 +73,11 @@ PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PT_REGISTER_GENERAL_KERNEL( PT_REGISTER_GENERAL_KERNEL(reshape,
reshape, XPU, ALL_LAYOUT, pten::Reshape<pten::XPUContext>, ALL_DTYPE) {} XPU,
ALL_LAYOUT,
pten::ReshapeKernel<pten::XPUContext>,
ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape, PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
......
...@@ -16,14 +16,16 @@ limitations under the License. */ ...@@ -16,14 +16,16 @@ limitations under the License. */
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/empty_kernel.h"
namespace pten { namespace pten {
template <typename Context> template <typename Context>
void Reshape(const Context& dev_ctx, void ReshapeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* out); DenseTensor* out);
template <typename Context> template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx, void ReshapeWithXShape(const Context& dev_ctx,
...@@ -32,4 +34,14 @@ void ReshapeWithXShape(const Context& dev_ctx, ...@@ -32,4 +34,14 @@ void ReshapeWithXShape(const Context& dev_ctx,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out); DenseTensor* out);
template <typename T, typename Context>
DenseTensor Reshape(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int64_t>& shape) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape);
auto dense_out = Empty<T, Context>(dev_ctx, std::move(out_meta));
ReshapeKernel<Context>(dev_ctx, x, ScalarArray(shape), &dense_out);
return dense_out;
}
} // namespace pten } // namespace pten
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/include/manipulation.h" #include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/common/data_type.h" #include "paddle/pten/common/data_type.h"
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/include/manipulation.h" #include "paddle/pten/kernels/flatten_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
......
...@@ -15,7 +15,7 @@ limitations under the License. */ ...@@ -15,7 +15,7 @@ limitations under the License. */
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <memory> #include <memory>
#include "paddle/pten/include/manipulation.h" #include "paddle/pten/kernels/reshape_kernel.h"
#include "paddle/pten/api/lib/utils/allocator.h" #include "paddle/pten/api/lib/utils/allocator.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册