Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
fc7c39f5
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
“11518a436183474458e46be1239084fba6775d99”上不存在“python/paddle/distributed/run/plugins/ip.py”
提交
fc7c39f5
编写于
3月 02, 2022
作者:
P
phlrain
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix slice bug;
上级
d02df1ed
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
123 addition
and
93 deletion
+123
-93
paddle/phi/kernels/cpu/slice_grad_kernel.cc
paddle/phi/kernels/cpu/slice_grad_kernel.cc
+10
-10
paddle/phi/kernels/cpu/slice_kernel.cc
paddle/phi/kernels/cpu/slice_kernel.cc
+9
-9
paddle/phi/kernels/funcs/slice_utils.h
paddle/phi/kernels/funcs/slice_utils.h
+20
-21
paddle/phi/kernels/gpu/slice_grad_kernel.cu
paddle/phi/kernels/gpu/slice_grad_kernel.cu
+10
-10
paddle/phi/kernels/gpu/slice_kernel.cu.cc
paddle/phi/kernels/gpu/slice_kernel.cu.cc
+9
-9
paddle/phi/kernels/impl/slice_grad_kernel_impl.h
paddle/phi/kernels/impl/slice_grad_kernel_impl.h
+15
-15
paddle/phi/kernels/impl/slice_kernel_impl.h
paddle/phi/kernels/impl/slice_kernel_impl.h
+8
-8
paddle/phi/kernels/slice_grad_kernel.h
paddle/phi/kernels/slice_grad_kernel.h
+3
-3
paddle/phi/kernels/slice_kernel.h
paddle/phi/kernels/slice_kernel.h
+31
-0
paddle/phi/ops/compat/slice_sig.cc
paddle/phi/ops/compat/slice_sig.cc
+5
-5
paddle/pten/kernels/slice_kernel.h
paddle/pten/kernels/slice_kernel.h
+3
-3
未找到文件。
paddle/p
ten
/kernels/cpu/slice_grad_kernel.cc
→
paddle/p
hi
/kernels/cpu/slice_grad_kernel.cc
浏览文件 @
fc7c39f5
...
@@ -12,22 +12,22 @@
...
@@ -12,22 +12,22 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/p
ten
/kernels/slice_grad_kernel.h"
#include "paddle/p
hi
/kernels/slice_grad_kernel.h"
#include "paddle/p
ten
/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/p
hi
/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/p
ten
/backends/cpu/cpu_context.h"
#include "paddle/p
hi
/backends/cpu/cpu_context.h"
#include "paddle/p
ten
/core/kernel_registry.h"
#include "paddle/p
hi
/core/kernel_registry.h"
P
T
_REGISTER_KERNEL
(
slice_grad
,
P
D
_REGISTER_KERNEL
(
slice_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
p
ten
::
SliceGradRawKernel
,
p
hi
::
SliceGradRawKernel
,
bool
,
bool
,
int
,
int
,
int64_t
,
int64_t
,
float
,
float
,
double
,
double
,
p
ten
::
dtype
::
complex
<
float
>
,
p
hi
::
dtype
::
complex
<
float
>
,
p
ten
::
dtype
::
complex
<
double
>
,
p
hi
::
dtype
::
complex
<
double
>
,
p
ten
::
dtype
::
bfloat16
,
p
hi
::
dtype
::
bfloat16
,
p
ten
::
dtype
::
float16
)
{}
p
hi
::
dtype
::
float16
)
{}
paddle/p
ten
/kernels/cpu/slice_kernel.cc
→
paddle/p
hi
/kernels/cpu/slice_kernel.cc
浏览文件 @
fc7c39f5
...
@@ -12,21 +12,21 @@
...
@@ -12,21 +12,21 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/p
ten
/kernels/slice_kernel.h"
#include "paddle/p
hi
/kernels/slice_kernel.h"
#include "paddle/p
ten
/kernels/impl/slice_kernel_impl.h"
#include "paddle/p
hi
/kernels/impl/slice_kernel_impl.h"
#include "paddle/p
ten
/backends/cpu/cpu_context.h"
#include "paddle/p
hi
/backends/cpu/cpu_context.h"
#include "paddle/p
ten
/core/kernel_registry.h"
#include "paddle/p
hi
/core/kernel_registry.h"
P
T
_REGISTER_KERNEL
(
slice
,
P
D
_REGISTER_KERNEL
(
slice
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
p
ten
::
SliceRawKernel
,
p
hi
::
SliceRawKernel
,
bool
,
bool
,
int
,
int
,
int64_t
,
int64_t
,
float
,
float
,
double
,
double
,
p
ten
::
dtype
::
complex
<
float
>
,
p
hi
::
dtype
::
complex
<
float
>
,
p
ten
::
dtype
::
complex
<
double
>
,
p
hi
::
dtype
::
complex
<
double
>
,
p
ten
::
dtype
::
bfloat16
)
{}
p
hi
::
dtype
::
bfloat16
)
{}
paddle/p
ten
/kernels/funcs/slice_utils.h
→
paddle/p
hi
/kernels/funcs/slice_utils.h
浏览文件 @
fc7c39f5
...
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
...
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include <paddle/p
ten
/core/ddim.h>
#include <paddle/p
hi
/core/ddim.h>
#include <string>
#include <string>
#include <vector>
#include <vector>
namespace
p
ten
{
namespace
p
hi
{
template
<
typename
T
=
int64_t
>
template
<
typename
T
=
int64_t
>
inline
void
CheckAndUpdateSliceAttrs
(
const
framework
::
DDim
in_dims
,
inline
void
CheckAndUpdateSliceAttrs
(
const
DDim
in_dims
,
const
std
::
vector
<
T
>&
axes
,
const
std
::
vector
<
T
>&
axes
,
std
::
vector
<
T
>*
starts
,
std
::
vector
<
T
>*
starts
,
std
::
vector
<
T
>*
ends
,
std
::
vector
<
T
>*
ends
,
...
@@ -31,7 +31,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
...
@@ -31,7 +31,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
axis
,
axis
,
in_dims
.
size
(),
in_dims
.
size
(),
p
ten
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"The axis value should be less than the rank of input, "
"The axis value should be less than the rank of input, "
"but received axes[%d] = %d, rank of input is %d."
,
"but received axes[%d] = %d, rank of input is %d."
,
i
,
i
,
...
@@ -49,7 +49,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
...
@@ -49,7 +49,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
PADDLE_ENFORCE_NE
(
PADDLE_ENFORCE_NE
(
step
,
step
,
0
,
0
,
p
ten
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"Step should not be 0, but received step = %d."
,
step
));
"Step should not be 0, but received step = %d."
,
step
));
T
start
=
(
*
starts
)[
i
]
<
0
?
((
*
starts
)[
i
]
+
dim_value
)
:
(
*
starts
)[
i
];
T
start
=
(
*
starts
)[
i
]
<
0
?
((
*
starts
)[
i
]
+
dim_value
)
:
(
*
starts
)[
i
];
...
@@ -65,7 +65,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
...
@@ -65,7 +65,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
end
,
end
,
start
,
start
,
p
ten
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"When step > 0, end should be greater than start, but "
"When step > 0, end should be greater than start, but "
"received end = %d, start = %d."
,
"received end = %d, start = %d."
,
end
,
end
,
...
@@ -79,7 +79,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
...
@@ -79,7 +79,7 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
start
,
start
,
end
,
end
,
p
ten
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"When step < 0, start should be greater than end, but "
"When step < 0, start should be greater than end, but "
"received start = %d, end = %d."
,
"received start = %d, end = %d."
,
start
,
start
,
...
@@ -96,14 +96,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
...
@@ -96,14 +96,13 @@ inline void CheckAndUpdateSliceAttrs(const framework::DDim in_dims,
}
}
template
<
typename
T
=
int64_t
>
template
<
typename
T
=
int64_t
>
inline
pten
::
framework
::
DDim
GetSliceDims
(
inline
phi
::
DDim
GetSliceDims
(
const
phi
::
DDim
in_dims
,
const
pten
::
framework
::
DDim
in_dims
,
const
std
::
vector
<
T
>&
axes
,
const
std
::
vector
<
T
>&
axes
,
const
std
::
vector
<
T
>&
starts
,
const
std
::
vector
<
T
>&
starts
,
const
std
::
vector
<
T
>&
ends
,
const
std
::
vector
<
T
>&
ends
,
std
::
vector
<
T
>*
steps
=
nullptr
,
std
::
vector
<
T
>*
steps
=
nullptr
,
std
::
vector
<
T
>*
infer_flags
=
nullptr
)
{
std
::
vector
<
T
>*
infer_flags
=
nullptr
)
{
phi
::
DDim
slice_dims
(
in_dims
);
pten
::
framework
::
DDim
slice_dims
(
in_dims
);
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
T
axis
=
axes
[
i
];
T
axis
=
axes
[
i
];
...
@@ -126,10 +125,10 @@ inline pten::framework::DDim GetSliceDims(
...
@@ -126,10 +125,10 @@ inline pten::framework::DDim GetSliceDims(
}
}
template
<
typename
T
=
int64_t
>
template
<
typename
T
=
int64_t
>
inline
framework
::
DDim
GetDecreasedDims
(
const
framework
::
DDim
slice_dims
,
inline
DDim
GetDecreasedDims
(
const
DDim
slice_dims
,
const
std
::
vector
<
T
>&
decrease_axes
,
const
std
::
vector
<
T
>&
decrease_axes
,
std
::
vector
<
T
>*
infer_flags
=
nullptr
)
{
std
::
vector
<
T
>*
infer_flags
=
nullptr
)
{
framework
::
DDim
decreased_dims
(
slice_dims
);
DDim
decreased_dims
(
slice_dims
);
std
::
vector
<
uint8_t
>
decrease_flag
(
slice_dims
.
size
(),
0
);
std
::
vector
<
uint8_t
>
decrease_flag
(
slice_dims
.
size
(),
0
);
if
(
decrease_axes
.
size
()
>
0
)
{
if
(
decrease_axes
.
size
()
>
0
)
{
for
(
size_t
i
=
0
;
i
<
decrease_axes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
decrease_axes
.
size
();
++
i
)
{
...
@@ -138,7 +137,7 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
...
@@ -138,7 +137,7 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
if
(
infer_flags
&&
(
*
infer_flags
)[
i
]
!=
-
1
)
{
if
(
infer_flags
&&
(
*
infer_flags
)[
i
]
!=
-
1
)
{
PADDLE_ENFORCE_EQ
(
decreased_dims
[
axis
],
PADDLE_ENFORCE_EQ
(
decreased_dims
[
axis
],
1
,
1
,
p
ten
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"Decrease dim should be 1, but now received %d"
,
"Decrease dim should be 1, but now received %d"
,
decreased_dims
[
axis
]));
decreased_dims
[
axis
]));
}
}
...
@@ -162,4 +161,4 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
...
@@ -162,4 +161,4 @@ inline framework::DDim GetDecreasedDims(const framework::DDim slice_dims,
return
decreased_dims
;
return
decreased_dims
;
}
}
}
// namespace p
ten
}
// namespace p
hi
paddle/p
ten
/kernels/gpu/slice_grad_kernel.cu
→
paddle/p
hi
/kernels/gpu/slice_grad_kernel.cu
浏览文件 @
fc7c39f5
...
@@ -12,22 +12,22 @@
...
@@ -12,22 +12,22 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/p
ten
/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/p
hi
/kernels/impl/slice_grad_kernel_impl.h"
#include "paddle/p
ten
/kernels/slice_grad_kernel.h"
#include "paddle/p
hi
/kernels/slice_grad_kernel.h"
#include "paddle/p
ten
/backends/gpu/gpu_context.h"
#include "paddle/p
hi
/backends/gpu/gpu_context.h"
#include "paddle/p
ten
/core/kernel_registry.h"
#include "paddle/p
hi
/core/kernel_registry.h"
P
T
_REGISTER_KERNEL
(
slice_grad
,
P
D
_REGISTER_KERNEL
(
slice_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
p
ten
::
SliceGradRawKernel
,
p
hi
::
SliceGradRawKernel
,
bool
,
bool
,
int
,
int
,
int64_t
,
int64_t
,
float
,
float
,
double
,
double
,
p
ten
::
dtype
::
complex
<
float
>
,
p
hi
::
dtype
::
complex
<
float
>
,
p
ten
::
dtype
::
complex
<
double
>
,
p
hi
::
dtype
::
complex
<
double
>
,
p
ten
::
dtype
::
bfloat16
,
p
hi
::
dtype
::
bfloat16
,
p
ten
::
dtype
::
float16
)
{}
p
hi
::
dtype
::
float16
)
{}
paddle/p
ten/kernels/gpu/slice_kernel.cu
→
paddle/p
hi/kernels/gpu/slice_kernel.cu.cc
浏览文件 @
fc7c39f5
...
@@ -12,21 +12,21 @@
...
@@ -12,21 +12,21 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/p
ten/kernels/impl/slice_kernel_imp
l.h"
#include "paddle/p
hi/kernels/slice_kerne
l.h"
#include "paddle/p
ten/kernels/slice_kerne
l.h"
#include "paddle/p
hi/kernels/impl/slice_kernel_imp
l.h"
#include "paddle/p
ten
/backends/gpu/gpu_context.h"
#include "paddle/p
hi
/backends/gpu/gpu_context.h"
#include "paddle/p
ten
/core/kernel_registry.h"
#include "paddle/p
hi
/core/kernel_registry.h"
P
T
_REGISTER_KERNEL
(
slice
,
P
D
_REGISTER_KERNEL
(
slice
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
p
ten
::
SliceRawKernel
,
p
hi
::
SliceRawKernel
,
bool
,
bool
,
int
,
int
,
int64_t
,
int64_t
,
float
,
float
,
double
,
double
,
p
ten
::
dtype
::
complex
<
float
>
,
p
hi
::
dtype
::
complex
<
float
>
,
p
ten
::
dtype
::
complex
<
double
>
,
p
hi
::
dtype
::
complex
<
double
>
,
p
ten
::
dtype
::
bfloat16
)
{}
p
hi
::
dtype
::
bfloat16
)
{}
paddle/p
ten
/kernels/impl/slice_grad_kernel_impl.h
→
paddle/p
hi
/kernels/impl/slice_grad_kernel_impl.h
浏览文件 @
fc7c39f5
...
@@ -14,12 +14,12 @@
...
@@ -14,12 +14,12 @@
#pragma once
#pragma once
#include "paddle/p
ten
/kernels/funcs/eigen/common.h"
#include "paddle/p
hi
/kernels/funcs/eigen/common.h"
#include "paddle/p
ten
/kernels/funcs/eigen/eigen_function.h"
#include "paddle/p
hi
/kernels/funcs/eigen/eigen_function.h"
#include "paddle/p
ten
/kernels/funcs/slice_utils.h"
#include "paddle/p
hi
/kernels/funcs/slice_utils.h"
#include "paddle/p
ten
/kernels/slice_grad_kernel.h"
#include "paddle/p
hi
/kernels/slice_grad_kernel.h"
namespace
p
ten
{
namespace
p
hi
{
template
<
typename
T
,
typename
Context
,
size_t
D
>
template
<
typename
T
,
typename
Context
,
size_t
D
>
void
LaunchEigenPadding
(
void
LaunchEigenPadding
(
...
@@ -108,8 +108,8 @@ void EigenPaddingCompute(
...
@@ -108,8 +108,8 @@ void EigenPaddingCompute(
// out_tore_shape[1] = out_dims[pad_dim];
// out_tore_shape[1] = out_dims[pad_dim];
// // convert array from std::vector to DDim
// // convert array from std::vector to DDim
// DDim reshaped_in_dims =
framework::
make_ddim(in_tore_shape);
// DDim reshaped_in_dims = make_ddim(in_tore_shape);
// DDim reshaped_out_dims =
framework::
make_ddim(out_tore_shape);
// DDim reshaped_out_dims = make_ddim(out_tore_shape);
// // after reshape: the first dimension do not need padding,
// // after reshape: the first dimension do not need padding,
// // set padding[0] zero
// // set padding[0] zero
...
@@ -138,8 +138,8 @@ void EigenPaddingCompute(
...
@@ -138,8 +138,8 @@ void EigenPaddingCompute(
// }
// }
// // convert array from std::vector to DDim
// // convert array from std::vector to DDim
// DDim reshaped_in_dims =
framework::
make_ddim(in_tore_shape);
// DDim reshaped_in_dims = make_ddim(in_tore_shape);
// DDim reshaped_out_dims =
framework::
make_ddim(out_tore_shape);
// DDim reshaped_out_dims = make_ddim(out_tore_shape);
// // after reshape:
// // after reshape:
// // the first dimension is the previous padding dimension
// // the first dimension is the previous padding dimension
...
@@ -173,8 +173,8 @@ void EigenPaddingCompute(
...
@@ -173,8 +173,8 @@ void EigenPaddingCompute(
// }
// }
// // convert array from std::vector to DDim
// // convert array from std::vector to DDim
// DDim reshaped_in_dims =
framework::
make_ddim(in_tore_shape);
// DDim reshaped_in_dims = make_ddim(in_tore_shape);
// DDim reshaped_out_dims =
framework::
make_ddim(out_tore_shape);
// DDim reshaped_out_dims = make_ddim(out_tore_shape);
// // after reshape:
// // after reshape:
// // the first dimension do not need padding, set padding[0] zero
// // the first dimension do not need padding, set padding[0] zero
...
@@ -219,7 +219,7 @@ void SliceGradCompute(const Context& ctx,
...
@@ -219,7 +219,7 @@ void SliceGradCompute(const Context& ctx,
if
(
decrease_size
==
static_cast
<
size_t
>
(
in_dims
.
size
()))
{
if
(
decrease_size
==
static_cast
<
size_t
>
(
in_dims
.
size
()))
{
// all dims decrease
// all dims decrease
std
::
vector
<
int
>
origin_out_shape
(
decrease_size
,
1
);
std
::
vector
<
int
>
origin_out_shape
(
decrease_size
,
1
);
out_dims
=
framework
::
make_ddim
(
std
::
vector
<
int
>
(
decrease_size
,
1
));
out_dims
=
make_ddim
(
std
::
vector
<
int
>
(
decrease_size
,
1
));
}
else
{
}
else
{
std
::
vector
<
int
>
origin_out_shape
(
out_dims
.
size
()
+
decrease_size
,
-
1
);
std
::
vector
<
int
>
origin_out_shape
(
out_dims
.
size
()
+
decrease_size
,
-
1
);
for
(
size_t
i
=
0
;
i
<
decrease_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
decrease_size
;
++
i
)
{
...
@@ -234,7 +234,7 @@ void SliceGradCompute(const Context& ctx,
...
@@ -234,7 +234,7 @@ void SliceGradCompute(const Context& ctx,
}
}
}
}
out_dims
=
framework
::
make_ddim
(
origin_out_shape
);
out_dims
=
make_ddim
(
origin_out_shape
);
}
}
}
}
...
@@ -334,9 +334,9 @@ void SliceGradRawKernel(const Context& ctx,
...
@@ -334,9 +334,9 @@ void SliceGradRawKernel(const Context& ctx,
input_grad
);
input_grad
);
break
;
break
;
default:
default:
PADDLE_THROW
(
p
ten
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
p
hi
::
errors
::
InvalidArgument
(
"The rank of input should be less than 7, but received %d."
,
rank
));
"The rank of input should be less than 7, but received %d."
,
rank
));
}
}
}
}
}
// namespace p
ten
}
// namespace p
hi
paddle/p
ten
/kernels/impl/slice_kernel_impl.h
→
paddle/p
hi
/kernels/impl/slice_kernel_impl.h
浏览文件 @
fc7c39f5
...
@@ -14,11 +14,11 @@
...
@@ -14,11 +14,11 @@
#pragma once
#pragma once
#include "paddle/p
ten
/kernels/funcs/eigen/common.h"
#include "paddle/p
hi
/kernels/funcs/eigen/common.h"
#include "paddle/p
ten
/kernels/funcs/eigen/eigen_function.h"
#include "paddle/p
hi
/kernels/funcs/eigen/eigen_function.h"
#include "paddle/p
ten
/kernels/funcs/slice_utils.h"
#include "paddle/p
hi
/kernels/funcs/slice_utils.h"
namespace
p
ten
{
namespace
p
hi
{
template
<
typename
T
,
typename
Context
,
size_t
D
>
template
<
typename
T
,
typename
Context
,
size_t
D
>
void
SliceCompute
(
const
Context
&
ctx
,
void
SliceCompute
(
const
Context
&
ctx
,
...
@@ -35,11 +35,11 @@ void SliceCompute(const Context& ctx,
...
@@ -35,11 +35,11 @@ void SliceCompute(const Context& ctx,
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
starts
.
size
(),
starts
.
size
(),
axes
.
size
(),
axes
.
size
(),
p
ten
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"The size of starts must be equal to the size of axes."
));
"The size of starts must be equal to the size of axes."
));
PADDLE_ENFORCE_EQ
(
ends
.
size
(),
PADDLE_ENFORCE_EQ
(
ends
.
size
(),
axes
.
size
(),
axes
.
size
(),
p
ten
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"The size of ends must be equal to the size of axes."
));
"The size of ends must be equal to the size of axes."
));
// Step 2: Compute output
// Step 2: Compute output
...
@@ -143,9 +143,9 @@ void SliceRawKernel(const Context& ctx,
...
@@ -143,9 +143,9 @@ void SliceRawKernel(const Context& ctx,
ctx
,
input
,
axes
,
starts
,
ends
,
infer_flags
,
decrease_axis
,
out
);
ctx
,
input
,
axes
,
starts
,
ends
,
infer_flags
,
decrease_axis
,
out
);
break
;
break
;
default:
default:
PADDLE_THROW
(
p
ten
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
p
hi
::
errors
::
InvalidArgument
(
"The rank of input should be less than 7, but received %d."
,
rank
));
"The rank of input should be less than 7, but received %d."
,
rank
));
}
}
}
}
}
// namespace p
ten
}
// namespace p
hi
paddle/p
ten
/kernels/slice_grad_kernel.h
→
paddle/p
hi
/kernels/slice_grad_kernel.h
浏览文件 @
fc7c39f5
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
#pragma once
#pragma once
#include "paddle/p
ten
/core/dense_tensor.h"
#include "paddle/p
hi
/core/dense_tensor.h"
namespace
p
ten
{
namespace
p
hi
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
SliceGradRawKernel
(
const
Context
&
ctx
,
void
SliceGradRawKernel
(
const
Context
&
ctx
,
...
@@ -28,4 +28,4 @@ void SliceGradRawKernel(const Context& ctx,
...
@@ -28,4 +28,4 @@ void SliceGradRawKernel(const Context& ctx,
const
std
::
vector
<
int64_t
>&
decrease_axis
,
const
std
::
vector
<
int64_t
>&
decrease_axis
,
DenseTensor
*
input_grad
);
DenseTensor
*
input_grad
);
}
// namespace p
ten
}
// namespace p
hi
paddle/phi/kernels/slice_kernel.h
0 → 100644
浏览文件 @
fc7c39f5
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
SliceRawKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
input
,
const
std
::
vector
<
int64_t
>&
axes
,
const
std
::
vector
<
int64_t
>&
starts
,
const
std
::
vector
<
int64_t
>&
ends
,
const
std
::
vector
<
int64_t
>&
infer_flags
,
const
std
::
vector
<
int64_t
>&
decrease_axis
,
DenseTensor
*
out
);
}
// namespace phi
paddle/p
ten
/ops/compat/slice_sig.cc
→
paddle/p
hi
/ops/compat/slice_sig.cc
浏览文件 @
fc7c39f5
...
@@ -12,9 +12,9 @@
...
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include "paddle/p
ten
/core/compat/op_utils.h"
#include "paddle/p
hi
/core/compat/op_utils.h"
namespace
p
ten
{
namespace
p
hi
{
KernelSignature
SliceOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
KernelSignature
SliceOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
return
KernelSignature
(
...
@@ -32,7 +32,7 @@ KernelSignature SliceGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
...
@@ -32,7 +32,7 @@ KernelSignature SliceGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
{
GradVarName
(
"Input"
)});
{
GradVarName
(
"Input"
)});
}
}
}
// namespace p
ten
}
// namespace p
hi
P
T_REGISTER_ARG_MAPPING_FN
(
slice
,
pten
::
SliceOpArgumentMapping
);
P
D_REGISTER_ARG_MAPPING_FN
(
slice
,
phi
::
SliceOpArgumentMapping
);
P
T_REGISTER_ARG_MAPPING_FN
(
slice_grad
,
pten
::
SliceGradOpArgumentMapping
);
P
D_REGISTER_ARG_MAPPING_FN
(
slice_grad
,
phi
::
SliceGradOpArgumentMapping
);
paddle/pten/kernels/slice_kernel.h
浏览文件 @
fc7c39f5
...
@@ -14,9 +14,9 @@
...
@@ -14,9 +14,9 @@
#pragma once
#pragma once
#include "paddle/p
ten
/core/dense_tensor.h"
#include "paddle/p
hi
/core/dense_tensor.h"
namespace
p
ten
{
namespace
p
hi
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
SliceRawKernel
(
const
Context
&
ctx
,
void
SliceRawKernel
(
const
Context
&
ctx
,
...
@@ -28,4 +28,4 @@ void SliceRawKernel(const Context& ctx,
...
@@ -28,4 +28,4 @@ void SliceRawKernel(const Context& ctx,
const
std
::
vector
<
int64_t
>&
decrease_axis
,
const
std
::
vector
<
int64_t
>&
decrease_axis
,
DenseTensor
*
out
);
DenseTensor
*
out
);
}
// namespace p
ten
}
// namespace p
hi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录