未验证 提交 dba694f4 编写于 作者: L Leo Chen 提交者: GitHub

[phi] move unbind to phi (#39789)

* move unbind to phi

* revert infer shape

* add header file

* move concat_and_split to phi
上级 1a1a2ce8
......@@ -6,9 +6,9 @@ endif()
# please add new math_library in alphabetical order
if (WITH_ASCEND_CL)
math_library(concat_and_split DEPS npu_op_runner)
math_library(concat_and_split DEPS concat_and_split_functor npu_op_runner)
else()
math_library(concat_and_split)
math_library(concat_and_split DEPS concat_and_split_functor)
endif()
math_library(context_project DEPS im2col math_function)
math_library(cross_entropy)
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/kernels/cpu/concat_and_split.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#endif
......@@ -46,9 +46,8 @@ class ConcatFunctor<platform::CPUDeviceContext, T> {
void operator()(const platform::CPUDeviceContext& context,
const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) {
std::vector<phi::DenseTensor> pt_input{input.begin(), input.end()};
phi::ConcatImpl<T, platform::CPUDeviceContext>(context, pt_input, axis,
output);
phi::funcs::ConcatFunctor<phi::CPUContext, T> functor;
functor(context, input, axis, output);
}
};
......@@ -63,11 +62,8 @@ class SplitFunctor<platform::CPUDeviceContext, T> {
const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs,
const int axis, std::vector<framework::Tensor*>* outputs) {
std::vector<const phi::DenseTensor*> pt_ref_inputs{ref_inputs.begin(),
ref_inputs.end()};
std::vector<phi::DenseTensor*> pt_outputs{outputs->begin(), outputs->end()};
phi::SplitImpl<T, platform::CPUDeviceContext>(context, input, pt_ref_inputs,
axis, &pt_outputs);
phi::funcs::SplitFunctor<phi::CPUContext, T> functor;
functor(context, input, ref_inputs, axis, outputs);
}
};
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/phi/kernels/gpu/concat_and_split.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
namespace paddle {
namespace operators {
namespace math {
......@@ -29,10 +29,8 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const std::vector<framework::Tensor>& input, int axis,
framework::Tensor* output) {
std::vector<phi::DenseTensor> pt_input{input.begin(), input.end()};
phi::ConcatImpl<T, platform::CUDADeviceContext>(context, pt_input, axis,
output);
phi::funcs::ConcatFunctor<phi::GPUContext, T> functor;
functor(context, input, axis, output);
}
};
......@@ -43,16 +41,12 @@ class ConcatFunctor<platform::CUDADeviceContext, T> {
template <typename T>
class SplitFunctor<platform::CUDADeviceContext, T> {
public:
SplitFunctor();
void operator()(const platform::CUDADeviceContext& context,
const framework::Tensor& input,
const std::vector<const framework::Tensor*>& ref_inputs,
int axis, std::vector<framework::Tensor*>* outputs) {
std::vector<const phi::DenseTensor*> pt_ref_inputs{ref_inputs.begin(),
ref_inputs.end()};
std::vector<phi::DenseTensor*> pt_outputs{outputs->begin(), outputs->end()};
phi::SplitImpl<T, platform::CUDADeviceContext>(
context, input, pt_ref_inputs, axis, &pt_outputs);
phi::funcs::SplitFunctor<phi::GPUContext, T> functor;
functor(context, input, ref_inputs, axis, outputs);
}
};
......
......@@ -64,17 +64,3 @@ class SplitFunctor {
} // namespace math
} // namespace operators
} // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16); \
macro(::paddle::platform::complex<float>); \
macro(::paddle::platform::complex<double>);
......@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/fluid/operators/unbind_op.h"
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
......@@ -79,11 +82,3 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(unbind, ops::UnbindOp, ops::UnbindOpMaker,
ops::UnbindGradMaker<paddle::framework::OpDesc>,
ops::UnbindGradMaker<paddle::imperative::OpBase>);
namespace plat = paddle::platform;
REGISTER_OP_CPU_KERNEL(
unbind, ops::UnbindOpKernel<plat::CPUDeviceContext, double>,
ops::UnbindOpKernel<plat::CPUDeviceContext, float>,
ops::UnbindOpKernel<plat::CPUDeviceContext, int64_t>,
ops::UnbindOpKernel<plat::CPUDeviceContext, int>,
ops::UnbindOpKernel<plat::CPUDeviceContext, plat::float16>,
ops::UnbindOpKernel<plat::CPUDeviceContext, plat::bfloat16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unbind_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
unbind, ops::UnbindOpKernel<plat::CUDADeviceContext, double>,
ops::UnbindOpKernel<plat::CUDADeviceContext, float>,
ops::UnbindOpKernel<plat::CUDADeviceContext, int64_t>,
ops::UnbindOpKernel<plat::CUDADeviceContext, int>,
ops::UnbindOpKernel<plat::CUDADeviceContext, plat::float16>,
ops::UnbindOpKernel<plat::CUDADeviceContext, plat::bfloat16>);
......@@ -34,27 +34,6 @@ static inline framework::DDim UnbindOutsDims(const framework::DDim in_dims,
}
return phi::make_ddim(out_dims);
}
template <typename DeviceContext, typename T>
class UnbindOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in = ctx.Input<framework::Tensor>("X");
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
int axis = ctx.Attr<int>("axis");
auto in_dims = in->dims();
axis = axis < 0 ? in_dims.size() + axis : axis;
std::vector<const framework::Tensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
outs[j]->mutable_data<T>(ctx.GetPlace());
shape_refer.emplace_back(outs[j]);
}
auto& dev_ctx = ctx.template device_context<DeviceContext>();
math::SplitFunctor<DeviceContext, T> functor;
functor(dev_ctx, *in, shape_refer, axis, &outs);
}
};
template <typename T>
class UnbindGradMaker : public framework::SingleGradOpMaker<T> {
......
......@@ -485,6 +485,25 @@ void SplitInferMeta(const MetaTensor& x,
}
}
void UnbindInferMeta(const MetaTensor& x,
int axis,
std::vector<MetaTensor>* outs) {
auto in_dims = x.dims();
std::vector<int> out_dim;
axis = axis < 0 ? in_dims.size() + axis : axis;
for (int i = 0; i < in_dims.size(); ++i) {
if (i != axis) out_dim.push_back(in_dims[i]);
}
auto out_dims = phi::make_ddim(out_dim);
for (size_t i = 0; i < outs->size(); ++i) {
(*outs)[i].set_dtype(x.dtype());
(*outs)[i].set_dims(out_dims);
(*outs)[i].set_layout(x.layout());
(*outs)[i].share_lod(x);
}
}
void TraceInferMeta(
const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out) {
int dim1 = axis1;
......
......@@ -90,6 +90,9 @@ void SplitInferMeta(const MetaTensor& x_meta,
std::vector<MetaTensor>* out,
MetaConfig config = MetaConfig());
void UnbindInferMeta(const MetaTensor& x,
int axis,
std::vector<MetaTensor>* outs);
void TraceInferMeta(
const MetaTensor& x, int offset, int axis1, int axis2, MetaTensor* out);
......
......@@ -10,7 +10,7 @@ add_subdirectory(funcs)
set_property(GLOBAL PROPERTY PTEN_KERNELS "")
set(COMMON_KERNEL_DEPS dense_tensor sparse_coo_tensor sparse_csr_tensor kernel_context kernel_factory arg_map_context convert_utils lod_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} eigen_function blas math_function im2col concat_and_split_functor)
# remove this dep after removing fluid deps on tensor creation
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} pten_api_utils)
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} infermeta)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
/*
* \brief Concatenate the input tensors along the dimension axis.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input[0] = [[1,2],[3,4]]
* Input[1] = [[5,6]]
* axis = 0
*
* Output = [[1,2],
* [3,4],
* [5,6]]
*/
template <typename T, typename Context>
void ConcatImpl(const Context& context,
const std::vector<DenseTensor>& input,
int axis,
DenseTensor* output) {
// TODO(zcd): Add input data validity checking
size_t num = input.size();
int64_t rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int64_t out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(input.size());
for (size_t i = 0; i < num; ++i) {
int64_t t_cols = input[i].numel() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
auto cpu_place = context.GetPlace();
// computation
auto output_data = output->data<T>();
int64_t col_idx = 0;
for (size_t j = 0; j < num; ++j) {
int64_t col_len = input_cols[j];
auto input_data = input[j].data<T>();
for (int64_t k = 0; k < out_rows; ++k) {
paddle::memory::Copy(cpu_place,
output_data + k * out_cols + col_idx,
cpu_place,
input_data + k * col_len,
sizeof(T) * col_len);
}
col_idx += col_len;
}
}
/*
* \brief Split the input tensors along the dimension axis into outputs.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input = [[1,2],
* [3,4],
* [5,6]]
* axis = 0
*
* Output[0] = [[1,2],[3,4]]
* Output[1] = [[5,6]]
*/
template <typename T, typename Context>
void SplitImpl(const Context& context,
const DenseTensor& input,
const std::vector<const DenseTensor*>& ref_inputs,
const int axis,
std::vector<DenseTensor*>* outputs) {
// NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
// tensors of shape [0,1,4]
if (input.numel() == 0) {
return;
}
// TODO(zcd): Add input data validity checking
size_t num = outputs->size();
int input_rows = 1;
auto dim_0 = ref_inputs[0]->dims();
for (int i = 0; i < axis; ++i) {
input_rows *= dim_0[i];
}
int input_cols = 0;
std::vector<int64_t> output_cols(outputs->size());
for (size_t i = 0; i < num; ++i) {
int t_cols = ref_inputs[i]->numel() / input_rows;
input_cols += t_cols;
output_cols[i] = t_cols;
}
auto cpu_place = context.GetPlace();
// computation
for (int k = 0; k < input_rows; ++k) {
const T* src_ptr = input.data<T>() + k * input_cols;
int col_idx = 0;
for (size_t j = 0; j < num; ++j) {
int col_len = output_cols[j];
auto* out_tensor = outputs->at(j);
if (out_tensor != nullptr) {
T* dst_ptr = out_tensor->data<T>() + k * col_len;
paddle::memory::Copy(cpu_place,
dst_ptr,
cpu_place,
src_ptr + col_idx,
sizeof(T) * col_len);
}
col_idx += col_len;
}
}
}
} // namespace phi
......@@ -22,7 +22,7 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/lod_utils.h"
#include "paddle/phi/kernels/cpu/concat_and_split.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
namespace phi {
......@@ -104,7 +104,8 @@ void ConcatKernel(const Context& dev_ctx,
continue;
}
}
ConcatImpl<T, Context>(dev_ctx, inputs, axis, out);
phi::funcs::ConcatFunctor<Context, T> functor;
functor(dev_ctx, inputs, axis, out);
}
}
......
......@@ -19,7 +19,7 @@
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/phi/kernels/cpu/concat_and_split.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
namespace phi {
template <typename T, typename Context>
......@@ -54,7 +54,8 @@ void SplitKernel(const Context& dev_ctx,
paddle::operators::StridedMemcpyWithAxis0<T>(
dev_ctx, x, shape_refer, &outs);
} else {
SplitImpl<T, Context>(dev_ctx, x, shape_refer, axis, &outs);
phi::funcs::SplitFunctor<Context, T> functor;
functor(dev_ctx, x, shape_refer, axis, &outs);
}
}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/unbind_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unbind_kernel_impl.h"
PD_REGISTER_KERNEL(unbind,
CPU,
ALL_LAYOUT,
phi::UnbindKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
......@@ -3,3 +3,4 @@ add_subdirectory(blas)
add_subdirectory(lapack)
math_library(math_function DEPS blas dense_tensor tensor)
math_library(concat_and_split_functor DEPS dense_tensor)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
namespace phi {
namespace funcs {
/*
* All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension.
*/
template <typename T>
struct ConcatFunctor<phi::CPUContext, T> {
void operator()(const phi::CPUContext& context,
const std::vector<phi::DenseTensor>& input,
int axis,
phi::DenseTensor* output) {
// TODO(zcd): Add input data validity checking
size_t num = input.size();
int64_t rows = 1;
auto dim_0 = input[0].dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int64_t out_rows = rows, out_cols = 0;
std::vector<int64_t> input_cols(input.size());
for (size_t i = 0; i < num; ++i) {
int64_t t_cols = input[i].numel() / rows;
out_cols += t_cols;
input_cols[i] = t_cols;
}
auto cpu_place = context.GetPlace();
// computation
auto output_data = output->data<T>();
int64_t col_idx = 0;
for (size_t j = 0; j < num; ++j) {
int64_t col_len = input_cols[j];
auto input_data = input[j].data<T>();
for (int64_t k = 0; k < out_rows; ++k) {
paddle::memory::Copy(cpu_place,
output_data + k * out_cols + col_idx,
cpu_place,
input_data + k * col_len,
sizeof(T) * col_len);
}
col_idx += col_len;
}
}
};
/*
* All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension.
*/
template <typename T>
struct SplitFunctor<phi::CPUContext, T> {
public:
void operator()(const phi::CPUContext& context,
const phi::DenseTensor& input,
const std::vector<const phi::DenseTensor*>& ref_inputs,
int axis,
std::vector<phi::DenseTensor*>* outputs) {
// NOTE(zhiqiu): split a tensor of shape [0,3,4] at axis=1, result in 3
// tensors of shape [0,1,4]
if (input.numel() == 0) {
return;
}
// TODO(zcd): Add input data validity checking
size_t num = outputs->size();
int input_rows = 1;
auto dim_0 = ref_inputs[0]->dims();
for (int i = 0; i < axis; ++i) {
input_rows *= dim_0[i];
}
int input_cols = 0;
std::vector<int64_t> output_cols(outputs->size());
for (size_t i = 0; i < num; ++i) {
int t_cols = ref_inputs[i]->numel() / input_rows;
input_cols += t_cols;
output_cols[i] = t_cols;
}
auto cpu_place = context.GetPlace();
// computation
for (int k = 0; k < input_rows; ++k) {
const T* src_ptr = input.data<T>() + k * input_cols;
int col_idx = 0;
for (size_t j = 0; j < num; ++j) {
int col_len = output_cols[j];
auto* out_tensor = outputs->at(j);
if (out_tensor != nullptr) {
T* dst_ptr = out_tensor->data<T>() + k * col_len;
paddle::memory::Copy(cpu_place,
dst_ptr,
cpu_place,
src_ptr + col_idx,
sizeof(T) * col_len);
}
col_idx += col_len;
}
}
}
};
#define DEFINE_FUNCTOR(type) \
template class ConcatFunctor<phi::CPUContext, type>; \
template class SplitFunctor<phi::CPUContext, type>;
FOR_ALL_TYPES(DEFINE_FUNCTOR);
} // namespace funcs
} // namespace phi
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
namespace funcs {
/*
* \brief Concatenate the input tensors along the dimension axis.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input[0] = [[1,2],[3,4]]
* Input[1] = [[5,6]]
* axis = 0
*
* Output = [[1,2],
* [3,4],
* [5,6]]
*/
template <typename Context, typename T>
struct ConcatFunctor {
void operator()(const Context& context,
const std::vector<phi::DenseTensor>& input,
int axis,
phi::DenseTensor* output);
};
/*
* \brief Split the input tensors along the dimension axis into outputs.
* TODO(zcd): maybe it needs to be more detailed.
* Examples:
* Input = [[1,2],
* [3,4],
* [5,6]]
* axis = 0
*
* Output[0] = [[1,2],[3,4]]
* Output[1] = [[5,6]]
*/
template <typename Context, typename T>
class SplitFunctor {
public:
void operator()(const Context& context,
const phi::DenseTensor& input,
const std::vector<const phi::DenseTensor*>& ref_inputs,
int axis,
std::vector<phi::DenseTensor*>* outputs);
};
} // namespace funcs
} // namespace phi
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(phi::dtype::float16); \
macro(phi::dtype::bfloat16); \
macro(phi::dtype::complex<float>); \
macro(phi::dtype::complex<double>);
......@@ -22,8 +22,8 @@
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/lod_utils.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"
#include "paddle/phi/kernels/gpu/concat_and_split.h"
namespace phi {
......@@ -104,7 +104,8 @@ void ConcatKernel(const Context& dev_ctx,
continue;
}
}
ConcatImpl<T, Context>(dev_ctx, inputs, axis, out);
phi::funcs::ConcatFunctor<Context, T> functor;
functor(dev_ctx, inputs, axis, out);
}
}
......
......@@ -18,7 +18,7 @@
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/gpu/concat_and_split.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
namespace phi {
template <typename T, typename Context>
......@@ -53,7 +53,8 @@ void SplitKernel(const Context& dev_ctx,
paddle::operators::StridedMemcpyWithAxis0<T>(
dev_ctx, x, shape_refer, &outs);
} else {
SplitImpl<T, Context>(dev_ctx, x, shape_refer, axis, &outs);
phi::funcs::SplitFunctor<Context, T> functor;
functor(dev_ctx, x, shape_refer, axis, &outs);
}
}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/unbind_kernel_impl.h"
#include "paddle/phi/kernels/unbind_kernel.h"
PD_REGISTER_KERNEL(unbind,
GPU,
ALL_LAYOUT,
phi::UnbindKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/unbind_kernel.h"
namespace phi {
template <typename T, typename Context>
void UnbindKernel(const Context& ctx,
const DenseTensor& x,
int axis,
std::vector<DenseTensor*> outs) {
auto x_dims = x.dims();
axis = axis < 0 ? x_dims.size() + axis : axis;
std::vector<const DenseTensor*> shape_refer;
for (size_t j = 0; j < outs.size(); ++j) {
ctx.template Alloc<T>(outs[j]);
shape_refer.emplace_back(outs[j]);
}
phi::funcs::SplitFunctor<Context, T> functor;
functor(ctx, x, shape_refer, axis, &outs);
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
/*
* All tensors' dimension should be the same and the values of
* each dimension must be the same, except the axis dimension.
*/
template <typename T, typename Context>
void UnbindKernel(const Context& ctx,
const DenseTensor& x,
int axis,
std::vector<DenseTensor*> outs);
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册