提交 8cecc9bd 编写于 作者: P phlrain

merge develop; test=develop

上级 87099d12
......@@ -13,12 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <paddle/fluid/framework/dim.h>
#include <paddle/pten/core/ddim.h>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
namespace pten {
template <typename T = int64_t>
inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
......@@ -30,11 +29,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
PADDLE_ENFORCE_LT(
axis, in_dims.size(),
platform::errors::InvalidArgument(
axis,
in_dims.size(),
pten::errors::InvalidArgument(
"The axis value should be less than the rank of input, "
"but received axes[%d] = %d, rank of input is %d.",
i, axis, in_dims.size()));
i,
axis,
in_dims.size()));
if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue;
......@@ -45,7 +47,9 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
if (dim_value > 0) {
T step = steps == nullptr ? 1 : (*steps)[i];
PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument(
step,
0,
pten::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step));
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
......@@ -59,11 +63,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0));
PADDLE_ENFORCE_GE(
end, start,
platform::errors::InvalidArgument(
end,
start,
pten::errors::InvalidArgument(
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d.",
end, start));
end,
start));
} else {
// NOTE(liym27): When step < 0, start should less and equal to
// dim_value-1
......@@ -71,11 +77,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<T>(-1));
PADDLE_ENFORCE_GE(
start, end,
platform::errors::InvalidArgument(
start,
end,
pten::errors::InvalidArgument(
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d.",
start, end));
start,
end));
}
(*starts)[i] = start;
......@@ -88,13 +96,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
}
template <typename T = int64_t>
inline framework::DDim GetSliceDims(const framework::DDim in_dims,
inline pten::framework::DDim GetSliceDims(
const pten::framework::DDim in_dims,
const std::vector<T>& axes,
const std::vector<T>& starts,
const std::vector<T>& ends,
std::vector<T>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) {
framework::DDim slice_dims(in_dims);
pten::framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i];
......@@ -127,8 +136,9 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
T axis = decrease_axes[i];
decrease_flag[axis] = 1;
if (infer_flags && (*infer_flags)[i] != -1) {
PADDLE_ENFORCE_EQ(decreased_dims[axis], 1,
platform::errors::InvalidArgument(
PADDLE_ENFORCE_EQ(decreased_dims[axis],
1,
pten::errors::InvalidArgument(
"Decrease dim should be 1, but now received %d",
decreased_dims[axis]));
}
......@@ -152,5 +162,4 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
return decreased_dims;
}
} // namespace operators
} // namespace paddle
} // namespace pten
......@@ -14,9 +14,9 @@
#pragma once
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/slice_utils.h"
#include "paddle/pten/kernels/slice_grad_kernel.h"
namespace pten {
......
......@@ -14,9 +14,9 @@
#pragma once
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/slice_utils.h"
namespace pten {
......@@ -60,11 +60,10 @@ void SliceCompute(const Context& ctx,
}
}
paddle::operators::CheckAndUpdateSliceAttrs<int64_t>(
in_dims, axes, &starts, &ends);
slice_dims = paddle::operators::GetSliceDims<int64_t>(
in_dims, axes, starts, ends, nullptr, nullptr);
out_dims = paddle::operators::GetDecreasedDims(slice_dims, decrease_axis);
CheckAndUpdateSliceAttrs<int64_t>(in_dims, axes, &starts, &ends);
slice_dims =
GetSliceDims<int64_t>(in_dims, axes, starts, ends, nullptr, nullptr);
out_dims = GetDecreasedDims<int64_t>(slice_dims, decrease_axis);
// 2.2 Get output
auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册