Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
84540384
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
84540384
编写于
4月 17, 2019
作者:
H
Hongyu Liu
提交者:
phlrain
4月 17, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Merge pull request #16840 from phlrain/fix_shape_check_many
fix shape check many by hongyu
上级
3063449f
变更
9
显示空白变更内容
内联
并排
Showing
9 changed file
with
144 addition
and
58 deletion
+144
-58
paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc
...fluid/operators/controlflow/tensor_array_read_write_op.cc
+4
-2
paddle/fluid/operators/data_norm_op.cc
paddle/fluid/operators/data_norm_op.cc
+6
-3
paddle/fluid/operators/huber_loss_op.cc
paddle/fluid/operators/huber_loss_op.cc
+9
-4
paddle/fluid/operators/layer_norm_op.cc
paddle/fluid/operators/layer_norm_op.cc
+9
-2
paddle/fluid/operators/metrics/precision_recall_op.cc
paddle/fluid/operators/metrics/precision_recall_op.cc
+29
-19
paddle/fluid/operators/minus_op.cc
paddle/fluid/operators/minus_op.cc
+7
-3
paddle/fluid/operators/modified_huber_loss_op.cc
paddle/fluid/operators/modified_huber_loss_op.cc
+16
-7
paddle/fluid/operators/space_to_depth_op.cc
paddle/fluid/operators/space_to_depth_op.cc
+38
-13
paddle/fluid/operators/tree_conv_op.cc
paddle/fluid/operators/tree_conv_op.cc
+26
-5
未找到文件。
paddle/fluid/operators/controlflow/tensor_array_read_write_op.cc
浏览文件 @
84540384
...
@@ -81,8 +81,10 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
...
@@ -81,8 +81,10 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
public:
public:
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
void
operator
()(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE
(
context
->
HasInput
(
"I"
),
"Must set the subscript index"
);
PADDLE_ENFORCE
(
context
->
HasInput
(
"I"
),
"Must set the subscript index"
);
if
(
context
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
framework
::
product
(
context
->
GetInputDim
(
"I"
)),
1
,
PADDLE_ENFORCE_EQ
(
framework
::
product
(
context
->
GetInputDim
(
"I"
)),
1
,
"The number of element of subscript index must be 1"
);
"The number of element of subscript index must be 1"
);
}
if
(
!
context
->
HasInput
(
"X"
))
{
if
(
!
context
->
HasInput
(
"X"
))
{
return
;
return
;
}
}
...
...
paddle/fluid/operators/data_norm_op.cc
浏览文件 @
84540384
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/data_norm_op.h"
#include "paddle/fluid/operators/data_norm_op.h"
#include <memory>
#include <string>
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/data_layout.h"
#ifdef PADDLE_WITH_MKLDNN
#ifdef PADDLE_WITH_MKLDNN
...
@@ -65,9 +66,11 @@ class DataNormOp : public framework::OperatorWithKernel {
...
@@ -65,9 +66,11 @@ class DataNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSize"
).
size
(),
1UL
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSize"
).
size
(),
1UL
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSum"
).
size
(),
1UL
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSum"
).
size
(),
1UL
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSquareSum"
).
size
(),
1UL
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSquareSum"
).
size
(),
1UL
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSize"
)[
0
],
C
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSize"
)[
0
],
C
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSum"
)[
0
],
C
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSum"
)[
0
],
C
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSquareSum"
)[
0
],
C
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"BatchSquareSum"
)[
0
],
C
);
}
ctx
->
SetOutputDim
(
"Y"
,
x_dims
);
ctx
->
SetOutputDim
(
"Y"
,
x_dims
);
ctx
->
SetOutputDim
(
"Means"
,
{
C
});
ctx
->
SetOutputDim
(
"Means"
,
{
C
});
...
...
paddle/fluid/operators/huber_loss_op.cc
浏览文件 @
84540384
...
@@ -28,13 +28,18 @@ class HuberLossOp : public framework::OperatorWithKernel {
...
@@ -28,13 +28,18 @@ class HuberLossOp : public framework::OperatorWithKernel {
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
PADDLE_ENFORCE_EQ
(
x_dims
,
y_dims
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"The rank of Input(X) must be 2 and the shape is "
"The rank of Input(X) must be 2 and the shape is "
"[batch_size, 1]."
);
"[batch_size, 1]."
);
if
(
ctx
->
IsRuntime
()
||
(
framework
::
product
(
x_dims
)
>
0
&&
framework
::
product
(
y_dims
)
>
0
))
{
PADDLE_ENFORCE_EQ
(
x_dims
,
y_dims
,
"Shape of X and Y should be same"
);
}
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
x_dims
[
1
],
1
,
PADDLE_ENFORCE_EQ
(
x_dims
[
1
],
1
,
"Each row of Input(X) contains a real value, "
"Each row of Input(X) contains a real value, "
"so the 2nd dimension of Input(X) must be 1."
);
"so the 2nd dimension of Input(X) must be 1."
);
}
ctx
->
SetOutputDim
(
"Residual"
,
x_dims
);
ctx
->
SetOutputDim
(
"Residual"
,
x_dims
);
ctx
->
SetOutputDim
(
"Out"
,
{
x_dims
[
0
],
1
});
ctx
->
SetOutputDim
(
"Out"
,
{
x_dims
[
0
],
1
});
...
...
paddle/fluid/operators/layer_norm_op.cc
浏览文件 @
84540384
...
@@ -46,11 +46,18 @@ class LayerNormOp : public framework::OperatorWithKernel {
...
@@ -46,11 +46,18 @@ class LayerNormOp : public framework::OperatorWithKernel {
int
right
=
static_cast
<
int
>
(
matrix_dim
[
1
]);
int
right
=
static_cast
<
int
>
(
matrix_dim
[
1
]);
if
(
ctx
->
HasInput
(
"Scale"
))
{
if
(
ctx
->
HasInput
(
"Scale"
))
{
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Scale"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Scale"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Scale"
)[
0
],
right
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Scale"
)[
0
],
right
,
"scale should with right"
);
}
}
}
if
(
ctx
->
HasInput
(
"Bias"
))
{
if
(
ctx
->
HasInput
(
"Bias"
))
{
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Bias"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Bias"
).
size
(),
1
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Bias"
)[
0
],
right
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Bias"
)[
0
],
right
,
"bias should with right"
);
}
}
}
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
ctx
->
SetOutputDim
(
"Y"
,
ctx
->
GetInputDim
(
"X"
));
...
...
paddle/fluid/operators/metrics/precision_recall_op.cc
浏览文件 @
84540384
...
@@ -40,31 +40,41 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
...
@@ -40,31 +40,41 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
auto
max_probs_dims
=
ctx
->
GetInputDim
(
"MaxProbs"
);
auto
max_probs_dims
=
ctx
->
GetInputDim
(
"MaxProbs"
);
auto
labels_dims
=
ctx
->
GetInputDim
(
"Labels"
);
auto
labels_dims
=
ctx
->
GetInputDim
(
"Labels"
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
max_probs_dims
[
1
],
1
,
PADDLE_ENFORCE_EQ
(
max_probs_dims
[
1
],
1
,
"Each instance contains one max probability, so the "
"Each instance contains one max probability, so the "
"shape of Input(MaxProbs) should be [batch_size, 1]."
);
"shape of Input(MaxProbs) should be [batch_size, 1]."
);
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Indices"
),
max_probs_dims
,
PADDLE_ENFORCE_EQ
(
"The shape of Input(Indices) should be [batch_size, 1]."
);
ctx
->
GetInputDim
(
"Indices"
),
max_probs_dims
,
PADDLE_ENFORCE_EQ
(
max_probs_dims
[
0
],
labels_dims
[
0
],
"The shape of Input(Indices) should bes same with max_probs_dims"
);
PADDLE_ENFORCE_EQ
(
max_probs_dims
[
0
],
labels_dims
[
0
],
"The 1st dimension of Input(MaxProbs) and "
"The 1st dimension of Input(MaxProbs) and "
"Input(Labels) both are batch_size and the shape should "
"Input(Labels) both are batch_size and the shape should "
"be the same."
);
"be the same."
);
PADDLE_ENFORCE_EQ
(
labels_dims
[
1
],
1
,
PADDLE_ENFORCE_EQ
(
labels_dims
[
1
],
1
,
"The 2nd dimension of Input(Labels) contains instance "
"The 2nd dimension of Input(Labels) contains instance "
"label and the shape should be equal to 1."
);
"label and the shape should be equal to 1."
);
}
if
(
ctx
->
HasInput
(
"Weights"
))
{
if
(
ctx
->
HasInput
(
"Weights"
))
{
auto
weights_dims
=
ctx
->
GetInputDim
(
"Weights"
);
auto
weights_dims
=
ctx
->
GetInputDim
(
"Weights"
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
weights_dims
,
PADDLE_ENFORCE_EQ
(
weights_dims
,
framework
::
make_ddim
({
max_probs_dims
[
0
],
1
}),
framework
::
make_ddim
({
max_probs_dims
[
0
],
1
}),
"The shape of Input(Weights) should be "
"The shape of Input(Weights) should be "
"[batch_size, 1]."
);
"[batch_size, 1]."
);
}
}
}
if
(
ctx
->
HasInput
(
"StatesInfo"
))
{
if
(
ctx
->
HasInput
(
"StatesInfo"
))
{
auto
states_dims
=
ctx
->
GetInputDim
(
"StatesInfo"
);
auto
states_dims
=
ctx
->
GetInputDim
(
"StatesInfo"
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
states_dims
,
framework
::
make_ddim
({
cls_num
,
4
}),
PADDLE_ENFORCE_EQ
(
states_dims
,
framework
::
make_ddim
({
cls_num
,
4
}),
"The shape of Input(StatesInfo) should be "
"The shape of Input(StatesInfo) should be "
"[class_number, 4]."
);
"[class_number, 4]."
);
}
}
}
// Layouts of BatchMetrics and AccumMetrics both are:
// Layouts of BatchMetrics and AccumMetrics both are:
// [
// [
...
...
paddle/fluid/operators/minus_op.cc
浏览文件 @
84540384
...
@@ -14,6 +14,7 @@ limitations under the License. */
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/minus_op.h"
#include "paddle/fluid/operators/minus_op.h"
#include <memory>
#include <string>
#include <string>
#include <vector>
#include <vector>
...
@@ -38,9 +39,12 @@ class MinusOp : public framework::OperatorWithKernel {
...
@@ -38,9 +39,12 @@ class MinusOp : public framework::OperatorWithKernel {
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
if
(
ctx
->
IsRuntime
()
||
(
framework
::
product
(
x_dims
)
>
0
&&
framework
::
product
(
y_dims
)
>
0
))
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
x_dims
,
y_dims
,
x_dims
,
y_dims
,
"Minus operator must take two tensor with same num of elements"
);
"Minus operator must take two tensor with same num of elements"
);
}
ctx
->
SetOutputDim
(
"Out"
,
x_dims
);
ctx
->
SetOutputDim
(
"Out"
,
x_dims
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
...
...
paddle/fluid/operators/modified_huber_loss_op.cc
浏览文件 @
84540384
...
@@ -28,9 +28,16 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
...
@@ -28,9 +28,16 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
PADDLE_ENFORCE_EQ
(
x_dims
,
y_dims
,
"The shape of X and Y must be the same."
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"The tensor rank of X must be 2."
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
2
,
"The tensor rank of X must be 2."
);
if
(
ctx
->
IsRuntime
()
||
(
framework
::
product
(
x_dims
)
>
0
&&
framework
::
product
(
y_dims
)
>
0
))
{
PADDLE_ENFORCE_EQ
(
x_dims
,
y_dims
,
"The shape of X and Y must be the same."
);
}
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
x_dims
[
1
],
1
,
"The 2nd dimension of X must be 1."
);
PADDLE_ENFORCE_EQ
(
x_dims
[
1
],
1
,
"The 2nd dimension of X must be 1."
);
}
ctx
->
SetOutputDim
(
"IntermediateVal"
,
x_dims
);
ctx
->
SetOutputDim
(
"IntermediateVal"
,
x_dims
);
ctx
->
SetOutputDim
(
"Out"
,
{
x_dims
[
0
],
1
});
ctx
->
SetOutputDim
(
"Out"
,
{
x_dims
[
0
],
1
});
...
@@ -90,11 +97,13 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
...
@@ -90,11 +97,13 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
auto
intermediate_dims
=
ctx
->
GetInputDim
(
"IntermediateVal"
);
auto
intermediate_dims
=
ctx
->
GetInputDim
(
"IntermediateVal"
);
auto
out_grad_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
auto
out_grad_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
intermediate_dims
,
x_dims
,
intermediate_dims
,
x_dims
,
"The shape of X and intermediate value must be the same."
);
"The shape of X and intermediate value must be the same."
);
PADDLE_ENFORCE_EQ
(
out_grad_dims
,
x_dims
,
PADDLE_ENFORCE_EQ
(
out_grad_dims
,
x_dims
,
"The shape of Input(Out@Grad) and X must be the same."
);
"The shape of Input(Out@Grad) and X must be the same."
);
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"X"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
...
...
paddle/fluid/operators/space_to_depth_op.cc
浏览文件 @
84540384
...
@@ -34,6 +34,7 @@ class SpaceToDepthOp : public framework::OperatorWithKernel {
...
@@ -34,6 +34,7 @@ class SpaceToDepthOp : public framework::OperatorWithKernel {
auto
blocksize
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"blocksize"
);
auto
blocksize
=
ctx
->
Attrs
().
Get
<
int64_t
>
(
"blocksize"
);
PADDLE_ENFORCE_GT
(
blocksize
,
1
,
"The blocksize should be Greater than 1"
);
PADDLE_ENFORCE_GT
(
blocksize
,
1
,
"The blocksize should be Greater than 1"
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_GT
(
x_dims
[
1
],
0
,
"input channel should be Greater than 0"
);
PADDLE_ENFORCE_GT
(
x_dims
[
1
],
0
,
"input channel should be Greater than 0"
);
PADDLE_ENFORCE_GT
(
x_dims
[
2
],
0
,
"input Height should be Greater than 0"
);
PADDLE_ENFORCE_GT
(
x_dims
[
2
],
0
,
"input Height should be Greater than 0"
);
PADDLE_ENFORCE_GT
(
x_dims
[
3
],
0
,
"input Width should be Greater than 0"
);
PADDLE_ENFORCE_GT
(
x_dims
[
3
],
0
,
"input Width should be Greater than 0"
);
...
@@ -47,6 +48,30 @@ class SpaceToDepthOp : public framework::OperatorWithKernel {
...
@@ -47,6 +48,30 @@ class SpaceToDepthOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ
(
x_dims
[
3
]
%
(
blocksize
),
0
,
PADDLE_ENFORCE_EQ
(
x_dims
[
3
]
%
(
blocksize
),
0
,
"input Width should be divisible of the square of "
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize"
);
"SpaceToDepthOp blocksize"
);
}
else
{
if
(
x_dims
[
1
]
!=
-
1
)
{
PADDLE_ENFORCE_GT
(
x_dims
[
1
],
0
,
"input channel should be Greater than 0"
);
PADDLE_ENFORCE_EQ
(
x_dims
[
1
]
%
(
blocksize
*
blocksize
),
0
,
"input channel should be divisible of the square of "
"SpaceToDepthOp blocksize"
);
}
if
(
x_dims
[
2
]
!=
-
1
)
{
PADDLE_ENFORCE_GT
(
x_dims
[
2
],
0
,
"input Height should be Greater than 0"
);
PADDLE_ENFORCE_EQ
(
x_dims
[
2
]
%
(
blocksize
),
0
,
"input Height should be divisible of the square of "
"SpaceToDepthOp blocksize"
);
}
if
(
x_dims
[
3
]
!=
-
1
)
{
PADDLE_ENFORCE_GT
(
x_dims
[
3
],
0
,
"input Width should be Greater than 0"
);
PADDLE_ENFORCE_EQ
(
x_dims
[
3
]
%
(
blocksize
),
0
,
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize"
);
}
}
VLOG
(
3
)
<<
"SpaceToDepthOp operator x.shape="
<<
x_dims
VLOG
(
3
)
<<
"SpaceToDepthOp operator x.shape="
<<
x_dims
<<
"Attribute blocksize"
<<
blocksize
<<
std
::
endl
;
<<
"Attribute blocksize"
<<
blocksize
<<
std
::
endl
;
...
...
paddle/fluid/operators/tree_conv_op.cc
浏览文件 @
84540384
...
@@ -62,17 +62,38 @@ class TreeConvOp : public framework::OperatorWithKernel {
...
@@ -62,17 +62,38 @@ class TreeConvOp : public framework::OperatorWithKernel {
auto
edge_dims
=
ctx
->
GetInputDim
(
"EdgeSet"
);
auto
edge_dims
=
ctx
->
GetInputDim
(
"EdgeSet"
);
auto
vector_dims
=
ctx
->
GetInputDim
(
"NodesVector"
);
auto
vector_dims
=
ctx
->
GetInputDim
(
"NodesVector"
);
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
auto
filter_dims
=
ctx
->
GetInputDim
(
"Filter"
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
edge_dims
[
2
],
2
,
"Input(EdgeSet) dim[2] should be 2"
);
}
else
{
if
(
edge_dims
[
2
]
!=
-
1
)
{
PADDLE_ENFORCE_EQ
(
edge_dims
[
2
],
2
,
"Input(EdgeSet) dim[2] should be 2"
);
PADDLE_ENFORCE_EQ
(
edge_dims
[
2
],
2
,
"Input(EdgeSet) dim[2] should be 2"
);
}
}
PADDLE_ENFORCE_EQ
(
edge_dims
.
size
(),
3
,
PADDLE_ENFORCE_EQ
(
edge_dims
.
size
(),
3
,
"The dimension of EdgeSet Tensor should be 3"
);
"The dimension of EdgeSet Tensor should be 3"
);
PADDLE_ENFORCE_EQ
(
vector_dims
.
size
(),
3
,
PADDLE_ENFORCE_EQ
(
vector_dims
.
size
(),
3
,
"The dimension of NodesVector Tensor should be 3"
);
"The dimension of NodesVector Tensor should be 3"
);
PADDLE_ENFORCE_EQ
(
filter_dims
.
size
(),
4
,
PADDLE_ENFORCE_EQ
(
filter_dims
.
size
(),
4
,
"The dimension of Filter Tensor should be 4"
);
"The dimension of Filter Tensor should be 4"
);
if
(
ctx
->
IsRuntime
())
{
PADDLE_ENFORCE_EQ
(
filter_dims
[
1
],
3
,
"Input(Filter) dim[1] should be 3"
);
PADDLE_ENFORCE_EQ
(
filter_dims
[
1
],
3
,
"Input(Filter) dim[1] should be 3"
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
filter_dims
[
0
],
vector_dims
[
2
],
filter_dims
[
0
],
vector_dims
[
2
],
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]"
);
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]"
);
}
else
{
if
(
filter_dims
[
1
]
!=
-
1
)
{
PADDLE_ENFORCE_EQ
(
filter_dims
[
1
],
3
,
"Input(Filter) dim[1] should be 3"
);
}
if
(
filter_dims
[
0
]
!=
-
1
&&
vector_dims
[
2
]
!=
-
1
)
{
PADDLE_ENFORCE_EQ
(
filter_dims
[
0
],
vector_dims
[
2
],
"Input(Filter) dim[0] must equal to Input(NodesVector) dim[2]"
);
}
}
auto
output_dims
=
framework
::
make_ddim
(
auto
output_dims
=
framework
::
make_ddim
(
{
vector_dims
[
0
],
vector_dims
[
1
],
filter_dims
[
2
],
filter_dims
[
3
]});
{
vector_dims
[
0
],
vector_dims
[
1
],
filter_dims
[
2
],
filter_dims
[
3
]});
ctx
->
SetOutputDim
(
"Out"
,
output_dims
);
ctx
->
SetOutputDim
(
"Out"
,
output_dims
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录