Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8cecc9bd
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8cecc9bd
编写于
2月 17, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
merge develop; test=develop
上级
87099d12
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
40 addition
and
32 deletion
+40
-32
paddle/pten/kernels/funcs/slice_utils.h
paddle/pten/kernels/funcs/slice_utils.h
+34
-25
paddle/pten/kernels/impl/slice_grad_kernel_impl.h
paddle/pten/kernels/impl/slice_grad_kernel_impl.h
+1
-1
paddle/pten/kernels/impl/slice_kernel_impl.h
paddle/pten/kernels/impl/slice_kernel_impl.h
+5
-6
未找到文件。
paddle/
fluid/operator
s/slice_utils.h
→
paddle/
pten/kernels/func
s/slice_utils.h
浏览文件 @
8cecc9bd
...
@@ -13,12 +13,11 @@ See the License for the specific language governing permissions and
...
@@ -13,12 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <paddle/
fluid/framework/
dim.h>
#include <paddle/
pten/core/d
dim.h>
#include <string>
#include <string>
#include <vector>
#include <vector>
namespace
paddle
{
namespace
pten
{
namespace
operators
{
template
<
typename
T
=
int64_t
>
template
<
typename
T
=
int64_t
>
inline
void
CheckAndUpdateSliceAttrs
(
const
framework
::
DDim
in_dims
,
inline
void
CheckAndUpdateSliceAttrs
(
const
framework
::
DDim
in_dims
,
...
@@ -30,11 +29,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
...
@@ -30,11 +29,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
T
axis
=
axes
[
i
];
T
axis
=
axes
[
i
];
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
axis
,
in_dims
.
size
(),
axis
,
platform
::
errors
::
InvalidArgument
(
in_dims
.
size
(),
pten
::
errors
::
InvalidArgument
(
"The axis value should be less than the rank of input, "
"The axis value should be less than the rank of input, "
"but received axes[%d] = %d, rank of input is %d."
,
"but received axes[%d] = %d, rank of input is %d."
,
i
,
axis
,
in_dims
.
size
()));
i
,
axis
,
in_dims
.
size
()));
if
(
infer_flags
!=
nullptr
&&
(
*
infer_flags
)[
i
]
==
-
1
)
{
if
(
infer_flags
!=
nullptr
&&
(
*
infer_flags
)[
i
]
==
-
1
)
{
continue
;
continue
;
...
@@ -45,8 +47,10 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
...
@@ -45,8 +47,10 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
if
(
dim_value
>
0
)
{
if
(
dim_value
>
0
)
{
T
step
=
steps
==
nullptr
?
1
:
(
*
steps
)[
i
];
T
step
=
steps
==
nullptr
?
1
:
(
*
steps
)[
i
];
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
step
,
0
,
platform
::
errors
::
InvalidArgument
(
step
,
"Step should not be 0, but received step = %d."
,
step
));
0
,
pten
::
errors
::
InvalidArgument
(
"Step should not be 0, but received step = %d."
,
step
));
T
start
=
(
*
starts
)[
i
]
<
0
?
((
*
starts
)[
i
]
+
dim_value
)
:
(
*
starts
)[
i
];
T
start
=
(
*
starts
)[
i
]
<
0
?
((
*
starts
)[
i
]
+
dim_value
)
:
(
*
starts
)[
i
];
start
=
std
::
max
(
start
,
static_cast
<
T
>
(
0
));
start
=
std
::
max
(
start
,
static_cast
<
T
>
(
0
));
...
@@ -59,11 +63,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
...
@@ -59,11 +63,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
start
=
std
::
min
(
start
,
dim_value
);
start
=
std
::
min
(
start
,
dim_value
);
end
=
std
::
max
(
end
,
static_cast
<
T
>
(
0
));
end
=
std
::
max
(
end
,
static_cast
<
T
>
(
0
));
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
end
,
start
,
end
,
platform
::
errors
::
InvalidArgument
(
start
,
pten
::
errors
::
InvalidArgument
(
"When step > 0, end should be greater than start, but "
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d."
,
"received end = %d, start = %d."
,
end
,
start
));
end
,
start
));
}
else
{
}
else
{
// NOTE(liym27): When step < 0, start should less and equal to
// NOTE(liym27): When step < 0, start should less and equal to
// dim_value-1
// dim_value-1
...
@@ -71,11 +77,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
...
@@ -71,11 +77,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
start
=
std
::
min
(
start
,
dim_value
-
1
);
start
=
std
::
min
(
start
,
dim_value
-
1
);
end
=
std
::
max
(
end
,
static_cast
<
T
>
(
-
1
));
end
=
std
::
max
(
end
,
static_cast
<
T
>
(
-
1
));
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
start
,
end
,
start
,
platform
::
errors
::
InvalidArgument
(
end
,
pten
::
errors
::
InvalidArgument
(
"When step < 0, start should be greater than end, but "
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d."
,
"received start = %d, end = %d."
,
start
,
end
));
start
,
end
));
}
}
(
*
starts
)[
i
]
=
start
;
(
*
starts
)[
i
]
=
start
;
...
@@ -88,13 +96,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
...
@@ -88,13 +96,14 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
}
}
template
<
typename
T
=
int64_t
>
template
<
typename
T
=
int64_t
>
inline
framework
::
DDim
GetSliceDims
(
const
framework
::
DDim
in_dims
,
inline
pten
::
framework
::
DDim
GetSliceDims
(
const
std
::
vector
<
T
>&
axes
,
const
pten
::
framework
::
DDim
in_dims
,
const
std
::
vector
<
T
>&
starts
,
const
std
::
vector
<
T
>&
axes
,
const
std
::
vector
<
T
>&
ends
,
const
std
::
vector
<
T
>&
starts
,
std
::
vector
<
T
>*
steps
=
nullptr
,
const
std
::
vector
<
T
>&
ends
,
std
::
vector
<
T
>*
infer_flags
=
nullptr
)
{
std
::
vector
<
T
>*
steps
=
nullptr
,
framework
::
DDim
slice_dims
(
in_dims
);
std
::
vector
<
T
>*
infer_flags
=
nullptr
)
{
pten
::
framework
::
DDim
slice_dims
(
in_dims
);
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
T
axis
=
axes
[
i
];
T
axis
=
axes
[
i
];
...
@@ -127,8 +136,9 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
...
@@ -127,8 +136,9 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
T
axis
=
decrease_axes
[
i
];
T
axis
=
decrease_axes
[
i
];
decrease_flag
[
axis
]
=
1
;
decrease_flag
[
axis
]
=
1
;
if
(
infer_flags
&&
(
*
infer_flags
)[
i
]
!=
-
1
)
{
if
(
infer_flags
&&
(
*
infer_flags
)[
i
]
!=
-
1
)
{
PADDLE_ENFORCE_EQ
(
decreased_dims
[
axis
],
1
,
PADDLE_ENFORCE_EQ
(
decreased_dims
[
axis
],
platform
::
errors
::
InvalidArgument
(
1
,
pten
::
errors
::
InvalidArgument
(
"Decrease dim should be 1, but now received %d"
,
"Decrease dim should be 1, but now received %d"
,
decreased_dims
[
axis
]));
decreased_dims
[
axis
]));
}
}
...
@@ -152,5 +162,4 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
...
@@ -152,5 +162,4 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
return
decreased_dims
;
return
decreased_dims
;
}
}
}
// namespace operators
}
// namespace pten
}
// namespace paddle
paddle/pten/kernels/impl/slice_grad_kernel_impl.h
浏览文件 @
8cecc9bd
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
#pragma once
#pragma once
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/slice_utils.h"
#include "paddle/pten/kernels/slice_grad_kernel.h"
#include "paddle/pten/kernels/slice_grad_kernel.h"
namespace
pten
{
namespace
pten
{
...
...
paddle/pten/kernels/impl/slice_kernel_impl.h
浏览文件 @
8cecc9bd
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
#pragma once
#pragma once
#include "paddle/fluid/operators/slice_utils.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/common.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/eigen/eigen_function.h"
#include "paddle/pten/kernels/funcs/slice_utils.h"
namespace
pten
{
namespace
pten
{
...
@@ -60,11 +60,10 @@ void SliceCompute(const Context& ctx,
...
@@ -60,11 +60,10 @@ void SliceCompute(const Context& ctx,
}
}
}
}
paddle
::
operators
::
CheckAndUpdateSliceAttrs
<
int64_t
>
(
CheckAndUpdateSliceAttrs
<
int64_t
>
(
in_dims
,
axes
,
&
starts
,
&
ends
);
in_dims
,
axes
,
&
starts
,
&
ends
);
slice_dims
=
slice_dims
=
paddle
::
operators
::
GetSliceDims
<
int64_t
>
(
GetSliceDims
<
int64_t
>
(
in_dims
,
axes
,
starts
,
ends
,
nullptr
,
nullptr
);
in_dims
,
axes
,
starts
,
ends
,
nullptr
,
nullptr
);
out_dims
=
GetDecreasedDims
<
int64_t
>
(
slice_dims
,
decrease_axis
);
out_dims
=
paddle
::
operators
::
GetDecreasedDims
(
slice_dims
,
decrease_axis
);
// 2.2 Get output
// 2.2 Get output
auto
offsets
=
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
D
>
();
auto
offsets
=
Eigen
::
DSizes
<
Eigen
::
DenseIndex
,
D
>
();
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录