提交 8cecc9bd 编写于 作者: P phlrain

merge develop; test=develop

上级 87099d12
...@@ -13,12 +13,11 @@ See the License for the specific language governing permissions and ...@@ -13,12 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <paddle/fluid/framework/dim.h> #include <paddle/pten/core/ddim.h>
#include <string> #include <string>
#include <vector> #include <vector>
namespace paddle { namespace pten {
namespace operators {
template <typename T = int64_t> template <typename T = int64_t>
inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
...@@ -30,11 +29,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, ...@@ -30,11 +29,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i]; T axis = axes[i];
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
axis, in_dims.size(), axis,
platform::errors::InvalidArgument( in_dims.size(),
pten::errors::InvalidArgument(
"The axis value should be less than the rank of input, " "The axis value should be less than the rank of input, "
"but received axes[%d] = %d, rank of input is %d.", "but received axes[%d] = %d, rank of input is %d.",
i, axis, in_dims.size())); i,
axis,
in_dims.size()));
if (infer_flags != nullptr && (*infer_flags)[i] == -1) { if (infer_flags != nullptr && (*infer_flags)[i] == -1) {
continue; continue;
...@@ -45,7 +47,9 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, ...@@ -45,7 +47,9 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
if (dim_value > 0) { if (dim_value > 0) {
T step = steps == nullptr ? 1 : (*steps)[i]; T step = steps == nullptr ? 1 : (*steps)[i];
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
step, 0, platform::errors::InvalidArgument( step,
0,
pten::errors::InvalidArgument(
"Step should not be 0, but received step = %d.", step)); "Step should not be 0, but received step = %d.", step));
T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i]; T start = (*starts)[i] < 0 ? ((*starts)[i] + dim_value) : (*starts)[i];
...@@ -59,11 +63,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, ...@@ -59,11 +63,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
start = std::min(start, dim_value); start = std::min(start, dim_value);
end = std::max(end, static_cast<T>(0)); end = std::max(end, static_cast<T>(0));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
end, start, end,
platform::errors::InvalidArgument( start,
pten::errors::InvalidArgument(
"When step > 0, end should be greater than start, but " "When step > 0, end should be greater than start, but "
"received end = %d, start = %d.", "received end = %d, start = %d.",
end, start)); end,
start));
} else { } else {
// NOTE(liym27): When step < 0, start should less and equal to // NOTE(liym27): When step < 0, start should less and equal to
// dim_value-1 // dim_value-1
...@@ -71,11 +77,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, ...@@ -71,11 +77,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
start = std::min(start, dim_value - 1); start = std::min(start, dim_value - 1);
end = std::max(end, static_cast<T>(-1)); end = std::max(end, static_cast<T>(-1));
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
start, end, start,
platform::errors::InvalidArgument( end,
pten::errors::InvalidArgument(
"When step < 0, start should be greater than end, but " "When step < 0, start should be greater than end, but "
"received start = %d, end = %d.", "received start = %d, end = %d.",
start, end)); start,
end));
} }
(*starts)[i] = start; (*starts)[i] = start;
...@@ -88,13 +96,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims, ...@@ -88,13 +96,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
} }
template <typename T = int64_t> template <typename T = int64_t>
inline framework::DDim GetSliceDims(const framework::DDim in_dims, inline pten::framework::DDim GetSliceDims(
const pten::framework::DDim in_dims,
const std::vector<T>& axes, const std::vector<T>& axes,
const std::vector<T>& starts, const std::vector<T>& starts,
const std::vector<T>& ends, const std::vector<T>& ends,
std::vector<T>* steps = nullptr, std::vector<T>* steps = nullptr,
std::vector<T>* infer_flags = nullptr) { std::vector<T>* infer_flags = nullptr) {
framework::DDim slice_dims(in_dims); pten::framework::DDim slice_dims(in_dims);
for (size_t i = 0; i < axes.size(); ++i) { for (size_t i = 0; i < axes.size(); ++i) {
T axis = axes[i]; T axis = axes[i];
...@@ -127,8 +136,9 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, ...@@ -127,8 +136,9 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
T axis = decrease_axes[i]; T axis = decrease_axes[i];
decrease_flag[axis] = 1; decrease_flag[axis] = 1;
if (infer_flags && (*infer_flags)[i] != -1) { if (infer_flags && (*infer_flags)[i] != -1) {
PADDLE_ENFORCE_EQ(decreased_dims[axis], 1, PADDLE_ENFORCE_EQ(decreased_dims[axis],
platform::errors::InvalidArgument( 1,
pten::errors::InvalidArgument(
"Decrease dim should be 1, but now received %d", "Decrease dim should be 1, but now received %d",
decreased_dims[axis])); decreased_dims[axis]));
} }
...@@ -152,5 +162,4 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims, ...@@ -152,5 +162,4 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
return decreased_dims; return decreased_dims;
} }
} // namespace operators } // namespace pten
} // namespace paddle
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/pten/kernels/funcs/eigen/common.h" #include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h" #include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/slice_utils.h"
#include "paddle/pten/kernels/slice_grad_kernel.h" #include "paddle/pten/kernels/slice_grad_kernel.h"
namespace pten { namespace pten {
......
...@@ -14,9 +14,9 @@ ...@@ -14,9 +14,9 @@
#pragma once #pragma once
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/pten/kernels/funcs/eigen/common.h" #include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h" #include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/slice_utils.h"
namespace pten { namespace pten {
...@@ -60,11 +60,10 @@ void SliceCompute(const Context& ctx, ...@@ -60,11 +60,10 @@ void SliceCompute(const Context& ctx,
} }
} }
paddle::operators::CheckAndUpdateSliceAttrs<int64_t>( CheckAndUpdateSliceAttrs<int64_t>(in_dims, axes, &starts, &ends);
in_dims, axes, &starts, &ends); slice_dims =
slice_dims = paddle::operators::GetSliceDims<int64_t>( GetSliceDims<int64_t>(in_dims, axes, starts, ends, nullptr, nullptr);
in_dims, axes, starts, ends, nullptr, nullptr); out_dims = GetDecreasedDims<int64_t>(slice_dims, decrease_axis);
out_dims = paddle::operators::GetDecreasedDims(slice_dims, decrease_axis);
// 2.2 Get output // 2.2 Get output
auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>(); auto offsets = Eigen::DSizes<Eigen::DenseIndex, D>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册