提交 fc7c39f5 编写于 作者: P phlrain

fix slice bug;

上级 d02df1ed
......@@ -12,22 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/slice_grad_kernel.h"
#include "paddle/pten/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/phi/kernels/slice_grad_kernel.h"
#include "paddle/phi/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PT_REGISTER_KERNEL(slice_grad,
PD_REGISTER_KERNEL(slice_grad,
CPU,
ALL_LAYOUT,
pten::SliceGradRawKernel,
phi::SliceGradRawKernel,
bool,
int,
int64_t,
float,
double,
pten::dtype::complex<float>,
pten::dtype::complex<double>,
pten::dtype::bfloat16,
pten::dtype::float16) {}
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -12,21 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/slice_kernel.h"
#include "paddle/pten/kernels/impl/slice_kernel_impl.h"
#include "paddle/phi/kernels/slice_kernel.h"
#include "paddle/phi/kernels/impl/slice_kernel_impl.h"
#include "paddle/pten/backends/cpu/cpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PT_REGISTER_KERNEL(slice,
PD_REGISTER_KERNEL(slice,
CPU,
ALL_LAYOUT,
pten::SliceRawKernel,
phi::SliceRawKernel,
bool,
int,
int64_t,
float,
double,
pten::dtype::complex<float>,
pten::dtype::complex<double>,
pten::dtype::bfloat16) {}
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::bfloat16) {}
......@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/pten/core/ddim.h>
#include <paddle/phi/core/ddim.h>
#include <string>
#include <vector>
namespace pten {
namespace phi {
template <typename T = int64_t>
inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
inline void CheckAndUpdateSliceAttrs(const DDim in_dims,
const std::vector<T>& axes,
std::vector<T>* starts,
std::vector<T>* ends,
......@@ -31,7 +31,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
PADDLE_ENFORCE_LT(
axis,
in_dims.size(),
pten::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The axis value should be less than the rank of input, "
"but received axes[%d] = %d, rank of input is %d.",
i,
......@@ -49,7 +49,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
PADDLE_ENFORCE_NE(
step,
0,
pten::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
......@@ -65,7 +65,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
PADDLE_ENFORCE_GE(
end,
start,
pten::errors::InvalidArgument(
phi::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end,
......@@ -79,7 +79,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
PADDLE_ENFORCE_GE(
start,
end,
pten::errors::InvalidArgument(
phi::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start,
......@@ -96,14 +96,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
}
template <typename T = int64_t>
inline pten::framework::DDim GetSliceDims(
const pten::framework::DDim in_dims,
const std::vector<T>& axes,
const std::vector<T>& starts,
const std::vector<T>& ends,
std::vector<T>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
pten::framework::DDim slice_dims(in_dims);
inline phi::DDim GetSliceDims(const phi::DDim in_dims,
const std::vector<T>& axes,
const std::vector<T>& starts,
const std::vector<T>& ends,
std::vector<T>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
phi::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
......@@ -126,10 +125,10 @@ inline pten::framework::DDim GetSliceDims(
}
template <typename T = int64_t>
inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
const std::vector<T>& decrease_axes,
std::vector<T>* infer_flags = nullptr) {
framework::DDim decreased_dims(slice_dims);
inline DDim GetDecreasedDims(const DDim slice_dims,
const std::vector<T>& decrease_axes,
std::vector<T>* infer_flags = nullptr) {
DDim decreased_dims(slice_dims);
std::vector<uint8_t> decrease_flag(slice_dims.size(), 0);
if (decrease_axes.size() > 0) {
for (size_t i = 0; i < decrease_axes.size(); ++i) {
......@@ -138,7 +137,7 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
if (infer_flags && (*infer_flags)[i] != -1) {
PADDLE_ENFORCE_EQ(decreased_dims[axis],
1,
pten::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Decrease dim should be 1, but now received %d",
decreased_dims[axis]));
}
......@@ -162,4 +161,4 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
return decreased_dims;
}
} // namespace pten
} // namespace phi
......@@ -12,22 +12,22 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/pten/kernels/slice_grad_kernel.h"
#include "paddle/phi/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/phi/kernels/slice_grad_kernel.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PT_REGISTER_KERNEL(slice_grad,
PD_REGISTER_KERNEL(slice_grad,
GPU,
ALL_LAYOUT,
pten::SliceGradRawKernel,
phi::SliceGradRawKernel,
bool,
int,
int64_t,
float,
double,
pten::dtype::complex<float>,
pten::dtype::complex<double>,
pten::dtype::bfloat16,
pten::dtype::float16) {}
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -12,21 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/kernels/impl/slice_kernel_impl.h"
#include "paddle/pten/kernels/slice_kernel.h"
#include "paddle/phi/kernels/slice_kernel.h"
#include "paddle/phi/kernels/impl/slice_kernel_impl.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PT_REGISTER_KERNEL(slice,
PD_REGISTER_KERNEL(slice,
GPU,
ALL_LAYOUT,
pten::SliceRawKernel,
phi::SliceRawKernel,
bool,
int,
int64_t,
float,
double,
pten::dtype::complex<float>,
pten::dtype::complex<double>,
pten::dtype::bfloat16) {}
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::bfloat16) {}
......@@ -14,12 +14,12 @@
#pragma once
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/slice_utils.h"
#include "paddle/pten/kernels/slice_grad_kernel.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
#include "paddle/phi/kernels/slice_grad_kernel.h"
namespace pten {
namespace phi {
template <typename T, typename Context, size_t D>
void LaunchEigenPadding(
......@@ -108,8 +108,8 @@ void EigenPaddingCompute(
// out_tore_shape[1] = out_dims[pad_dim];
// // convert array from std::vector to DDim
// DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
// DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
// DDim reshaped_in_dims = make_ddim(in_tore_shape);
// DDim reshaped_out_dims = make_ddim(out_tore_shape);
// // after reshape: the first dimension do not need padding,
// // set padding[0] zero
......@@ -138,8 +138,8 @@ void EigenPaddingCompute(
// }
// // convert array from std::vector to DDim
// DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
// DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
// DDim reshaped_in_dims = make_ddim(in_tore_shape);
// DDim reshaped_out_dims = make_ddim(out_tore_shape);
// // after reshape:
// // the first dimension is the previous padding dimension
......@@ -173,8 +173,8 @@ void EigenPaddingCompute(
// }
// // convert array from std::vector to DDim
// DDim reshaped_in_dims = framework::make_ddim(in_tore_shape);
// DDim reshaped_out_dims = framework::make_ddim(out_tore_shape);
// DDim reshaped_in_dims = make_ddim(in_tore_shape);
// DDim reshaped_out_dims = make_ddim(out_tore_shape);
// // after reshape:
// // the first dimension do not need padding, set padding[0] zero
......@@ -219,7 +219,7 @@ void SliceGradCompute(const Context& ctx,
if (decrease_size == static_cast<size_t>(in_dims.size())) {
// all dims decrease
std::vector<int> origin_out_shape(decrease_size, 1);
out_dims = framework::make_ddim(std::vector<int>(decrease_size, 1));
out_dims = make_ddim(std::vector<int>(decrease_size, 1));
} else {
std::vector<int> origin_out_shape(out_dims.size() + decrease_size, -1);
for (size_t i = 0; i < decrease_size; ++i) {
......@@ -234,7 +234,7 @@ void SliceGradCompute(const Context& ctx,
}
}
out_dims = framework::make_ddim(origin_out_shape);
out_dims = make_ddim(origin_out_shape);
}
}
......@@ -334,9 +334,9 @@ void SliceGradRawKernel(const Context& ctx,
input_grad);
break;
default:
PADDLE_THROW(pten::errors::InvalidArgument(
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
} // namespace pten
} // namespace phi
......@@ -14,11 +14,11 @@
#pragma once
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/slice_utils.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
namespace pten {
namespace phi {
template <typename T, typename Context, size_t D>
void SliceCompute(const Context& ctx,
......@@ -35,11 +35,11 @@ void SliceCompute(const Context& ctx,
PADDLE_ENFORCE_EQ(
starts.size(),
axes.size(),
pten::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The size of starts must be equal to the size of axes."));
PADDLE_ENFORCE_EQ(ends.size(),
axes.size(),
pten::errors::InvalidArgument(
phi::errors::InvalidArgument(
"The size of ends must be equal to the size of axes."));
// Step 2: Compute output
......@@ -143,9 +143,9 @@ void SliceRawKernel(const Context& ctx,
ctx, input, axes, starts, ends, infer_flags, decrease_axis, out);
break;
default:
PADDLE_THROW(pten::errors::InvalidArgument(
PADDLE_THROW(phi::errors::InvalidArgument(
"The rank of input should be less than 7, but received %d.", rank));
}
}
} // namespace pten
} // namespace phi
......@@ -14,9 +14,9 @@
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
namespace pten {
namespace phi {
template <typename T, typename Context>
void SliceGradRawKernel(const Context& ctx,
......@@ -28,4 +28,4 @@ void SliceGradRawKernel(const Context& ctx,
const std::vector<int64_t>& decrease_axis,
DenseTensor* input_grad);
} // namespace pten
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
const DenseTensor& input,
const std::vector<int64_t>& axes,
const std::vector<int64_t>& starts,
const std::vector<int64_t>& ends,
const std::vector<int64_t>& infer_flags,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out);
} // namespace phi
......@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/phi/core/compat/op_utils.h"
namespace pten {
namespace phi {
KernelSignature SliceOpArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature(
......@@ -32,7 +32,7 @@ KernelSignature SliceGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
{GradVarName("Input")});
}
} // namespace pten
} // namespace phi
PT_REGISTER_ARG_MAPPING_FN(slice, pten::SliceOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(slice_grad, pten::SliceGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(slice, phi::SliceOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(slice_grad, phi::SliceGradOpArgumentMapping);
......@@ -14,9 +14,9 @@
#pragma once
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/phi/core/dense_tensor.h"
namespace pten {
namespace phi {
template <typename T, typename Context>
void SliceRawKernel(const Context& ctx,
......@@ -28,4 +28,4 @@ void SliceRawKernel(const Context& ctx,
const std::vector<int64_t>& decrease_axis,
DenseTensor* out);
} // namespace pten
} // namespace phi
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册