未验证 提交 49216134 编写于 作者: Y YuanRisheng 提交者: GitHub

[PTen]move reshape kernel according to new directory (#38432)

* move reshape

* fix compile bugs

* delete manipulation file

* fix compile bugs
上级 113c8b93
...@@ -28,13 +28,10 @@ get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS) ...@@ -28,13 +28,10 @@ get_property(pten_kernels GLOBAL PROPERTY PTEN_KERNELS)
# keep this message for debug, remove it later if needless # keep this message for debug, remove it later if needless
message(STATUS "All standard pten kernels: ${pten_kernels}") message(STATUS "All standard pten kernels: ${pten_kernels}")
set(PTEN_DEPS ${PTEN_DEPS} ${pten_kernels}) set(PTEN_DEPS ${PTEN_DEPS} ${pten_kernels})
set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu manipulation_cpu) set(PTEN_DEPS ${PTEN_DEPS} math_cpu linalg_cpu)
set(PTEN_DEPS ${PTEN_DEPS} nary unary binary) set(PTEN_DEPS ${PTEN_DEPS} nary unary binary)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
set(PTEN_DEPS ${PTEN_DEPS} math_gpu linalg_gpu manipulation_gpu) set(PTEN_DEPS ${PTEN_DEPS} math_gpu linalg_gpu)
endif()
if(WITH_XPU)
set(PTEN_DEPS ${PTEN_DEPS} manipulation_xpu)
endif() endif()
cc_library(pten SRCS all.cc DEPS ${PTEN_DEPS}) cc_library(pten SRCS all.cc DEPS ${PTEN_DEPS})
...@@ -21,15 +21,9 @@ limitations under the License. */ ...@@ -21,15 +21,9 @@ limitations under the License. */
// file name of the kernel, and this header file will be removed // file name of the kernel, and this header file will be removed
PT_DECLARE_KERNEL(matmul, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(matmul, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(reshape, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(mean, CPU, ALL_LAYOUT); PT_DECLARE_KERNEL(mean, CPU, ALL_LAYOUT);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT); PT_DECLARE_KERNEL(matmul, GPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(reshape, GPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT); PT_DECLARE_KERNEL(mean, GPU, ALL_LAYOUT);
#endif #endif
#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(reshape, XPU, ALL_LAYOUT);
#endif
...@@ -18,10 +18,8 @@ ...@@ -18,10 +18,8 @@
#include "paddle/pten/api/lib/utils/storage.h" #include "paddle/pten/api/lib/utils/storage.h"
#include "paddle/pten/include/infermeta.h" #include "paddle/pten/include/infermeta.h"
#include "paddle/pten/kernels/cast_kernel.h" #include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/kernels/cpu/manipulation.h"
#include "paddle/pten/kernels/flatten_kernel.h" #include "paddle/pten/kernels/flatten_kernel.h"
#include "paddle/pten/kernels/gpu/manipulation.h" #include "paddle/pten/kernels/reshape_kernel.h"
#include "paddle/pten/kernels/xpu/manipulation.h"
namespace pten { namespace pten {
...@@ -62,7 +60,7 @@ DenseTensor Reshape(const ContextT& dev_ctx, ...@@ -62,7 +60,7 @@ DenseTensor Reshape(const ContextT& dev_ctx,
pten::make_intrusive<paddle::experimental::SharedStorage>( pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()), dev_ctx.GetPlace()),
std::move(out_meta)); std::move(out_meta));
Reshape(dev_ctx, x, ScalarArray(shape), &dense_out); Reshape<ContextT>(dev_ctx, x, ScalarArray(shape), &dense_out);
return dense_out; return dense_out;
} }
......
...@@ -28,8 +28,6 @@ set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory) ...@@ -28,8 +28,6 @@ set(COMMON_KERNEL_DEPS dense_tensor kernel_context kernel_factory)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function)
# auto build kernel targets by cmake # auto build kernel targets by cmake
register_kernels(EXCLUDES flatten_kernel DEPS ${COMMON_KERNEL_DEPS}) register_kernels(DEPS ${COMMON_KERNEL_DEPS})
# TODO(chenweihang): auto parse compile deps by include headers later
kernel_library(flatten_kernel DEPS ${COMMON_KERNEL_DEPS} copy_kernel unary)
copy_if_different(${kernel_declare_file} ${kernel_declare_file_final}) copy_if_different(${kernel_declare_file} ${kernel_declare_file_final})
cc_library(math_cpu SRCS math.cc DEPS dense_tensor kernel_context kernel_factory eigen_function blas pten_transpose_cpu cast_kernel) cc_library(math_cpu SRCS math.cc DEPS dense_tensor kernel_context kernel_factory eigen_function blas pten_transpose_cpu cast_kernel)
cc_library(linalg_cpu SRCS linalg.cc DEPS dense_tensor kernel_context kernel_factory) cc_library(linalg_cpu SRCS linalg.cc DEPS dense_tensor kernel_context kernel_factory)
cc_library(manipulation_cpu SRCS manipulation.cc DEPS dense_tensor kernel_context kernel_factory copy_kernel unary)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/cpu/manipulation.h"
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/copy_kernel.h"
#include "paddle/pten/kernels/hybird/general/manipulation.h"
namespace pten {
void Reshape(const CPUContext& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData());
if (x.data() == out->data() && x.numel() == out->numel()) {
out->Resize(out_meta.dims);
return;
}
pten::Copy(dev_ctx, x, false, out);
out->Resize(out_meta.dims);
out->ResetLoD(x.lod());
}
void ReshapeWithXShape(const CPUContext& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape);
Reshape(dev_ctx, x, shape, out);
}
} // namespace pten
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape, CPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape_with_xshape, CPU, ALL_LAYOUT, pten::ReshapeWithXShape, ALL_DTYPE) {}
...@@ -43,7 +43,7 @@ void FlattenWithXShape(const ContextT& dev_ctx, ...@@ -43,7 +43,7 @@ void FlattenWithXShape(const ContextT& dev_ctx,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
Flatten<T, ContextT>(dev_ctx, x, start_axis, stop_axis, out); Flatten<T, ContextT>(dev_ctx, x, start_axis, stop_axis, out);
functions::SetXShape(x, xshape); funcs::SetXShape(x, xshape);
} }
} // namespace pten } // namespace pten
......
...@@ -17,7 +17,7 @@ limitations under the License. */ ...@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
namespace pten { namespace pten {
namespace functions { namespace funcs {
inline void SetXShape(const DenseTensor& x, DenseTensor* xshape) { inline void SetXShape(const DenseTensor& x, DenseTensor* xshape) {
const auto& in_dims = x.meta().dims; const auto& in_dims = x.meta().dims;
...@@ -30,5 +30,5 @@ inline void SetXShape(const DenseTensor& x, DenseTensor* xshape) { ...@@ -30,5 +30,5 @@ inline void SetXShape(const DenseTensor& x, DenseTensor* xshape) {
xshape->ResetLoD(x.meta().lod); xshape->ResetLoD(x.meta().lod);
} }
} // namespace functions } // namespace funcs
} // namespace pten } // namespace pten
if(WITH_GPU) if(WITH_GPU)
nv_library(math_gpu SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_gpu cast_kernel) nv_library(math_gpu SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_gpu cast_kernel copy_kernel)
nv_library(linalg_gpu SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory) nv_library(linalg_gpu SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory)
nv_library(manipulation_gpu SRCS manipulation.cu DEPS dense_tensor kernel_context kernel_factory copy_kernel unary)
elseif(WITH_ROCM) elseif(WITH_ROCM)
hip_library(math_gpu SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_gpu cast_kernel) hip_library(math_gpu SRCS math.cu DEPS eigen_function dense_tensor convert_utils kernel_context kernel_factory pten_transpose_gpu cast_kernel copy_kernel)
hip_library(linalg_gpu SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory) hip_library(linalg_gpu SRCS linalg.cu DEPS eigen_function dense_tensor kernel_context kernel_factory)
hip_library(manipulation_gpu SRCS manipulation.cu DEPS dense_tensor kernel_context kernel_factory copy_kernel unary)
endif() endif()
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
// CUDA and HIP use same api
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten {
void Reshape(const GPUContext& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out);
void ReshapeWithXShape(const GPUContext& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out);
} // namespace pten
#endif
...@@ -12,15 +12,17 @@ ...@@ -12,15 +12,17 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/pten/kernels/gpu/manipulation.h" #include "paddle/pten/kernels/reshape_kernel.h"
#include "paddle/pten/backends/all_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/infermeta/unary.h" #include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/copy_kernel.h" #include "paddle/pten/kernels/copy_kernel.h"
#include "paddle/pten/kernels/hybird/general/manipulation.h" #include "paddle/pten/kernels/funcs/common_shape.h"
namespace pten { namespace pten {
void Reshape(const GPUContext& dev_ctx, template <typename ContextT>
void Reshape(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* out) { DenseTensor* out) {
...@@ -34,18 +36,42 @@ void Reshape(const GPUContext& dev_ctx, ...@@ -34,18 +36,42 @@ void Reshape(const GPUContext& dev_ctx,
out->ResetLoD(x.lod()); out->ResetLoD(x.lod());
} }
void ReshapeWithXShape(const GPUContext& dev_ctx, template <typename ContextT>
void ReshapeWithXShape(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* xshape, DenseTensor* xshape,
DenseTensor* out) { DenseTensor* out) {
general::SetXShape(x, xshape); funcs::SetXShape(x, xshape);
Reshape(dev_ctx, x, shape, out); Reshape(dev_ctx, x, shape, out);
} }
} // namespace pten } // namespace pten
PT_REGISTER_NO_TEMPLATE_KERNEL( PT_REGISTER_GENERAL_KERNEL(
reshape, GPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {} reshape, CPU, ALL_LAYOUT, pten::Reshape<pten::CPUContext>, ALL_DTYPE) {}
PT_REGISTER_NO_TEMPLATE_KERNEL( PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
reshape_with_xshape, GPU, ALL_LAYOUT, pten::ReshapeWithXShape, ALL_DTYPE) {} CPU,
ALL_LAYOUT,
pten::ReshapeWithXShape<pten::CPUContext>,
ALL_DTYPE) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_GENERAL_KERNEL(
reshape, GPU, ALL_LAYOUT, pten::Reshape<pten::GPUContext>, ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
GPU,
ALL_LAYOUT,
pten::ReshapeWithXShape<pten::GPUContext>,
ALL_DTYPE) {}
#endif
#ifdef PADDLE_WITH_XPU
PT_REGISTER_GENERAL_KERNEL(
reshape, XPU, ALL_LAYOUT, pten::Reshape<pten::XPUContext>, ALL_DTYPE) {}
PT_REGISTER_GENERAL_KERNEL(reshape_with_xshape,
XPU,
ALL_LAYOUT,
pten::ReshapeWithXShape<pten::XPUContext>,
ALL_DTYPE) {}
#endif
...@@ -14,19 +14,19 @@ limitations under the License. */ ...@@ -14,19 +14,19 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/common/scalar_array.h" #include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten { namespace pten {
void Reshape(const CPUContext& dev_ctx, template <typename ContextT>
void Reshape(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* out); DenseTensor* out);
void ReshapeWithXShape(const CPUContext& dev_ctx, template <typename ContextT>
void ReshapeWithXShape(const ContextT& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* xshape, DenseTensor* xshape,
......
cc_library(manipulation_xpu SRCS manipulation.cc DEPS dense_tensor kernel_context kernel_factory copy_kernel unary)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/xpu/manipulation.h"
#include "paddle/pten/infermeta/unary.h"
#include "paddle/pten/kernels/copy_kernel.h"
#include "paddle/pten/kernels/hybird/general/manipulation.h"
namespace pten {
void Reshape(const XPUContext& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out) {
auto out_meta = InferMetaFromVecValue(x.meta(), shape.GetData());
if (x.data() == out->data() && x.numel() == out->numel()) {
out->Resize(out_meta.dims);
return;
}
pten::Copy(dev_ctx, x, false, out);
out->Resize(out_meta.dims);
out->ResetLoD(x.lod());
}
void ReshapeWithXShape(const XPUContext& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out) {
general::SetXShape(x, xshape);
Reshape(dev_ctx, x, shape, out);
}
} // namespace pten
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape, XPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_XPU
#include "paddle/pten/backends/xpu/xpu_context.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/kernel_registry.h"
namespace pten {
void Reshape(const XPUContext& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out);
void ReshapeWithXShape(const XPUContext& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* xshape,
DenseTensor* out);
} // namespace pten
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册