Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
8c25dfaa
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8c25dfaa
编写于
10月 13, 2020
作者:
T
Thunderbrook
提交者:
GitHub
10月 13, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
op error info (#27856)
* op error info * style * code format
上级
79b5db13
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
125 addition
and
72 deletion
+125
-72
paddle/fluid/operators/slice_op.cc
paddle/fluid/operators/slice_op.cc
+29
-15
paddle/fluid/operators/slice_op.h
paddle/fluid/operators/slice_op.h
+3
-2
paddle/fluid/operators/space_to_depth_op.cc
paddle/fluid/operators/space_to_depth_op.cc
+57
-30
paddle/fluid/operators/split_op.cc
paddle/fluid/operators/split_op.cc
+8
-5
paddle/fluid/operators/split_op.h
paddle/fluid/operators/split_op.h
+28
-20
未找到文件。
paddle/fluid/operators/slice_op.cc
浏览文件 @
8c25dfaa
...
@@ -29,10 +29,12 @@ class SliceOp : public framework::OperatorWithKernel {
...
@@ -29,10 +29,12 @@ class SliceOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Input"
),
true
,
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Input"
),
true
,
"Input (Input) of slice op should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Input (Input) of slice op should not be null."
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
PADDLE_ENFORCE_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
"Output (Out) of slice op should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Output (Out) of slice op should not be null."
));
auto
x_var_type
=
ctx
->
GetInputsVarType
(
"Input"
)[
0
];
auto
x_var_type
=
ctx
->
GetInputsVarType
(
"Input"
)[
0
];
auto
axes
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"axes"
);
auto
axes
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"axes"
);
if
(
x_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
if
(
x_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
...
@@ -57,7 +59,8 @@ class SliceOp : public framework::OperatorWithKernel {
...
@@ -57,7 +59,8 @@ class SliceOp : public framework::OperatorWithKernel {
}
}
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"Input"
);
PADDLE_ENFORCE_LT
(
in_dims
.
size
(),
7
,
PADDLE_ENFORCE_LT
(
in_dims
.
size
(),
7
,
"The rank of input should be less than 7."
);
platform
::
errors
::
InvalidArgument
(
"The rank of input should be less than 7."
));
framework
::
DDim
out_dims
(
in_dims
);
framework
::
DDim
out_dims
(
in_dims
);
auto
starts
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"starts"
);
auto
starts
=
ctx
->
Attrs
().
Get
<
std
::
vector
<
int
>>
(
"starts"
);
...
@@ -76,31 +79,37 @@ class SliceOp : public framework::OperatorWithKernel {
...
@@ -76,31 +79,37 @@ class SliceOp : public framework::OperatorWithKernel {
if
(
ctx
->
HasInputs
(
"StartsTensorList"
))
{
if
(
ctx
->
HasInputs
(
"StartsTensorList"
))
{
auto
StartsTensorList
=
ctx
->
Inputs
(
"StartsTensorList"
);
auto
StartsTensorList
=
ctx
->
Inputs
(
"StartsTensorList"
);
PADDLE_ENFORCE_GT
(
StartsTensorList
.
size
(),
0
,
PADDLE_ENFORCE_GT
(
StartsTensorList
.
size
(),
0
,
"StartsTensorList size can't be zero"
);
platform
::
errors
::
InvalidArgument
(
"StartsTensorList size can't be zero"
));
starts_size
=
StartsTensorList
.
size
();
starts_size
=
StartsTensorList
.
size
();
}
}
if
(
ctx
->
HasInputs
(
"EndsTensorList"
))
{
if
(
ctx
->
HasInputs
(
"EndsTensorList"
))
{
auto
EndsTensorList
=
ctx
->
Inputs
(
"EndsTensorList"
);
auto
EndsTensorList
=
ctx
->
Inputs
(
"EndsTensorList"
);
PADDLE_ENFORCE_GT
(
EndsTensorList
.
size
(),
0
,
PADDLE_ENFORCE_GT
(
EndsTensorList
.
size
(),
0
,
"EndsTensorList size can't be zero"
);
platform
::
errors
::
InvalidArgument
(
"EndsTensorList size can't be zero"
));
ends_size
=
EndsTensorList
.
size
();
ends_size
=
EndsTensorList
.
size
();
}
}
if
(
ctx
->
HasInput
(
"StartsTensor"
)
==
false
)
{
if
(
ctx
->
HasInput
(
"StartsTensor"
)
==
false
)
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
starts_size
,
axes
.
size
(),
starts_size
,
axes
.
size
(),
"The size of starts must be equal to the size of axes."
);
platform
::
errors
::
InvalidArgument
(
"The size of starts must be equal to the size of axes."
));
}
}
if
(
ctx
->
HasInput
(
"EndsTensor"
)
==
false
)
{
if
(
ctx
->
HasInput
(
"EndsTensor"
)
==
false
)
{
PADDLE_ENFORCE_EQ
(
ends_size
,
axes
.
size
(),
PADDLE_ENFORCE_EQ
(
"The size of ends must be equal to the size of axes."
);
ends_size
,
axes
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The size of ends must be equal to the size of axes."
));
}
}
int
dim_value
,
start
,
end
;
int
dim_value
,
start
,
end
;
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
axes
.
size
();
++
i
)
{
PADDLE_ENFORCE_LT
(
static_cast
<
int
>
(
axes
[
i
]),
in_dims
.
size
(),
PADDLE_ENFORCE_LT
(
static_cast
<
int
>
(
axes
[
i
]),
in_dims
.
size
(),
"The index of dimension in axes must be less "
platform
::
errors
::
InvalidArgument
(
"than the size of input shape."
);
"The index of dimension in axes must be less "
"than the size of input shape."
));
if
(
infer_flags
[
i
]
==
-
1
)
{
if
(
infer_flags
[
i
]
==
-
1
)
{
out_dims
[
axes
[
i
]]
=
-
1
;
out_dims
[
axes
[
i
]]
=
-
1
;
}
else
{
}
else
{
...
@@ -112,7 +121,8 @@ class SliceOp : public framework::OperatorWithKernel {
...
@@ -112,7 +121,8 @@ class SliceOp : public framework::OperatorWithKernel {
start
=
std
::
max
(
start
,
0
);
start
=
std
::
max
(
start
,
0
);
end
=
std
::
max
(
end
,
0
);
end
=
std
::
max
(
end
,
0
);
end
=
std
::
min
(
end
,
dim_value
);
end
=
std
::
min
(
end
,
dim_value
);
PADDLE_ENFORCE_GT
(
end
,
start
,
"end should greater than start"
);
PADDLE_ENFORCE_GT
(
end
,
start
,
platform
::
errors
::
InvalidArgument
(
"end should greater than start"
));
out_dims
[
axes
[
i
]]
=
end
-
start
;
out_dims
[
axes
[
i
]]
=
end
-
start
;
}
}
}
}
...
@@ -122,8 +132,9 @@ class SliceOp : public framework::OperatorWithKernel {
...
@@ -122,8 +132,9 @@ class SliceOp : public framework::OperatorWithKernel {
std
::
vector
<
int
>
new_out_shape
;
std
::
vector
<
int
>
new_out_shape
;
for
(
size_t
i
=
0
;
i
<
decrease_axis
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
decrease_axis
.
size
();
++
i
)
{
if
(
ctx
->
IsRuntime
()
&&
infer_flags
[
i
]
!=
-
1
)
{
if
(
ctx
->
IsRuntime
()
&&
infer_flags
[
i
]
!=
-
1
)
{
PADDLE_ENFORCE_EQ
(
out_dims
[
decrease_axis
[
i
]],
1
,
PADDLE_ENFORCE_EQ
(
"decrease dim should be 1"
);
out_dims
[
decrease_axis
[
i
]],
1
,
platform
::
errors
::
InvalidArgument
(
"decrease dim should be 1"
));
}
}
out_dims
[
decrease_axis
[
i
]]
=
0
;
out_dims
[
decrease_axis
[
i
]]
=
0
;
}
}
...
@@ -284,9 +295,12 @@ class SliceOpGrad : public framework::OperatorWithKernel {
...
@@ -284,9 +295,12 @@ class SliceOpGrad : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Input"
),
true
,
"Input should not be null"
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Input"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input should not be null"
));
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
true
,
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
true
,
"Input(Out@GRAD) should not be null"
);
platform
::
errors
::
InvalidArgument
(
"Input(Out@GRAD) should not be null"
));
auto
x_var_type
=
ctx
->
GetInputsVarType
(
"Input"
)[
0
];
auto
x_var_type
=
ctx
->
GetInputsVarType
(
"Input"
)[
0
];
if
(
x_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
if
(
x_var_type
==
framework
::
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
{
// If the var type of input is LOD_TENSOR_ARRAY,
// If the var type of input is LOD_TENSOR_ARRAY,
...
...
paddle/fluid/operators/slice_op.h
浏览文件 @
8c25dfaa
...
@@ -191,8 +191,9 @@ class SliceKernel : public framework::OpKernel<T> {
...
@@ -191,8 +191,9 @@ class SliceKernel : public framework::OpKernel<T> {
if
(
decrease_axis
.
size
()
>
0
)
{
if
(
decrease_axis
.
size
()
>
0
)
{
std
::
vector
<
int64_t
>
new_out_shape
;
std
::
vector
<
int64_t
>
new_out_shape
;
for
(
size_t
i
=
0
;
i
<
decrease_axis
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
decrease_axis
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
out_dims
[
decrease_axis
[
i
]],
1
,
PADDLE_ENFORCE_EQ
(
"decrease dim should be 1"
);
out_dims
[
decrease_axis
[
i
]],
1
,
platform
::
errors
::
InvalidArgument
(
"decrease dim should be 1"
));
out_dims
[
decrease_axis
[
i
]]
=
0
;
out_dims
[
decrease_axis
[
i
]]
=
0
;
}
}
...
...
paddle/fluid/operators/space_to_depth_op.cc
浏览文件 @
8c25dfaa
...
@@ -31,51 +31,76 @@ class SpaceToDepthOp : public framework::OperatorWithKernel {
...
@@ -31,51 +31,76 @@ class SpaceToDepthOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SpaceToDepthOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Input(X) of SpaceToDepthOp should not be null."
));
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of SpaceToDepthOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Output(Out) of SpaceToDepthOp should not be null."
));
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
4
,
"input should be a 4D tensor"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"input should be a 4D tensor"
));
auto
blocksize
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"blocksize"
);
auto
blocksize
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"blocksize"
);
PADDLE_ENFORCE_GT
(
blocksize
,
1
,
"The blocksize should be Greater than 1"
);
PADDLE_ENFORCE_GT
(
blocksize
,
1
,
platform
::
errors
::
InvalidArgument
(
"The blocksize should be Greater than 1"
));
if
(
ctx
->
IsRuntime
())
{
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_GT
(
x_dims
[
1
],
0
,
"input channel should be Greater than 0"
);
PADDLE_ENFORCE_GT
(
x_dims
[
1
],
0
,
PADDLE_ENFORCE_GT
(
x_dims
[
2
],
0
,
"input Height should be Greater than 0"
);
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GT
(
x_dims
[
3
],
0
,
"input Width should be Greater than 0"
);
"input channel should be Greater than 0"
));
PADDLE_ENFORCE_GT
(
x_dims
[
2
],
0
,
PADDLE_ENFORCE_EQ
(
x_dims
[
1
]
%
(
blocksize
*
blocksize
),
0
,
platform
::
errors
::
InvalidArgument
(
"input channel should be divisible of the square of "
"input Height should be Greater than 0"
));
"SpaceToDepthOp blocksize"
);
PADDLE_ENFORCE_GT
(
x_dims
[
3
],
0
,
platform
::
errors
::
InvalidArgument
(
"input Width should be Greater than 0"
));
PADDLE_ENFORCE_EQ
(
x_dims
[
1
]
%
(
blocksize
*
blocksize
),
0
,
platform
::
errors
::
InvalidArgument
(
"input channel should be divisible of the square of "
"SpaceToDepthOp blocksize"
));
PADDLE_ENFORCE_EQ
(
x_dims
[
2
]
%
(
blocksize
),
0
,
PADDLE_ENFORCE_EQ
(
x_dims
[
2
]
%
(
blocksize
),
0
,
"input Height should be divisible of the square of "
platform
::
errors
::
InvalidArgument
(
"SpaceToDepthOp blocksize"
);
"input Height should be divisible of the square of "
"SpaceToDepthOp blocksize"
));
PADDLE_ENFORCE_EQ
(
x_dims
[
3
]
%
(
blocksize
),
0
,
PADDLE_ENFORCE_EQ
(
x_dims
[
3
]
%
(
blocksize
),
0
,
"input Width should be divisible of the square of "
platform
::
errors
::
InvalidArgument
(
"SpaceToDepthOp blocksize"
);
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize"
));
}
else
{
}
else
{
if
(
x_dims
[
1
]
!=
-
1
)
{
if
(
x_dims
[
1
]
!=
-
1
)
{
PADDLE_ENFORCE_GT
(
x_dims
[
1
],
0
,
PADDLE_ENFORCE_GT
(
x_dims
[
1
],
0
,
"input channel should be Greater than 0"
);
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_EQ
(
x_dims
[
1
]
%
(
blocksize
*
blocksize
),
0
,
"input channel should be Greater than 0"
));
"input channel should be divisible of the square of "
PADDLE_ENFORCE_EQ
(
"SpaceToDepthOp blocksize"
);
x_dims
[
1
]
%
(
blocksize
*
blocksize
),
0
,
platform
::
errors
::
InvalidArgument
(
"input channel should be divisible of the square of "
"SpaceToDepthOp blocksize"
));
}
}
if
(
x_dims
[
2
]
!=
-
1
)
{
if
(
x_dims
[
2
]
!=
-
1
)
{
PADDLE_ENFORCE_GT
(
x_dims
[
2
],
0
,
PADDLE_ENFORCE_GT
(
x_dims
[
2
],
0
,
"input Height should be Greater than 0"
);
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_EQ
(
x_dims
[
2
]
%
(
blocksize
),
0
,
"input Height should be Greater than 0"
));
"input Height should be divisible of the square of "
PADDLE_ENFORCE_EQ
(
"SpaceToDepthOp blocksize"
);
x_dims
[
2
]
%
(
blocksize
),
0
,
platform
::
errors
::
InvalidArgument
(
"input Height should be divisible of the square of "
"SpaceToDepthOp blocksize"
));
}
}
if
(
x_dims
[
3
]
!=
-
1
)
{
if
(
x_dims
[
3
]
!=
-
1
)
{
PADDLE_ENFORCE_GT
(
x_dims
[
3
],
0
,
"input Width should be Greater than 0"
);
PADDLE_ENFORCE_GT
(
x_dims
[
3
],
0
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_EQ
(
x_dims
[
3
]
%
(
blocksize
),
0
,
"input Width should be Greater than 0"
));
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize"
);
PADDLE_ENFORCE_EQ
(
x_dims
[
3
]
%
(
blocksize
),
0
,
platform
::
errors
::
InvalidArgument
(
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize"
));
}
}
}
}
...
@@ -156,9 +181,11 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel {
...
@@ -156,9 +181,11 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel {
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) shouldn't be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
platform
::
errors
::
InvalidArgument
(
"Input(X) shouldn't be null."
));
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) shouldn't be null."
);
platform
::
errors
::
InvalidArgument
(
"Input(Out@GRAD) shouldn't be null."
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
ctx
->
GetInputDim
(
"X"
));
}
}
...
...
paddle/fluid/operators/split_op.cc
浏览文件 @
8c25dfaa
...
@@ -25,9 +25,11 @@ class SplitOp : public framework::OperatorWithKernel {
...
@@ -25,9 +25,11 @@ class SplitOp : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) of SplitOp should not be null."
);
platform
::
errors
::
InvalidArgument
(
"Input(X) of SplitOp should not be null."
));
PADDLE_ENFORCE_GE
(
ctx
->
Outputs
(
"Out"
).
size
(),
1UL
,
PADDLE_ENFORCE_GE
(
ctx
->
Outputs
(
"Out"
).
size
(),
1UL
,
"Outputs(Out) of SplitOp should not be empty."
);
platform
::
errors
::
InvalidArgument
(
"Outputs(Out) of SplitOp should not be empty."
));
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
in_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
outs_names
=
ctx
->
Outputs
(
"Out"
);
auto
outs_names
=
ctx
->
Outputs
(
"Out"
);
size_t
axis
=
static_cast
<
size_t
>
(
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
));
size_t
axis
=
static_cast
<
size_t
>
(
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
));
...
@@ -37,9 +39,10 @@ class SplitOp : public framework::OperatorWithKernel {
...
@@ -37,9 +39,10 @@ class SplitOp : public framework::OperatorWithKernel {
const
size_t
outs_number
=
outs_names
.
size
();
const
size_t
outs_number
=
outs_names
.
size
();
if
(
sections
.
size
()
>
0
)
{
if
(
sections
.
size
()
>
0
)
{
PADDLE_ENFORCE_EQ
(
sections
.
size
(),
outs_number
,
PADDLE_ENFORCE_EQ
(
"tensor split sections size "
sections
.
size
(),
outs_number
,
"should be equal to output size."
);
platform
::
errors
::
InvalidArgument
(
"tensor split sections size "
"should be equal to output size."
));
}
}
if
(
ctx
->
HasInput
(
"AxisTensor"
))
{
if
(
ctx
->
HasInput
(
"AxisTensor"
))
{
...
...
paddle/fluid/operators/split_op.h
浏览文件 @
8c25dfaa
...
@@ -33,12 +33,14 @@ static inline std::vector<framework::DDim> UpdateOutsDims(
...
@@ -33,12 +33,14 @@ static inline std::vector<framework::DDim> UpdateOutsDims(
int64_t
input_axis_dim
=
in_dims
[
axis
];
int64_t
input_axis_dim
=
in_dims
[
axis
];
if
(
num
>
0
)
{
if
(
num
>
0
)
{
if
(
is_runtime
||
input_axis_dim
>
0
)
{
if
(
is_runtime
||
input_axis_dim
>
0
)
{
PADDLE_ENFORCE_EQ
(
input_axis_dim
%
num
,
0
,
PADDLE_ENFORCE_EQ
(
"The input's size along the split dimension "
input_axis_dim
%
num
,
0
,
"must be evenly divisible by Attr(num_or_sections). "
platform
::
errors
::
InvalidArgument
(
"But received Attr(num_or_sections) "
"The input's size along the split dimension "
"= %d, input(X)'s shape = [%s], Attr(dim) = %d."
,
"must be evenly divisible by Attr(num_or_sections). "
num
,
in_dims
,
axis
);
"But received Attr(num_or_sections) "
"= %d, input(X)'s shape = [%s], Attr(dim) = %d."
,
num
,
in_dims
,
axis
));
size_t
out_axis_dim
=
input_axis_dim
/
num
;
size_t
out_axis_dim
=
input_axis_dim
/
num
;
for
(
auto
&
out_dim
:
outs_dims
)
{
for
(
auto
&
out_dim
:
outs_dims
)
{
...
@@ -64,11 +66,13 @@ static inline std::vector<framework::DDim> UpdateOutsDims(
...
@@ -64,11 +66,13 @@ static inline std::vector<framework::DDim> UpdateOutsDims(
}
}
if
(
each_section_is_known
)
{
if
(
each_section_is_known
)
{
PADDLE_ENFORCE_LE
(
num_of_unk
,
1
,
PADDLE_ENFORCE_LE
(
"Only one dimension value of Attr(num_or_sections) "
num_of_unk
,
1
,
"in SplitOp can be -1. "
platform
::
errors
::
InvalidArgument
(
"But received Attr(num_or_sections) = [%s]."
,
"Only one dimension value of Attr(num_or_sections) "
framework
::
make_ddim
(
sections
));
"in SplitOp can be -1. "
"But received Attr(num_or_sections) = [%s]."
,
framework
::
make_ddim
(
sections
)));
}
}
if
(
unk_dim_idx
!=
-
1
)
{
if
(
unk_dim_idx
!=
-
1
)
{
...
@@ -77,21 +81,25 @@ static inline std::vector<framework::DDim> UpdateOutsDims(
...
@@ -77,21 +81,25 @@ static inline std::vector<framework::DDim> UpdateOutsDims(
// the following check will fail.
// the following check will fail.
PADDLE_ENFORCE_LT
(
PADDLE_ENFORCE_LT
(
sum_of_section
,
input_axis_dim
,
sum_of_section
,
input_axis_dim
,
"Sum of Attr(num_or_sections) other than unknown section "
platform
::
errors
::
InvalidArgument
(
"must be less than the input's size "
"Sum of Attr(num_or_sections) other than unknown section "
"along the split dimension. But received Attr(num_or_sections) "
"must be less than the input's "
"= [%s], input(X)'s shape = [%s], Attr(dim) = %d."
,
"size "
framework
::
make_ddim
(
sections
),
in_dims
,
axis
);
"along the split dimension. But received Attr(num_or_sections) "
"= [%s], input(X)'s shape = [%s], Attr(dim) = %d."
,
framework
::
make_ddim
(
sections
),
in_dims
,
axis
));
if
(
each_section_is_known
)
{
if
(
each_section_is_known
)
{
sections
[
unk_dim_idx
]
=
input_axis_dim
-
sum_of_section
;
sections
[
unk_dim_idx
]
=
input_axis_dim
-
sum_of_section
;
}
}
}
else
{
}
else
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
sum_of_section
,
input_axis_dim
,
sum_of_section
,
input_axis_dim
,
"Sum of Attr(num_or_sections) must be equal to the input's size "
platform
::
errors
::
InvalidArgument
(
"along the split dimension. But received Attr(num_or_sections)"
"Sum of Attr(num_or_sections) must be equal to the input's "
" = [%s], input(X)'s shape = [%s], Attr(dim) = %d."
,
"size "
framework
::
make_ddim
(
sections
),
in_dims
,
axis
);
"along the split dimension. But received Attr(num_or_sections)"
" = [%s], input(X)'s shape = [%s], Attr(dim) = %d."
,
framework
::
make_ddim
(
sections
),
in_dims
,
axis
));
}
}
}
}
for
(
int
i
=
0
;
i
<
outs_number
;
++
i
)
{
for
(
int
i
=
0
;
i
<
outs_number
;
++
i
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录