提交 13f44099 编写于 作者: P phlrain

move slice to pten

上级 5bb3b668
此差异已折叠。
...@@ -13,13 +13,12 @@ See the License for the specific language governing permissions and ...@@ -13,13 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <paddle/fluid/framework/operator.h> #include <paddle/fluid/framework/dim.h>
#include <string> #include <string>
#include <vector> #include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T = int64_t> template <typename T = int64_t>
inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/slice_grad_kernel.h"
#include "paddle/pten/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
PT_REGISTER_KERNEL(slice_grad,
CPU,
ALL_LAYOUT,
pten::SliceGradRawKernel,
bool,
int,
int64_t,
float,
double,
pten::dtype::complex<float>,
pten::dtype::complex<double>,
pten::dtype::bfloat16,
pten::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/slice_kernel.h"
#include "paddle/pten/kernels/impl/slice_kernel_impl.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
PT_REGISTER_KERNEL(slice,
CPU,
ALL_LAYOUT,
pten::SliceRawKernel,
bool,
int,
int64_t,
float,
double,
pten::dtype::complex<float>,
pten::dtype::complex<double>,
pten::dtype::bfloat16) {}
...@@ -56,7 +56,8 @@ struct EigenPad<Eigen::DefaultDevice, T, Rank> { ...@@ -56,7 +56,8 @@ struct EigenPad<Eigen::DefaultDevice, T, Rank> {
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 3>; \ template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 3>; \
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 4>; \ template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 4>; \
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 5>; \ template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 5>; \
template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 6> template struct FUNCTOR<Eigen::DefaultDevice, TYPE, 6>;
INSTANTIATION(EigenPad, bool); INSTANTIATION(EigenPad, bool);
INSTANTIATION(EigenPad, int); INSTANTIATION(EigenPad, int);
INSTANTIATION(EigenPad, int64_t); INSTANTIATION(EigenPad, int64_t);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/pten/kernels/slice_grad_kernel.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
PT_REGISTER_KERNEL(slice_grad,
GPU,
ALL_LAYOUT,
pten::SliceGradRawKernel,
bool,
int,
int64_t,
float,
double,
pten::dtype::complex<float>,
pten::dtype::complex<double>,
pten::dtype::bfloat16,
pten::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/impl/slice_kernel_impl.h"
#include "paddle/pten/kernels/slice_kernel.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
PT_REGISTER_KERNEL(slice,
GPU,
ALL_LAYOUT,
pten::SliceRawKernel,
bool,
int,
int64_t,
float,
double,
pten::dtype::complex<float>,
pten::dtype::complex<double>,
pten::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/slice_grad_kernel.h"
namespace pten {
template <typename T, typename Context, size_t D>
void LaunchEigenPadding(
const Context& context,
DenseTensor* d_input,
const DDim& in_dims,
const DenseTensor* d_out,
const DDim& out_dims,
const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) {
auto& place = *context.template eigen_device();
auto d_in_t = EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_input, in_dims);
auto d_out_t = EigenTensor<T, D, Eigen::RowMajor, Eigen::DenseIndex>::From(
*d_out, out_dims);
if (d_input->numel() <= Eigen::NumTraits<int>::highest()) {
// similar to tf.pad:
// if element number less than INT_MAX, change the type of index to int
Eigen::array<std::pair<int, int>, D> paddings_32bit;
for (size_t i = 0; i < D; i++) {
paddings_32bit[i] = std::make_pair(paddings[i].first, paddings[i].second);
}
funcs::EigenPad<std::decay_t<decltype(place)>, T, D>::Eval(
place,
To32BitIndex(d_in_t),
To32BitIndex(d_out_t),
paddings_32bit,
static_cast<T>(0));
} else {
funcs::EigenPad<std::decay_t<decltype(place)>, T, D>::Eval(
place, d_in_t, d_out_t, paddings, static_cast<T>(0));
}
}
template <typename T, typename Context, size_t D>
void EigenPaddingCompute(
const Context& context,
DenseTensor* d_input,
const DDim& in_dims,
const DenseTensor* d_out,
const DDim& out_dims,
const Eigen::array<std::pair<int64_t, int64_t>, D>& paddings) {
if (D <= 3) {
// if dimension less than 3, cannot reduce dimension
LaunchEigenPadding<T, Context, D>(
context, d_input, in_dims, d_out, out_dims, paddings);
}
// } else { // else we can reduce dimension
// // count not-zero padding number, and record the dimension
// int need_pad_num = 0, pad_dim = -1;
// for (size_t i = 0; i < D; i++) {
// if (paddings[i].first != 0 || paddings[i].second != 0) {
// need_pad_num++;
// pad_dim = i;
// }
// }
// if (need_pad_num == 1) {
// // only need padding one dimension, we can reduce dimension.
// // only the padding dimension is available for us.
// // How to reduce dimension(5 to 3 for example):
// // before(D=5):
// // in_dims: [x1, x2, x3, x4, x5]
// // padding.first: [0, 0, a, 0, 0]
// // padding.second: [0, 0, b, 0, 0]
// // | |
// // V V
// // after(D=3):
// // reshaped_in_dims: [x1*x2, x3, x4*x5]
// // reshaped_padding.first: [0, a, 0]
// // reshaped_padding.second: [0, b, 0]
// if (pad_dim == D - 1) {
// // only last dimension need padding,
// // reshape the dimension of tensor in 2: [preceding, padding]
// std::vector<int64_t> in_tore_shape(2, 1), out_tore_shape(2, 1);
// Eigen::array<std::pair<int64_t, int64_t>, 2> reshaped_padding;
// // first dimension is the accumulate of preceding dimension
// for (int i = 0; i < pad_dim; i++) {
// in_tore_shape[0] *= in_dims[i];
// out_tore_shape[0] *= out_dims[i];
// }
// // second dimension is the padding dimension
// in_tore_shape[1] = in_dims[pad_dim];
// out_tore_shape[1] = out_dims[pad_dim];
// // convert array from std::vector to DDim
// DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
// DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
// // after reshape: the first dimension do not need padding,
// // set padding[0] zero
// reshaped_padding[0].first = reshaped_padding[0].second = 0;
// // the second dimension is the previous padding dimension
// reshaped_padding[1].first = paddings[pad_dim].first;
// reshaped_padding[1].second = paddings[pad_dim].second;
// LaunchEigenPadding<T, Context, D>(context, d_input, reshaped_in_dims,
// d_out,
// reshaped_out_dims, reshaped_padding);
// } else if (pad_dim == 0) {
// // only first dimension need padding,
// // reshape the dimension of tensor in 2: [padding, succeeding]
// // similar to (D - 1)
// std::vector<int64_t> in_tore_shape(2, 1), out_tore_shape(2, 1);
// Eigen::array<std::pair<int64_t, int64_t>, 2> reshaped_padding;
// // first dimension is the padding dimension
// in_tore_shape[0] = in_dims[pad_dim];
// out_tore_shape[0] = out_dims[pad_dim];
// // sencond dimension is the accumulate of succeeding dimension
// for (size_t i = pad_dim + 1; i < D; i++) {
// in_tore_shape[1] *= in_dims[i];
// out_tore_shape[1] *= out_dims[i];
// }
// // convert array from std::vector to DDim
// DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
// DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
// // after reshape:
// // the first dimension is the previous padding dimension
// reshaped_padding[0].first = paddings[pad_dim].first;
// reshaped_padding[0].second = paddings[pad_dim].second;
// // the second dimension do not need padding, set padding[1] zero
// reshaped_padding[1].first = reshaped_padding[1].second = 0;
// LaunchEigenPadding<T, Context, D>(context, d_input, reshaped_in_dims,
// d_out,
// reshaped_out_dims, reshaped_padding);
// } else {
// // other dimension need padding
// // reshape the dimension of tensor in 3:
// // [preceding, padding, succeeding]
// std::vector<int64_t> in_tore_shape(3, 1), out_tore_shape(3, 1);
// Eigen::array<std::pair<int64_t, int64_t>, 3> reshaped_padding;
// // first dimension is the accumulate of preceding dimension
// for (int i = 0; i < pad_dim; i++) {
// in_tore_shape[0] *= in_dims[i];
// out_tore_shape[0] *= out_dims[i];
// }
// // second dimension is the padding dimension
// in_tore_shape[1] = in_dims[pad_dim];
// out_tore_shape[1] = out_dims[pad_dim];
// // third dimension is the accumulate of succeeding dimension
// for (size_t i = pad_dim + 1; i < D; i++) {
// in_tore_shape[2] *= in_dims[i];
// out_tore_shape[2] *= out_dims[i];
// }
// // convert array from std::vector to DDim
// DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
// DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
// // after reshape:
// // the first dimension do not need padding, set padding[0] zero
// reshaped_padding[0].first = reshaped_padding[2].second = 0;
// // the second dimension is the previous padding dimension
// reshaped_padding[1].first = paddings[pad_dim].first;
// reshaped_padding[1].second = paddings[pad_dim].second;
// // the third dimension do not need padding, set padding[2] zero
// reshaped_padding[2].first = reshaped_padding[2].second = 0;
// LaunchEigenPadding<T, Context, D>(context, d_input, reshaped_in_dims,
// d_out,
// reshaped_out_dims, reshaped_padding);
// }
// } else {
// // need padding at many dimension, cannot reduce dimension
// LaunchEigenPadding<T, Context, D>(context, d_input, in_dims, d_out,
// out_dims,
// paddings);
// }
// }
}
template <typename T, typename Context, size_t D>
void SliceGradCompute(const Context& ctx,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* input_grad) {
auto* d_out = &out_grad;
auto* d_input = input_grad;
d_input->mutable_data<T>(ctx.GetPlace());
auto out_dims = d_out->dims();
auto in_dims = d_input->dims();
auto decrease_size = decrease_axis.size();
if (decrease_size > 0) {
if (decrease_size == static_cast<size_t>(in_dims.size())) {
// all dims decrease
std::vector<int> origin_out_shape(decrease_size, 1);
out_dims = framework::make_ddim(std::vector<int>(decrease_size, 1));
} else {
std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1);
for (size_t i = 0; i < decrease_size; ++i) {
origin_out_shape[decrease_axis[i]] = 1;
}
int index = 0;
for (size_t i = 0; i < origin_out_shape.size(); ++i) {
if (origin_out_shape[i] == -1) {
origin_out_shape[i] = out_dims[index];
++index;
}
}
out_dims = framework::make_ddim(origin_out_shape);
}
}
auto offsets = Eigen::array<int64_t, D>();
auto extents = Eigen::array<int64_t, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = out_dims[i];
}
for (size_t i = 0; i < axes.size(); ++i) {
int axis = axes[i];
int64_t start = starts[i] < 0 ? (starts[i] + in_dims[axis]) : starts[i];
start = std::max(start, static_cast<int64_t>(0));
offsets[axis] = start;
}
Eigen::array<std::pair<int64_t, int64_t>, D> paddings;
for (size_t i = 0; i < paddings.size(); ++i) {
paddings[i].first = offsets[i];
paddings[i].second = (in_dims[i] - out_dims[i]) - offsets[i];
}
EigenPaddingCompute<T, Context, D>(
ctx, d_input, in_dims, d_out, out_dims, paddings);
}
template <typename T, typename Context>
void SliceGradRawKernel(const Context& ctx,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* input_grad) {
size_t rank = out_grad.dims().size();
switch (rank) {
case 1:
SliceGradCompute<T, Context, 1>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
case 2:
SliceGradCompute<T, Context, 2>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
case 3:
SliceGradCompute<T, Context, 3>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
case 4:
SliceGradCompute<T, Context, 4>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
case 5:
SliceGradCompute<T, Context, 5>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
case 6:
SliceGradCompute<T, Context, 6>(ctx,
out_grad,
axes,
starts,
ends,
infer_flags,
decrease_axis,
input_grad);
break;
default:
PADDLE_THROW(pten::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
namespace pten {
template <typename T, typename Context, size_t D>
void SliceCompute(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts_t,
const std::vector<int64_t>& ends_t,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out) {
// Step 1: Get the accurate attribute value of starts and ends
std::vector<int64_t> starts = starts_t;
std::vector<int64_t> ends = ends_t;
PADDLE_ENFORCE_EQ(
starts.size(),
axes.size(),
pten::errors::InvalidArgument(
"The size of starts must be equal to the size of axes."));
PADDLE_ENFORCE_EQ(ends.size(),
axes.size(),
pten::errors::InvalidArgument(
"The size of ends must be equal to the size of axes."));
// Step 2: Compute output
auto in = &input;
auto in_dims = in->dims();
auto out_dims = out->dims();
auto slice_dims = out_dims;
// 2.1 Infer output dims
for (size_t i = 0; i < axes.size(); ++i) {
// when start == -1 && end == start+1
if (starts[i] == -1 && ends[i] == 0 && infer_flags[i] == -1) {
auto ret = std::find(decrease_axis.begin(), decrease_axis.end(), axes[i]);
if (ret != decrease_axis.end()) {
ends[i] = in_dims[axes[i]];
}
}
}
paddle::operators::CheckAndUpdateSliceAttrs<int64_t>(
in_dims, axes, &starts, &ends);
slice_dims = paddle::operators::GetSliceDims<int64_t>(
in_dims, axes, starts, ends, nullptr, nullptr);
out_dims = paddle::operators::GetDecreasedDims(slice_dims, decrease_axis);
// 2.2 Get output
auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
auto extents = Eigen::DSizes<Eigen::DenseIndex, D>();
for (size_t i = 0; i < D; ++i) {
offsets[i] = 0;
extents[i] = slice_dims[i];
}
for (size_t i = 0; i < axes.size(); ++i) {
offsets[axes[i]] = starts[i];
}
out->Resize(slice_dims);
out->mutable_data<T>(ctx.GetPlace());
auto in_t = EigenTensor<T, D>::From(*in, in_dims);
auto out_t = EigenTensor<T, D>::From(*out, slice_dims);
auto& eigen_place = *ctx.eigen_device();
if (in->numel() <= Eigen::NumTraits<int>::highest()) {
// similar to tf.slice:
// if element number less than INT_MAX, change the type of index to int
Eigen::DSizes<int, D> offsets_32bit, extents_32bit;
for (size_t i = 0; i < D; i++) {
offsets_32bit[i] = offsets[i];
extents_32bit[i] = extents[i];
}
funcs::EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place,
To32BitIndex(out_t),
To32BitIndex(in_t),
offsets_32bit,
extents_32bit);
} else {
funcs::EigenSlice<std::decay_t<decltype(eigen_place)>, T, D>::Eval(
eigen_place, out_t, in_t, offsets, extents);
}
out->Resize(out_dims);
}
template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out) {
int rank = input.dims().size();
switch (rank) {
case 1:
SliceCompute<T, Context, 1>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
case 2:
SliceCompute<T, Context, 2>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
case 3:
SliceCompute<T, Context, 3>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
case 4:
SliceCompute<T, Context, 4>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
case 5:
SliceCompute<T, Context, 5>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
case 6:
SliceCompute<T, Context, 6>(
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
default:
PADDLE_THROW(pten::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void SliceGradRawKernel(const Context& ctx,
const DenseTensor& out_grad,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* input_grad);
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pten/core/dense_tensor.h"
namespace pten {
template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out);
} // namespace pten
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/compat/op_utils.h"
namespace pten {
KernelSignature SliceOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"slice",
{"Input"},
{"axes", "starts", "ends", "infer_flags", "decrease_axis"},
{"Out"});
}
KernelSignature SliceGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
"slice_grad",
{GradVarName("Out")},
{"axes", "starts", "ends", "infer_flags", "decrease_axis"},
{GradVarName("Input")});
}
} // namespace pten
PT_REGISTER_ARG_MAPPING_FN(slice, pten::SliceOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(slice_grad, pten::SliceGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册