未验证 提交 b8236b7b 编写于 作者: H hong 提交者: GitHub

Move slice to phi (#40736)

* move slice to pten

* merge develop; test=develop

* fix slice bug;

* update

* update

* fix error

* update

* fix bug

* polish code

* polish code

* polish code

* try to fix windows bug

* add gpu compile flag;

* try to fix

* remov template;

* polish code;

* fix npu bug;

* fix npu bug

* fix npu bug; test=develop

* fix slice bug;

* remove no need dep
上级 0ad2e192
......@@ -81,6 +81,8 @@ PD_DECLARE_KERNEL(sum, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sum_grad, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(max_raw, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(sgd, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(slice, GPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(slice_grad, GPU, ALL_LAYOUT);
DECLARE_double(eager_delete_tensor_gb);
......
......@@ -42,9 +42,12 @@ void SetValueCompute(const framework::ExecutionContext& ctx,
auto dtype = framework::TransToProtoVarType(in->dtype());
auto in_dims = in->dims();
CheckAndUpdateSliceAttrs<int64_t>(in_dims, axes, starts, ends, &steps);
auto slice_dims = GetSliceDims(in_dims, axes, *starts, *ends, &steps);
auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
phi::funcs::CheckAndUpdateSliceAttrs<int64_t>(in_dims, axes, starts, ends,
&steps);
auto slice_dims =
phi::funcs::GetSliceDims(in_dims, axes, *starts, *ends, &steps);
auto decrease_slice_dims =
phi::funcs::GetDecreasedDims(slice_dims, decrease_axes);
auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
......@@ -282,10 +285,10 @@ void SliceCompute(const framework::ExecutionContext& ctx,
}
}
CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends);
slice_dims =
GetSliceDims<int64_t>(in_dims, axes, starts, ends, nullptr, nullptr);
out_dims = GetDecreasedDims(slice_dims, decrease_axis);
phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends);
slice_dims = phi::funcs::GetSliceDims<int64_t>(in_dims, axes, starts, ends,
nullptr, nullptr);
out_dims = phi::funcs::GetDecreasedDims(slice_dims, decrease_axis);
// 2.2 Get output
auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
......
......@@ -22,9 +22,11 @@
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/assign_value_op.h"
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace paddle {
namespace operators {
......
......@@ -15,6 +15,8 @@ limitations under the License. */
#include "paddle/fluid/operators/set_value_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace paddle {
namespace operators {
......@@ -51,9 +53,11 @@ class SetValueNPUKernel : public framework::OpKernel<T> {
}
auto in_dims = in->dims();
CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, &steps);
auto slice_dims = GetSliceDims(in_dims, axes, starts, ends, &steps);
auto decrease_slice_dims = GetDecreasedDims(slice_dims, decrease_axes);
phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends, &steps);
auto slice_dims =
phi::funcs::GetSliceDims(in_dims, axes, starts, ends, &steps);
auto decrease_slice_dims =
phi::funcs::GetDecreasedDims(slice_dims, decrease_axes);
auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace paddle {
namespace operators {
......@@ -101,15 +102,17 @@ class SliceOp : public framework::OperatorWithKernel {
"The size of ends must be equal to the size of axes."));
}
CheckAndUpdateSliceAttrs<int>(in_dims, axes, &starts, &ends, nullptr,
&infer_flags);
phi::funcs::CheckAndUpdateSliceAttrs<int>(in_dims, axes, &starts, &ends,
nullptr, &infer_flags);
auto slice_dims =
GetSliceDims<int>(in_dims, axes, starts, ends, nullptr, &infer_flags);
auto slice_dims = phi::funcs::GetSliceDims<int>(in_dims, axes, starts, ends,
nullptr, &infer_flags);
if (ctx->IsRuntime()) {
out_dims = GetDecreasedDims<int>(slice_dims, decrease_axis, &infer_flags);
out_dims = phi::funcs::GetDecreasedDims<int>(slice_dims, decrease_axis,
&infer_flags);
} else {
out_dims = GetDecreasedDims<int>(slice_dims, decrease_axis, nullptr);
out_dims =
phi::funcs::GetDecreasedDims<int>(slice_dims, decrease_axis, nullptr);
}
ctx->SetOutputDim("Out", out_dims);
......
......@@ -18,7 +18,6 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/eigen/eigen_function.h"
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -81,38 +80,6 @@ template <typename DeviceContext, typename T>
class SliceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Variable* input_var = ctx.InputVar("Input");
bool is_tensor_array = input_var->IsType<LoDTensorArray>();
int rank = is_tensor_array ? 1 : ctx.Input<Tensor>("Input")->dims().size();
switch (rank) {
case 1:
SliceCompute<1>(ctx);
break;
case 2:
SliceCompute<2>(ctx);
break;
case 3:
SliceCompute<3>(ctx);
break;
case 4:
SliceCompute<4>(ctx);
break;
case 5:
SliceCompute<5>(ctx);
break;
case 6:
SliceCompute<6>(ctx);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
private:
template <size_t D>
void SliceCompute(const framework::ExecutionContext& ctx) const {
const Variable* input_var = ctx.InputVar("Input");
Variable* out_var = ctx.OutputVar("Out");
bool input_is_array = input_var->IsType<LoDTensorArray>();
......@@ -156,68 +123,6 @@ class SliceKernel : public framework::OpKernel<T> {
if (input_is_array) {
DealTensorArray(ctx, starts, ends, out_is_array);
return;
} else {
auto in = ctx.Input<Tensor>("Input");
auto out = ctx.Output<Tensor>("Out");
auto in_dims = in->dims();
auto out_dims = out->dims();
auto slice_dims = out_dims;
// 2.1 Infer output dims
for (size_t i = 0; i < axes.size(); ++i) {
// when start == -1 && end == start+1
if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) {
auto ret =
std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
ends[i] = in_dims[axes[i]];
}
}
}
CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends);
slice_dims =
GetSliceDims<int64_t>(in_dims, axes, starts, ends, nullptr, nullptr);
out_dims = GetDecreasedDims(slice_dims, decrease_axis);
// 2.2 Get output
auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = slice_dims[i];
}
for (size_t i = 0; i < axes.size(); ++i) {
offsets[axes[i]] = starts[i];
}
out->Resize(slice_dims);
out->mutable_data<T>(ctx.GetPlace());
auto in_t = framework::EigenTensor<T, D>::From(*in, in_dims);
auto out_t = framework::EigenTensor<T, D>::From(*out, slice_dims);
auto& eigen_place =
*ctx.template device_context<DeviceContext>().eigen_device();
if (in->numel() <= Eigen::NumTraits<int>::highest()) {
// similar to tf.slice:
// if element number less than INT_MAX, change the type of index to int
Eigen::DSizes<int, D> offsets_32bit, extents_32bit;
for (size_t i = 0; i < D; i++) {
offsets_32bit[i] = offsets[i];
extents_32bit[i] = extents[i];
}
EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place, framework::To32BitIndex(out_t),
framework::To32BitIndex(in_t), offsets_32bit, extents_32bit);
} else {
EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place, out_t, in_t, offsets, extents);
}
out->Resize(out_dims);
}
}
};
......@@ -226,38 +131,6 @@ template <typename DeviceContext, typename T>
class SliceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Variable* input_var = ctx.InputVar("Input");
bool is_array = input_var->IsType<LoDTensorArray>();
size_t rank = is_array ? 1 : ctx.Input<Tensor>("Input")->dims().size();
switch (rank) {
case 1:
SliceCompute<1>(ctx);
break;
case 2:
SliceCompute<2>(ctx);
break;
case 3:
SliceCompute<3>(ctx);
break;
case 4:
SliceCompute<4>(ctx);
break;
case 5:
SliceCompute<5>(ctx);
break;
case 6:
SliceCompute<6>(ctx);
break;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
private:
template <size_t D>
void SliceCompute(const framework::ExecutionContext& ctx) const {
auto axes = ctx.Attr<std::vector<int>>("axes");
auto starts_int = ctx.Attr<std::vector<int>>("starts");
auto ends_int = ctx.Attr<std::vector<int>>("ends");
......@@ -323,226 +196,9 @@ class SliceGradKernel : public framework::OpKernel<T> {
}
return;
}
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* d_input = ctx.Output<Tensor>(framework::GradVarName("Input"));
d_input->mutable_data<T>(ctx.GetPlace());
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
auto decrease_size = decrease_axis.size();
if (decrease_size > 0) {
if (decrease_size == static_cast<size_t>(in_dims.size())) {
// all dims decrease
std::vector<int> origin_out_shape(decrease_size, 1);
out_dims = phi::make_ddim(std::vector<int>(decrease_size, 1));
} else {
std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1);
for (size_t i = 0; i < decrease_size; ++i) {
origin_out_shape[decrease_axis[i]] = 1;
}
int index = 0;
for (size_t i = 0; i < origin_out_shape.size(); ++i) {
if (origin_out_shape[i] == -1) {
origin_out_shape[i] = out_dims[index];
++index;
}
}
out_dims = phi::make_ddim(origin_out_shape);
}
}
auto offsets = Eigen::array<int64_t, D>();
auto extents = Eigen::array<int64_t, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = out_dims[i];
}
for (size_t i = 0; i < axes.size(); ++i) {
int axis = axes[i];
int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
start = std::max(start, static_cast<int64_t>(0));
offsets[axis] = start;
}
Eigen::array<std::pair<int64_t, int64_t>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = offsets[i];
paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i];
}
EigenPaddingCompute(ctx, d_input, in_dims, d_out, out_dims, paddings);
}
template <size_t D>
void EigenPaddingCompute(
const framework::ExecutionContext& context, Tensor* d_input,
const DDim& in_dims, const Tensor* d_out, const DDim& out_dims,
const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) const {
if (D <= 3) {
// if dimension less than 3, cannot reduce dimension
LaunchEigenPadding(context, d_input, in_dims, d_out, out_dims, paddings);
} else { // else we can reduce dimension
// count not-zero padding number, and record the dimension
int need_pad_num = 0, pad_dim = -1;
for (size_t i = 0; i < D; i++) {
if (paddings[i].first != 0 || paddings[i].second != 0) {
need_pad_num++;
pad_dim = i;
}
}
if (need_pad_num == 1) {
// only need padding one dimension, we can reduce dimension.
// only the padding dimension is available for us.
// How to reduce dimension(5 to 3 for example):
// before(D=5):
// in_dims: [x1, x2, x3, x4, x5]
// padding.first: [0, 0, a, 0, 0]
// padding.second: [0, 0, b, 0, 0]
// | |
// V V
// after(D=3):
// reshaped_in_dims: [x1*x2, x3, x4*x5]
// reshaped_padding.first: [0, a, 0]
// reshaped_padding.second: [0, b, 0]
if (pad_dim == D - 1) {
// only last dimension need padding,
// reshape the dimension of tensor in 2: [preceding, padding]
std::vector<int64_t> in_tore_shape(2, 1), out_tore_shape(2, 1);
Eigen::array<std::pair<int64_t, int64_t>, 2> reshaped_padding;
// first dimension is the accumulate of preceding dimension
for (int i = 0; i < pad_dim; i++) {
in_tore_shape[0] *= in_dims[i];
out_tore_shape[0] *= out_dims[i];
}
// second dimension is the padding dimension
in_tore_shape[1] = in_dims[pad_dim];
out_tore_shape[1] = out_dims[pad_dim];
// convert array from std::vector to DDim
DDim reshaped_in_dims = phi::make_ddim(in_tore_shape);
DDim reshaped_out_dims = phi::make_ddim(out_tore_shape);
// after reshape: the first dimension do not need padding,
// set padding[0] zero
reshaped_padding[0].first = reshaped_padding[0].second = 0;
// the second dimension is the previous padding dimension
reshaped_padding[1].first = paddings[pad_dim].first;
reshaped_padding[1].second = paddings[pad_dim].second;
LaunchEigenPadding(context, d_input, reshaped_in_dims, d_out,
reshaped_out_dims, reshaped_padding);
} else if (pad_dim == 0) {
// only first dimension need padding,
// reshape the dimension of tensor in 2: [padding, succeeding]
// similar to (D - 1)
std::vector<int64_t> in_tore_shape(2, 1), out_tore_shape(2, 1);
Eigen::array<std::pair<int64_t, int64_t>, 2> reshaped_padding;
// first dimension is the padding dimension
in_tore_shape[0] = in_dims[pad_dim];
out_tore_shape[0] = out_dims[pad_dim];
// sencond dimension is the accumulate of succeeding dimension
for (size_t i = pad_dim + 1; i < D; i++) {
in_tore_shape[1] *= in_dims[i];
out_tore_shape[1] *= out_dims[i];
}
// convert array from std::vector to DDim
DDim reshaped_in_dims = phi::make_ddim(in_tore_shape);
DDim reshaped_out_dims = phi::make_ddim(out_tore_shape);
// after reshape:
// the first dimension is the previous padding dimension
reshaped_padding[0].first = paddings[pad_dim].first;
reshaped_padding[0].second = paddings[pad_dim].second;
// the second dimension do not need padding, set padding[1] zero
reshaped_padding[1].first = reshaped_padding[1].second = 0;
LaunchEigenPadding(context, d_input, reshaped_in_dims, d_out,
reshaped_out_dims, reshaped_padding);
} else {
// other dimension need padding
// reshape the dimension of tensor in 3:
// [preceding, padding, succeeding]
std::vector<int64_t> in_tore_shape(3, 1), out_tore_shape(3, 1);
Eigen::array<std::pair<int64_t, int64_t>, 3> reshaped_padding;
// first dimension is the accumulate of preceding dimension
for (int i = 0; i < pad_dim; i++) {
in_tore_shape[0] *= in_dims[i];
out_tore_shape[0] *= out_dims[i];
}
// second dimension is the padding dimension
in_tore_shape[1] = in_dims[pad_dim];
out_tore_shape[1] = out_dims[pad_dim];
// third dimension is the accumulate of succeeding dimension
for (size_t i = pad_dim + 1; i < D; i++) {
in_tore_shape[2] *= in_dims[i];
out_tore_shape[2] *= out_dims[i];
}
// convert array from std::vector to DDim
DDim reshaped_in_dims = phi::make_ddim(in_tore_shape);
DDim reshaped_out_dims = phi::make_ddim(out_tore_shape);
// after reshape:
// the first dimension do not need padding, set padding[0] zero
reshaped_padding[0].first = reshaped_padding[2].second = 0;
// the second dimension is the previous padding dimension
reshaped_padding[1].first = paddings[pad_dim].first;
reshaped_padding[1].second = paddings[pad_dim].second;
// the third dimension do not need padding, set padding[2] zero
reshaped_padding[2].first = reshaped_padding[2].second = 0;
LaunchEigenPadding(context, d_input, reshaped_in_dims, d_out,
reshaped_out_dims, reshaped_padding);
}
} else {
// need padding at many dimension, cannot reduce dimension
LaunchEigenPadding(context, d_input, in_dims, d_out, out_dims,
paddings);
}
}
}
template <size_t D>
void LaunchEigenPadding(
const framework::ExecutionContext& context, Tensor* d_input,
const DDim& in_dims, const Tensor* d_out, const DDim& out_dims,
const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) const {
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
auto d_in_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input, in_dims);
auto d_out_t =
framework::EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out, out_dims);
if (d_input->numel() <= Eigen::NumTraits<int>::highest()) {
// similar to tf.pad:
// if element number less than INT_MAX, change the type of index to int
Eigen::array<std::pair<int, int>, D> paddings_32bit;
for (size_t i = 0; i < D; i++) {
paddings_32bit[i] =
std::make_pair(paddings[i].first, paddings[i].second);
}
EigenPad<std::decay_t<decltype(place)>, T, D>::Eval(
place, framework::To32BitIndex(d_in_t),
framework::To32BitIndex(d_out_t), paddings_32bit, static_cast<T>(0));
} else {
EigenPad<std::decay_t<decltype(place)>, T, D>::Eval(
place, d_in_t, d_out_t, paddings, static_cast<T>(0));
}
}
private:
};
} // namespace operators
} // namespace paddle
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/slice_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace paddle {
namespace operators {
......@@ -109,10 +111,10 @@ class SliceNPUKernel : public framework::OpKernel<T> {
}
}
CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends);
slice_dims =
GetSliceDims<int>(in_dims, axes, starts, ends, nullptr, nullptr);
out_dims = GetDecreasedDims(slice_dims, decrease_axis);
phi::funcs::CheckAndUpdateSliceAttrs(in_dims, axes, &starts, &ends);
slice_dims = phi::funcs::GetSliceDims<int>(in_dims, axes, starts, ends,
nullptr, nullptr);
out_dims = phi::funcs::GetDecreasedDims(slice_dims, decrease_axis);
out->Resize(out_dims);
}
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/slice_grad_kernel.h"
#include "paddle/phi/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(slice_grad,
CPU,
ALL_LAYOUT,
phi::SliceGradRawKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/slice_kernel.h"
#include "paddle/phi/kernels/impl/slice_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(slice,
CPU,
ALL_LAYOUT,
phi::SliceRawKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::bfloat16) {}
......@@ -163,11 +163,11 @@ struct EigenPad {
const InType& in,
const Array& padding,
const T value);
static void Eval(const EigenDevice& dev,
OutType32BitIndex out,
const InType32BitIndex& in,
const Array32Bit& padding,
const T value);
static void Eval32(const EigenDevice& dev,
OutType32BitIndex out,
const InType32BitIndex& in,
const Array32Bit& padding,
const T value);
};
template <typename EigenDevice, typename T>
......
......@@ -41,11 +41,11 @@ struct EigenPad<Eigen::DefaultDevice, T, Rank> {
out.device(dev) = in.pad(padding, value);
}
static void Eval(const Eigen::DefaultDevice& dev,
OutType32BitIndex out,
const InType32BitIndex& in,
const Array32Bit& padding,
const T value) {
static void Eval32(const Eigen::DefaultDevice& dev,
OutType32BitIndex out,
const InType32BitIndex& in,
const Array32Bit& padding,
const T value) {
out.device(dev) = in.pad(padding, value);
}
};
......@@ -56,7 +56,8 @@ struct EigenPad<Eigen::DefaultDevice, T, Rank> {
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 3>; \
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 4>; \
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 5>; \
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 6>
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 6>;
INSTANTIATION(EigenPad, bool);
INSTANTIATION(EigenPad, int);
INSTANTIATION(EigenPad, int64_t);
......
......@@ -42,11 +42,11 @@ struct EigenPad<Eigen::GpuDevice, T, Rank> {
out.device(dev) = in.pad(padding, value);
}
static void Eval(const Eigen::GpuDevice& dev,
OutType32BitIndex out,
const InType32BitIndex& in,
const Array32Bit& padding,
const T value) {
static void Eval32(const Eigen::GpuDevice& dev,
OutType32BitIndex out,
const InType32BitIndex& in,
const Array32Bit& padding,
const T value) {
out.device(dev) = in.pad(padding, value);
}
};
......
......@@ -13,16 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/fluid/framework/operator.h>
#include <paddle/phi/core/ddim.h>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
namespace phi {
namespace funcs {
template <typename T = int64_t>
inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
inline void CheckAndUpdateSliceAttrs(const DDim in_dims,
const std::vector<T>& axes,
std::vector<T>* starts,
std::vector<T>* ends,
......@@ -31,11 +31,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
PADDLE_ENFORCE_LT(
axis, in_dims.size(),
platform::errors::InvalidArgument(
axis,
in_dims.size(),
phi::errors::InvalidArgument(
"The axis value should be less than the rank of input, "
"but received axes[%d] = %d, rank of input is %d.",
i, axis, in_dims.size()));
i,
axis,
in_dims.size()));
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue;
......@@ -46,8 +49,10 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
if (dim_value > 0) {
T step = steps == nullptr ? 1 : (*steps)[i];
PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));
step,
0,
phi::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
start = std::max(start, static_cast<T>(0));
......@@ -60,11 +65,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0));
PADDLE_ENFORCE_GE(
end, start,
platform::errors::InvalidArgument(
end,
start,
phi::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end, start));
end,
start));
} else {
// NOTE(liym27): When step < 0, start should less and equal to
// dim_value-1
......@@ -72,11 +79,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<T>(-1));
PADDLE_ENFORCE_GE(
start, end,
platform::errors::InvalidArgument(
start,
end,
phi::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start, end));
start,
end));
}
(*starts)[i] = start;
......@@ -89,13 +98,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
}
template <typename T = int64_t>
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
const std::vector<T>& axes,
const std::vector<T>& starts,
const std::vector<T>& ends,
std::vector<T>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
framework::DDim slice_dims(in_dims);
inline phi::DDim GetSliceDims(const phi::DDim in_dims,
const std::vector<T>& axes,
const std::vector<T>& starts,
const std::vector<T>& ends,
std::vector<T>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
phi::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
......@@ -118,18 +127,19 @@ inline framework::DDim GetSliceDims(const framework::DDim in_dims,
}
template <typename T = int64_t>
inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
const std::vector<T>& decrease_axes,
std::vector<T>* infer_flags = nullptr) {
framework::DDim decreased_dims(slice_dims);
inline DDim GetDecreasedDims(const DDim slice_dims,
const std::vector<T>& decrease_axes,
std::vector<T>* infer_flags = nullptr) {
DDim decreased_dims(slice_dims);
std::vector<uint8_t> decrease_flag(slice_dims.size(), 0);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
T axis = decrease_axes[i];
decrease_flag[axis] = 1;
if (infer_flags && (*infer_flags)[i] != -1) {
PADDLE_ENFORCE_EQ(decreased_dims[axis], 1,
platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(decreased_dims[axis],
1,
phi::errors::InvalidArgument(
"Decrease dim should be 1, but now received %d",
decreased_dims[axis]));
}
......@@ -153,5 +163,5 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
return decreased_dims;
}
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/slice_grad_kernel.h"
#include "paddle/phi/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(slice_grad,
GPU,
ALL_LAYOUT,
phi::SliceGradRawKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::bfloat16,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/slice_kernel.h"
#include "paddle/phi/kernels/impl/slice_kernel_impl.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL(slice,
GPU,
ALL_LAYOUT,
phi::SliceRawKernel,
bool,
int,
int64_t,
float,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -24,8 +24,7 @@
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace phi {
......@@ -85,12 +84,12 @@ void SetValueImpl(const Context& dev_ctx,
std::vector<int64_t> starts_local = starts.GetData();
std::vector<int64_t> ends_local = ends.GetData();
std::vector<int64_t> steps_local = steps.GetData();
paddle::operators::CheckAndUpdateSliceAttrs(
phi::funcs::CheckAndUpdateSliceAttrs(
in_dims, axes, &starts_local, &ends_local, &steps_local);
auto slice_dims = paddle::operators::GetSliceDims(
auto slice_dims = phi::funcs::GetSliceDims(
in_dims, axes, starts_local, ends_local, &steps_local);
auto decrease_slice_dims =
paddle::operators::GetDecreasedDims(slice_dims, decrease_axes);
phi::funcs::GetDecreasedDims(slice_dims, decrease_axes);
auto slice_dims_for_assign = decrease_slice_dims;
if (!none_axes.empty()) {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/slice_grad_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace phi {
template <typename T, typename Context, size_t D>
void LaunchEigenPadding(
const Context& context,
DenseTensor* d_input,
const DDim& in_dims,
const DenseTensor* d_out,
const DDim& out_dims,
const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) {
auto& place = *context.eigen_device();
auto d_in_t = EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input, in_dims);
auto d_out_t = EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out, out_dims);
if (d_input->numel() <= Eigen::NumTraits<int>::highest()) {
// similar to tf.pad:
// if element number less than INT_MAX, change the type of index to int
Eigen::array<std::pair<int, int>, D> paddings_32bit;
for (size_t i = 0; i < D; i++) {
paddings_32bit[i] = std::make_pair(paddings[i].first, paddings[i].second);
}
funcs::EigenPad<std::decay_t<decltype(place)>, T, D>::Eval32(
place,
To32BitIndex(d_in_t),
To32BitIndex(d_out_t),
paddings_32bit,
static_cast<T>(0));
} else {
funcs::EigenPad<std::decay_t<decltype(place)>, T, D>::Eval(
place, d_in_t, d_out_t, paddings, static_cast<T>(0));
}
}
template <typename T, typename Context, size_t D>
void EigenPaddingCompute(
const Context& context,
DenseTensor* d_input,
const DDim& in_dims,
const DenseTensor* d_out,
const DDim& out_dims,
const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) {
if (D <= 3) {
// if dimension less than 3, cannot reduce dimension
LaunchEigenPadding<T, Context, D>(
context, d_input, in_dims, d_out, out_dims, paddings);
} else { // else we can reduce dimension
// count not-zero padding number, and record the dimension
int need_pad_num = 0, pad_dim = -1;
for (size_t i = 0; i < D; i++) {
if (paddings[i].first != 0 || paddings[i].second != 0) {
need_pad_num++;
pad_dim = i;
}
}
if (need_pad_num == 1) {
// only need padding one dimension, we can reduce dimension.
// only the padding dimension is available for us.
// How to reduce dimension(5 to 3 for example):
// before(D=5):
// in_dims: [x1, x2, x3, x4, x5]
// padding.first: [0, 0, a, 0, 0]
// padding.second: [0, 0, b, 0, 0]
// | |
// V V
// after(D=3):
// reshaped_in_dims: [x1*x2, x3, x4*x5]
// reshaped_padding.first: [0, a, 0]
// reshaped_padding.second: [0, b, 0]
if (pad_dim == D - 1) {
// only last dimension need padding,
// reshape the dimension of tensor in 2: [preceding, padding]
std::vector<int64_t> in_tore_shape(2, 1), out_tore_shape(2, 1);
Eigen::array<std::pair<int64_t, int64_t>, 2> reshaped_padding;
// first dimension is the accumulate of preceding dimension
for (int i = 0; i < pad_dim; i++) {
in_tore_shape[0] *= in_dims[i];
out_tore_shape[0] *= out_dims[i];
}
// second dimension is the padding dimension
in_tore_shape[1] = in_dims[pad_dim];
out_tore_shape[1] = out_dims[pad_dim];
// convert array from std::vector to DDim
DDim reshaped_in_dims = make_ddim(in_tore_shape);
DDim reshaped_out_dims = make_ddim(out_tore_shape);
// after reshape: the first dimension do not need padding,
// set padding[0] zero
reshaped_padding[0].first = reshaped_padding[0].second = 0;
// the second dimension is the previous padding dimension
reshaped_padding[1].first = paddings[pad_dim].first;
reshaped_padding[1].second = paddings[pad_dim].second;
LaunchEigenPadding<T, Context>(context,
d_input,
reshaped_in_dims,
d_out,
reshaped_out_dims,
reshaped_padding);
} else if (pad_dim == 0) {
// only first dimension need padding,
// reshape the dimension of tensor in 2: [padding, succeeding]
// similar to (D - 1)
std::vector<int64_t> in_tore_shape(2, 1), out_tore_shape(2, 1);
Eigen::array<std::pair<int64_t, int64_t>, 2> reshaped_padding;
// first dimension is the padding dimension
in_tore_shape[0] = in_dims[pad_dim];
out_tore_shape[0] = out_dims[pad_dim];
// sencond dimension is the accumulate of succeeding dimension
for (size_t i = pad_dim + 1; i < D; i++) {
in_tore_shape[1] *= in_dims[i];
out_tore_shape[1] *= out_dims[i];
}
// convert array from std::vector to DDim
DDim reshaped_in_dims = make_ddim(in_tore_shape);
DDim reshaped_out_dims = make_ddim(out_tore_shape);
// after reshape:
// the first dimension is the previous padding dimension
reshaped_padding[0].first = paddings[pad_dim].first;
reshaped_padding[0].second = paddings[pad_dim].second;
// the second dimension do not need padding, set padding[1] zero
reshaped_padding[1].first = reshaped_padding[1].second = 0;
LaunchEigenPadding<T, Context, 2>(context,
d_input,
reshaped_in_dims,
d_out,
reshaped_out_dims,
reshaped_padding);
} else {
// other dimension need padding
// reshape the dimension of tensor in 3:
// [preceding, padding, succeeding]
std::vector<int64_t> in_tore_shape(3, 1), out_tore_shape(3, 1);
Eigen::array<std::pair<int64_t, int64_t>, 3> reshaped_padding;
// first dimension is the accumulate of preceding dimension
for (int i = 0; i < pad_dim; i++) {
in_tore_shape[0] *= in_dims[i];
out_tore_shape[0] *= out_dims[i];
}
// second dimension is the padding dimension
in_tore_shape[1] = in_dims[pad_dim];
out_tore_shape[1] = out_dims[pad_dim];
// third dimension is the accumulate of succeeding dimension
for (size_t i = pad_dim + 1; i < D; i++) {
in_tore_shape[2] *= in_dims[i];
out_tore_shape[2] *= out_dims[i];
}
// convert array from std::vector to DDim
DDim reshaped_in_dims = make_ddim(in_tore_shape);
DDim reshaped_out_dims = make_ddim(out_tore_shape);
// after reshape:
// the first dimension do not need padding, set padding[0] zero
reshaped_padding[0].first = reshaped_padding[2].second = 0;
// the second dimension is the previous padding dimension
reshaped_padding[1].first = paddings[pad_dim].first;
reshaped_padding[1].second = paddings[pad_dim].second;
// the third dimension do not need padding, set padding[2] zero
reshaped_padding[2].first = reshaped_padding[2].second = 0;
LaunchEigenPadding<T, Context, 3>(context,
d_input,
reshaped_in_dims,
d_out,
reshaped_out_dims,
reshaped_padding);
}
} else {
// need padding at many dimension, cannot reduce dimension
LaunchEigenPadding<T, Context>(
context, d_input, in_dims, d_out, out_dims, paddings);
}
}
}
template <typename T, typename Context, size_t D>
void SliceGradCompute(const Context& ctx,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* input_grad) {
auto* d_out = &out_grad;
auto* d_input = input_grad;
ctx.template Alloc<T>(d_input);
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
auto decrease_size = decrease_axis.size();
if (decrease_size > 0) {
if (decrease_size == static_cast<size_t>(in_dims.size())) {
// all dims decrease
std::vector<int> origin_out_shape(decrease_size, 1);
out_dims = make_ddim(std::vector<int>(decrease_size, 1));
} else {
std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1);
for (size_t i = 0; i < decrease_size; ++i) {
origin_out_shape[decrease_axis[i]] = 1;
}
int index = 0;
for (size_t i = 0; i < origin_out_shape.size(); ++i) {
if (origin_out_shape[i] == -1) {
origin_out_shape[i] = out_dims[index];
++index;
}
}
out_dims = make_ddim(origin_out_shape);
}
}
auto offsets = Eigen::array<int64_t, D>();
auto extents = Eigen::array<int64_t, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = out_dims[i];
}
for (size_t i = 0; i < axes.size(); ++i) {
int axis = axes[i];
int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
start = std::max(start, static_cast<int64_t>(0));
offsets[axis] = start;
}
Eigen::array<std::pair<int64_t, int64_t>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = offsets[i];
paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i];
}
EigenPaddingCompute<T, Context, D>(
ctx, d_input, in_dims, d_out, out_dims, paddings);
}
template <typename T, typename Context>
void SliceGradRawKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const ScalarArray& starts_arr,
const ScalarArray& ends_arr,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* input_grad) {
size_t rank = input.dims().size();
auto& starts = starts_arr.GetData();
auto& ends = ends_arr.GetData();
switch (rank) {
case 1:
SliceGradCompute<T, Context, 1>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
case 2:
SliceGradCompute<T, Context, 2>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
case 3:
SliceGradCompute<T, Context, 3>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
case 4:
SliceGradCompute<T, Context, 4>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
case 5:
SliceGradCompute<T, Context, 5>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
case 6:
SliceGradCompute<T, Context, 6>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace phi {
template <typename T, typename Context, size_t D>
void SliceCompute(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts_t,
const std::vector<int64_t>& ends_t,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out) {
// Step 1: Get the accurate attribute value of starts and ends
std::vector<int64_t> starts = starts_t;
std::vector<int64_t> ends = ends_t;
PADDLE_ENFORCE_EQ(
starts.size(),
axes.size(),
phi::errors::InvalidArgument(
"The size of starts must be equal to the size of axes."));
PADDLE_ENFORCE_EQ(ends.size(),
axes.size(),
phi::errors::InvalidArgument(
"The size of ends must be equal to the size of axes."));
// Step 2: Compute output
auto in = &input;
auto in_dims = in->dims();
auto out_dims = out->dims();
auto slice_dims = out_dims;
// 2.1 Infer output dims
for (size_t i = 0; i < axes.size(); ++i) {
// when start == -1 && end == start+1
if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) {
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
ends[i] = in_dims[axes[i]];
}
}
}
funcs::CheckAndUpdateSliceAttrs<int64_t>(in_dims, axes, &starts, &ends);
slice_dims = funcs::GetSliceDims<int64_t>(
in_dims, axes, starts, ends, nullptr, nullptr);
out_dims = funcs::GetDecreasedDims<int64_t>(slice_dims, decrease_axis);
// 2.2 Get output
auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = slice_dims[i];
}
for (size_t i = 0; i < axes.size(); ++i) {
offsets[axes[i]] = starts[i];
}
out->Resize(slice_dims);
ctx.template Alloc<T>(out);
auto in_t = EigenTensor<T, D>::From(*in, in_dims);
auto out_t = EigenTensor<T, D>::From(*out, slice_dims);
auto& eigen_place = *ctx.eigen_device();
if (in->numel() <= Eigen::NumTraits<int>::highest()) {
// similar to tf.slice:
// if element number less than INT_MAX, change the type of index to int
Eigen::DSizes<int, D> offsets_32bit, extents_32bit;
for (size_t i = 0; i < D; i++) {
offsets_32bit[i] = offsets[i];
extents_32bit[i] = extents[i];
}
funcs::EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place,
To32BitIndex(out_t),
To32BitIndex(in_t),
offsets_32bit,
extents_32bit);
} else {
funcs::EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place, out_t, in_t, offsets, extents);
}
out->Resize(out_dims);
}
template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const ScalarArray& starts_arr,
const ScalarArray& ends_arr,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out) {
int rank = input.dims().size();
auto& starts = starts_arr.GetData();
auto& ends = ends_arr.GetData();
switch (rank) {
case 1:
SliceCompute<T, Context, 1>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
case 2:
SliceCompute<T, Context, 2>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
case 3:
SliceCompute<T, Context, 3>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
case 4:
SliceCompute<T, Context, 4>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
case 5:
SliceCompute<T, Context, 5>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
case 6:
SliceCompute<T, Context, 6>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
default:
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SliceGradRawKernel(const Context& ctx,
const DenseTensor& input,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* input_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature SliceOpArgumentMapping(const ArgumentMappingContext& ctx) {
// if input is Tensor Array
if (ctx.IsDenseTensorVectorInput("Input")) {
return KernelSignature("unregistered", {}, {}, {});
}
if (ctx.HasInput("StartsTensor")) {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature("slice",
{"Input"},
{"axes",
"StartsTensor",
"EndsTensor",
"infer_flags",
"decrease_axis"},
{"Out"});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature("slice",
{"Input"},
{"axes",
"StartsTensor",
"EndsTensorList",
"infer_flags",
"decrease_axis"},
{"Out"});
} else {
return KernelSignature(
"slice",
{"Input"},
{"axes", "StartsTensor", "ends", "infer_flags", "decrease_axis"},
{"Out"});
}
} else if (ctx.InputSize("StartsTensorList") > 0) {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature("slice",
{"Input"},
{"axes",
"StartsTensorList",
"EndsTensor",
"infer_flags",
"decrease_axis"},
{"Out"});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature("slice",
{"Input"},
{"axes",
"StartsTensorList",
"EndsTensorList",
"infer_flags",
"decrease_axis"},
{"Out"});
} else {
return KernelSignature(
"slice",
{"Input"},
{"axes", "StartsTensorList", "ends", "infer_flags", "decrease_axis"},
{"Out"});
}
} else {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature(
"slice",
{"Input"},
{"axes", "starts", "EndsTensor", "infer_flags", "decrease_axis"},
{"Out"});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature(
"slice",
{"Input"},
{"axes", "starts", "EndsTensorList", "infer_flags", "decrease_axis"},
{"Out"});
} else {
return KernelSignature(
"slice",
{"Input"},
{"axes", "starts", "ends", "infer_flags", "decrease_axis"},
{"Out"});
}
}
}
KernelSignature SliceGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorVectorInput("Input")) {
return KernelSignature("unregistered", {}, {}, {});
}
if (ctx.HasInput("StartsTensor")) {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature("slice_grad",
{"Input", GradVarName("Out")},
{"axes",
"StartsTensor",
"EndsTensor",
"infer_flags",
"decrease_axis"},
{GradVarName("Input")});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature("slice_grad",
{"Input", GradVarName("Out")},
{"axes",
"StartsTensor",
"EndsTensorList",
"infer_flags",
"decrease_axis"},
{GradVarName("Input")});
} else {
return KernelSignature(
"slice_grad",
{"Input", GradVarName("Out")},
{"axes", "StartsTensor", "ends", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
}
} else if (ctx.InputSize("StartsTensorList") > 0) {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature("slice_grad",
{"Input", GradVarName("Out")},
{"axes",
"StartsTensorList",
"EndsTensor",
"infer_flags",
"decrease_axis"},
{GradVarName("Input")});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature("slice_grad",
{"Input", GradVarName("Out")},
{"axes",
"StartsTensorList",
"EndsTensorList",
"infer_flags",
"decrease_axis"},
{GradVarName("Input")});
} else {
return KernelSignature(
"slice_grad",
{"Input", GradVarName("Out")},
{"axes", "StartsTensorList", "ends", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
}
} else {
if (ctx.HasInput("EndsTensor")) {
return KernelSignature(
"slice_grad",
{"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensor", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
} else if (ctx.InputSize("EndsTensorList") > 0) {
return KernelSignature(
"slice_grad",
{"Input", GradVarName("Out")},
{"axes", "starts", "EndsTensorList", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
} else {
return KernelSignature(
"slice_grad",
{"Input", GradVarName("Out")},
{"axes", "starts", "ends", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
}
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(slice, phi::SliceOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(slice_grad, phi::SliceGradOpArgumentMapping);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar_array.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const ScalarArray& starts,
const ScalarArray& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out);
} // namespace phi
......@@ -796,4 +796,5 @@ class TestImperativeCUDAPinnedInput(unittest.TestCase):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册