未验证 提交 c0f993f6 编写于 作者: Z zyfncg 提交者: GitHub

Fix slice error in jit.to_static mode (#39251)

* fix slice bug

* fix syntax error
上级 a1addeef
...@@ -43,6 +43,10 @@ inline void DealTensorArray(const framework::ExecutionContext& ctx, ...@@ -43,6 +43,10 @@ inline void DealTensorArray(const framework::ExecutionContext& ctx,
end = std::max(end, static_cast<int64_t>(0)); end = std::max(end, static_cast<int64_t>(0));
end = std::min(end, in_size); end = std::min(end, in_size);
if (starts[0] == -1 && end == 0) {
end = start + 1;
}
PADDLE_ENFORCE_GT(end, start, PADDLE_ENFORCE_GT(end, start,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Attr(ends) should be greater than attr(starts) in " "Attr(ends) should be greater than attr(starts) in "
...@@ -330,7 +334,7 @@ class SliceGradKernel : public framework::OpKernel<T> { ...@@ -330,7 +334,7 @@ class SliceGradKernel : public framework::OpKernel<T> {
auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis"); auto decrease_axis = ctx.Attr<std::vector<int>>("decrease_axis");
auto decrease_size = decrease_axis.size(); auto decrease_size = decrease_axis.size();
if (decrease_size > 0) { if (decrease_size > 0) {
if (decrease_size == (size_t)in_dims.size()) { if (decrease_size == static_cast<size_t>(in_dims.size())) {
// all dims decrease // all dims decrease
std::vector<int> origin_out_shape(decrease_size, 1); std::vector<int> origin_out_shape(decrease_size, 1);
out_dims = framework::make_ddim(std::vector<int>(decrease_size, 1)); out_dims = framework::make_ddim(std::vector<int>(decrease_size, 1));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册