Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
7daae985
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
7daae985
编写于
7月 20, 2022
作者:
Y
yaozhixin
提交者:
GitHub
7月 20, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[IPU] Add more Ops (#44414)
* [IPU] Add more Ops * update boost API
上级
1047cb17
变更
5
显示空白变更内容
内联
并排
Showing
5 changed file
with
696 addition
and
170 deletion
+696
-170
paddle/fluid/platform/device/ipu/popart_canonicalization/detection_ops.cc
...tform/device/ipu/popart_canonicalization/detection_ops.cc
+444
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc
...uid/platform/device/ipu/popart_canonicalization/nn_ops.cc
+8
-24
paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc
...platform/device/ipu/popart_canonicalization/op_builder.cc
+63
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h
.../platform/device/ipu/popart_canonicalization/op_builder.h
+16
-0
paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc
...platform/device/ipu/popart_canonicalization/tensor_ops.cc
+165
-146
未找到文件。
paddle/fluid/platform/device/ipu/popart_canonicalization/detection_ops.cc
0 → 100644
浏览文件 @
7daae985
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace
paddle
{
namespace
platform
{
namespace
ipu
{
namespace
{
Node
*
yolo_box_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
clip_bbox
=
PADDLE_GET_CONST
(
bool
,
op
->
GetAttr
(
"clip_bbox"
));
auto
iou_aware
=
PADDLE_GET_CONST
(
bool
,
op
->
GetAttr
(
"iou_aware"
));
auto
conf_thresh
=
PADDLE_GET_CONST
(
float
,
op
->
GetAttr
(
"conf_thresh"
));
auto
iou_aware_factor
=
PADDLE_GET_CONST
(
float
,
op
->
GetAttr
(
"iou_aware_factor"
));
auto
class_num
=
PADDLE_GET_CONST
(
int
,
op
->
GetAttr
(
"class_num"
));
auto
downsample_ratio
=
PADDLE_GET_CONST
(
int
,
op
->
GetAttr
(
"downsample_ratio"
));
auto
scale_x_y
=
PADDLE_GET_CONST
(
float
,
op
->
GetAttr
(
"scale_x_y"
));
auto
anchors
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"anchors"
));
// For Slice Op, while value is very large, it equals to the ends.
int
max_int
=
INT_MAX
;
int
anchor_num
=
anchors
.
size
()
/
2
;
// FP32 or FP16
auto
target_dtype
=
GetInputVarNode
(
"X"
,
node
)
->
Var
()
->
GetDataType
();
Node
*
input_x
=
GetInputVarNode
(
"X"
,
node
);
if
(
iou_aware
)
{
input_x
=
CreateSlice
(
graph
,
node
,
{
input_x
},
{},
std
::
vector
<
int
>
{
0
,
0
,
0
,
0
},
std
::
vector
<
int
>
{
max_int
,
anchor_num
,
max_int
,
max_int
},
std
::
vector
<
int
>
{
0
,
1
,
2
,
3
},
std
::
vector
<
int
>
{
1
,
1
,
1
,
1
})
->
outputs
[
0
];
}
auto
nchw
=
GetInputVarNode
(
"X"
,
node
)
->
Var
()
->
GetShape
();
// Channel `C` = anchor_num * (5 + class_num)
auto
*
reshaped_x
=
CreateReshape
(
graph
,
node
,
{
input_x
},
{},
std
::
vector
<
int64_t
>
{
nchw
[
0
],
anchor_num
,
-
1
,
nchw
[
2
],
nchw
[
3
]})
->
outputs
[
0
];
auto
*
transposed_x
=
CreateBaseOp
(
graph
,
node
,
"popart_transpose"
,
{
reshaped_x
},
{},
{{
"perm"
,
std
::
vector
<
int64_t
>
{
0
,
1
,
3
,
4
,
2
}}})
->
outputs
[
0
];
// Build the grid
// grid_x_0 shape is [w]
std
::
vector
<
float
>
grid_x_0
(
nchw
[
3
]);
std
::
iota
(
grid_x_0
.
begin
(),
grid_x_0
.
end
(),
0.0
f
);
// grid_y_0 shape is [h]
std
::
vector
<
float
>
grid_y_0
(
nchw
[
2
]);
std
::
iota
(
grid_y_0
.
begin
(),
grid_y_0
.
end
(),
0.0
f
);
// grid_x_1 shape is [w * h]
std
::
vector
<
float
>
grid_x_1
;
for
(
int
i
=
0
;
i
<
nchw
[
2
];
i
++
)
{
grid_x_1
.
insert
(
grid_x_1
.
end
(),
grid_x_0
.
begin
(),
grid_x_0
.
end
());
}
auto
*
grid_x_1_node
=
CreateConst
(
graph
,
node
,
grid_x_1
,
{
int64_t
(
grid_x_1
.
size
())},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
// grid_y_1 shape is [h * w]
std
::
vector
<
float
>
grid_y_1
;
for
(
int
i
=
0
;
i
<
nchw
[
3
];
i
++
)
{
grid_y_1
.
insert
(
grid_y_1
.
end
(),
grid_y_0
.
begin
(),
grid_y_0
.
end
());
}
auto
*
grid_y_1_node
=
CreateConst
(
graph
,
node
,
grid_y_1
,
{
int64_t
(
grid_y_1
.
size
())},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
auto
*
grid_x_node
=
CreateReshape
(
graph
,
node
,
{
grid_x_1_node
},
{},
std
::
vector
<
int64_t
>
{
nchw
[
2
],
nchw
[
3
],
1
})
->
outputs
[
0
];
auto
*
grid_y_2_node
=
CreateReshape
(
graph
,
node
,
{
grid_y_1_node
},
{},
std
::
vector
<
int64_t
>
{
nchw
[
3
],
nchw
[
2
],
1
})
->
outputs
[
0
];
auto
*
grid_y_node
=
CreateBaseOp
(
graph
,
node
,
"popart_transpose"
,
{
grid_y_2_node
},
{},
{{
"perm"
,
std
::
vector
<
int64_t
>
{
1
,
0
,
2
}}})
->
outputs
[
0
];
auto
*
grid_node
=
CreateBaseOp
(
graph
,
node
,
"popart_concat"
,
{
grid_x_node
,
grid_y_node
},
{},
{{
"axis"
,
int64_t
(
2
)}})
->
outputs
[
0
];
// Generate the positions(x, y) of boxes
// pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0]) *
// scale_x_y + bias_x_y) / w pred_box[:, :, :, :, 1] = (grid_y +
// sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h
auto
*
pred_box_xy
=
CreateSlice
(
graph
,
node
,
{
transposed_x
},
{},
std
::
vector
<
int
>
{
0
,
0
,
0
,
0
,
0
},
std
::
vector
<
int
>
{
max_int
,
max_int
,
max_int
,
max_int
,
2
},
std
::
vector
<
int
>
{
0
,
1
,
2
,
3
,
4
},
std
::
vector
<
int
>
{
1
,
1
,
1
,
1
,
1
})
->
outputs
[
0
];
auto
*
scale_x_y_node
=
CreateConst
(
graph
,
node
,
std
::
vector
<
float
>
{
scale_x_y
},
{
int64_t
(
1
)},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
auto
*
bias_x_y_node
=
CreateConst
(
graph
,
node
,
std
::
vector
<
float
>
{(
1.0
f
-
scale_x_y
)
/
2.0
f
},
{
int64_t
(
1
)},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
auto
*
wh
=
CreateConst
(
graph
,
node
,
std
::
vector
<
float
>
{
static_cast
<
float
>
(
nchw
[
3
]),
static_cast
<
float
>
(
nchw
[
2
])},
{
int64_t
(
2
)},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
pred_box_xy
=
CreateBaseOp
(
graph
,
node
,
"popart_sigmoid"
,
{
pred_box_xy
},
{})
->
outputs
[
0
];
pred_box_xy
=
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
{
pred_box_xy
,
scale_x_y_node
},
{})
->
outputs
[
0
];
pred_box_xy
=
CreateBaseOp
(
graph
,
node
,
"popart_add"
,
{
pred_box_xy
,
bias_x_y_node
},
{})
->
outputs
[
0
];
pred_box_xy
=
CreateBaseOp
(
graph
,
node
,
"popart_add"
,
{
pred_box_xy
,
grid_node
},
{})
->
outputs
[
0
];
pred_box_xy
=
CreateBaseOp
(
graph
,
node
,
"popart_div"
,
{
pred_box_xy
,
wh
},
{})
->
outputs
[
0
];
// Generate Width and Height of boxes
// anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
// anchors_s = np.array(
// [(an_w / input_w, an_h / input_h) for an_w, an_h in anchors])
// anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1))
// anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1))
auto
*
anchors_node
=
CreateConst
(
graph
,
node
,
std
::
vector
<
float
>
{
anchors
.
begin
(),
anchors
.
begin
()
+
anchor_num
*
2
},
{
int64_t
(
anchor_num
*
2
)},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
anchors_node
=
CreateReshape
(
graph
,
node
,
{
anchors_node
},
{},
std
::
vector
<
int64_t
>
{
anchor_num
,
2
})
->
outputs
[
0
];
auto
*
downsample_node
=
CreateConst
(
graph
,
node
,
std
::
vector
<
float
>
{
static_cast
<
float
>
(
downsample_ratio
)},
{
int64_t
(
1
)},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
auto
*
ori_wh
=
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
{
wh
,
downsample_node
},
{})
->
outputs
[
0
];
anchors_node
=
CreateBaseOp
(
graph
,
node
,
"popart_div"
,
{
anchors_node
,
ori_wh
},
{})
->
outputs
[
0
];
anchors_node
=
CreateReshape
(
graph
,
node
,
{
anchors_node
},
{},
std
::
vector
<
int64_t
>
{
1
,
anchor_num
,
1
,
1
,
2
})
->
outputs
[
0
];
auto
*
pred_box_wh
=
CreateSlice
(
graph
,
node
,
{
transposed_x
},
{},
std
::
vector
<
int
>
{
0
,
0
,
0
,
0
,
2
},
std
::
vector
<
int
>
{
max_int
,
max_int
,
max_int
,
max_int
,
4
},
std
::
vector
<
int
>
{
0
,
1
,
2
,
3
,
4
},
std
::
vector
<
int
>
{
1
,
1
,
1
,
1
,
1
})
->
outputs
[
0
];
pred_box_wh
=
CreateBaseOp
(
graph
,
node
,
"popart_exp"
,
{
pred_box_wh
},
{})
->
outputs
[
0
];
pred_box_wh
=
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
{
pred_box_wh
,
anchors_node
},
{})
->
outputs
[
0
];
// Ignore the boxes whose confidience lower than the threshold
// if iou_aware:
// pred_conf = sigmoid(x[:, :, :, :, 4:5])**(
// 1 - iou_aware_factor) * sigmoid(ioup)**iou_aware_factor
// else:
// pred_conf = sigmoid(x[:, :, :, :, 4:5])
auto
*
confidence
=
CreateSlice
(
graph
,
node
,
{
transposed_x
},
{},
std
::
vector
<
int
>
{
0
,
0
,
0
,
0
,
4
},
std
::
vector
<
int
>
{
max_int
,
max_int
,
max_int
,
max_int
,
5
},
std
::
vector
<
int
>
{
0
,
1
,
2
,
3
,
4
},
std
::
vector
<
int
>
{
1
,
1
,
1
,
1
,
1
})
->
outputs
[
0
];
auto
*
pred_conf
=
CreateBaseOp
(
graph
,
node
,
"popart_sigmoid"
,
{
confidence
},
{})
->
outputs
[
0
];
if
(
iou_aware
)
{
auto
*
ioup
=
CreateSlice
(
graph
,
node
,
{
GetInputVarNode
(
"X"
,
node
)},
{},
std
::
vector
<
int
>
{
0
,
0
,
0
,
0
},
std
::
vector
<
int
>
{
max_int
,
anchor_num
,
max_int
,
max_int
},
std
::
vector
<
int
>
{
0
,
1
,
2
,
3
},
std
::
vector
<
int
>
{
1
,
1
,
1
,
1
})
->
outputs
[
0
];
ioup
=
CreateBaseOp
(
graph
,
node
,
"popart_unsqueeze"
,
{
ioup
},
{},
{{
"axes"
,
std
::
vector
<
int64_t
>
{
4
}}})
->
outputs
[
0
];
ioup
=
CreateBaseOp
(
graph
,
node
,
"popart_sigmoid"
,
{
ioup
},
{})
->
outputs
[
0
];
auto
*
power_0
=
CreateConst
(
graph
,
node
,
std
::
vector
<
float
>
{
1.0
f
-
iou_aware_factor
},
{
int64_t
(
1
)},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
auto
*
power_1
=
CreateConst
(
graph
,
node
,
std
::
vector
<
float
>
{
iou_aware_factor
},
{
int64_t
(
1
)},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
ioup
=
CreateBaseOp
(
graph
,
node
,
"popart_pow"
,
{
ioup
,
power_1
},
{})
->
outputs
[
0
];
pred_conf
=
CreateBaseOp
(
graph
,
node
,
"popart_pow"
,
{
pred_conf
,
power_0
},
{})
->
outputs
[
0
];
pred_conf
=
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
{
pred_conf
,
ioup
},
{})
->
outputs
[
0
];
}
// pred_conf[pred_conf < conf_thresh] = 0.
// pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf
// pred_box = pred_box * (pred_conf > 0.).astype('float32')
auto
*
value_2
=
CreateConst
(
graph
,
node
,
std
::
vector
<
float
>
{
2.0
f
},
{
int64_t
(
1
)},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
auto
*
center
=
CreateBaseOp
(
graph
,
node
,
"popart_div"
,
{
pred_box_wh
,
value_2
},
{})
->
outputs
[
0
];
auto
*
min_xy
=
CreateBaseOp
(
graph
,
node
,
"popart_sub"
,
{
pred_box_xy
,
center
},
{})
->
outputs
[
0
];
auto
*
max_xy
=
CreateBaseOp
(
graph
,
node
,
"popart_add"
,
{
pred_box_xy
,
center
},
{})
->
outputs
[
0
];
auto
*
conf_thresh_node
=
CreateConst
(
graph
,
node
,
std
::
vector
<
float
>
{
conf_thresh
},
{
int64_t
(
1
)},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
auto
*
filter
=
CreateBaseOp
(
graph
,
node
,
"popart_greater"
,
{
pred_conf
,
conf_thresh_node
},
{})
->
outputs
[
0
];
filter
=
CreateCast
(
graph
,
node
,
{
filter
},
{},
target_dtype
)
->
outputs
[
0
];
pred_conf
=
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
{
pred_conf
,
filter
},
{})
->
outputs
[
0
];
auto
*
pred_score
=
CreateSlice
(
graph
,
node
,
{
transposed_x
},
{},
std
::
vector
<
int
>
{
0
,
0
,
0
,
0
,
5
},
std
::
vector
<
int
>
{
max_int
,
max_int
,
max_int
,
max_int
,
max_int
},
std
::
vector
<
int
>
{
0
,
1
,
2
,
3
,
4
},
std
::
vector
<
int
>
{
1
,
1
,
1
,
1
,
1
})
->
outputs
[
0
];
pred_score
=
CreateBaseOp
(
graph
,
node
,
"popart_sigmoid"
,
{
pred_score
},
{})
->
outputs
[
0
];
pred_score
=
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
{
pred_score
,
pred_conf
},
{})
->
outputs
[
0
];
auto
*
pred_box
=
CreateBaseOp
(
graph
,
node
,
"popart_concat"
,
{
min_xy
,
max_xy
},
{},
{{
"axis"
,
int64_t
(
4
)}})
->
outputs
[
0
];
pred_box
=
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
{
pred_box
,
filter
},
{})
->
outputs
[
0
];
pred_box
=
CreateReshape
(
graph
,
node
,
{
pred_box
},
{},
std
::
vector
<
int64_t
>
{
nchw
[
0
],
-
1
,
4
})
->
outputs
[
0
];
// Clip the boxes to img_size
auto
*
float_img_size
=
CreateCast
(
graph
,
node
,
{
GetInputVarNode
(
"ImgSize"
,
node
)},
{},
target_dtype
)
->
outputs
[
0
];
float_img_size
=
CreateBaseOp
(
graph
,
node
,
"popart_unsqueeze"
,
{
float_img_size
},
{},
{{
"axes"
,
std
::
vector
<
int64_t
>
(
1
)}})
->
outputs
[
0
];
auto
split_im_hw
=
CreateSplit
(
graph
,
node
,
{
float_img_size
},
{},
std
::
vector
<
int64_t
>
{
1
,
1
},
2
)
->
outputs
;
auto
*
im_whwh
=
CreateBaseOp
(
graph
,
node
,
"popart_concat"
,
{
split_im_hw
[
1
],
split_im_hw
[
0
],
split_im_hw
[
1
],
split_im_hw
[
0
]},
{},
{{
"axis"
,
int64_t
(
2
)}})
->
outputs
[
0
];
if
(
!
clip_bbox
)
{
auto
*
out
=
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
{
pred_box
,
im_whwh
},
{})
->
outputs
[
0
];
CreateCast
(
graph
,
node
,
{
out
},
{
GetOutputVarNode
(
"Boxes"
,
node
)},
GetOutputVarNode
(
"Boxes"
,
node
)
->
Var
()
->
GetDataType
());
}
else
{
pred_box
=
CreateBaseOp
(
graph
,
node
,
"popart_mul"
,
{
pred_box
,
im_whwh
},
{})
->
outputs
[
0
];
auto
*
im_wh
=
CreateBaseOp
(
graph
,
node
,
"popart_concat"
,
{
split_im_hw
[
1
],
split_im_hw
[
0
]},
{},
{{
"axis"
,
int64_t
(
2
)}})
->
outputs
[
0
];
auto
*
float_value_1
=
CreateConst
(
graph
,
node
,
std
::
vector
<
float
>
{
1.0
f
},
{
int64_t
(
1
)},
VarType2OnnxDType
(
target_dtype
))
->
outputs
[
0
];
im_wh
=
CreateBaseOp
(
graph
,
node
,
"popart_sub"
,
{
im_wh
,
float_value_1
},
{})
->
outputs
[
0
];
auto
pred_box_xymin_xymax
=
CreateSplit
(
graph
,
node
,
{
pred_box
},
{},
std
::
vector
<
int64_t
>
{
2
,
2
},
2
)
->
outputs
;
pred_box_xymin_xymax
[
0
]
=
CreateBaseOp
(
graph
,
node
,
"popart_relu"
,
{
pred_box_xymin_xymax
[
0
]},
{})
->
outputs
[
0
];
pred_box_xymin_xymax
[
1
]
=
CreateBaseOp
(
graph
,
node
,
"popart_min"
,
{
pred_box_xymin_xymax
[
1
],
im_wh
},
{})
->
outputs
[
0
];
auto
*
out
=
CreateBaseOp
(
graph
,
node
,
"popart_concat"
,
pred_box_xymin_xymax
,
{},
{{
"axis"
,
int64_t
(
2
)}})
->
outputs
[
0
];
CreateCast
(
graph
,
node
,
{
out
},
{
GetOutputVarNode
(
"Boxes"
,
node
)},
GetOutputVarNode
(
"Boxes"
,
node
)
->
Var
()
->
GetDataType
());
}
auto
*
score_out
=
CreateReshape
(
graph
,
node
,
{
pred_score
},
{},
std
::
vector
<
int64_t
>
{
nchw
[
0
],
-
1
,
class_num
})
->
outputs
[
0
];
return
CreateCast
(
graph
,
node
,
{
score_out
},
{
GetOutputVarNode
(
"Scores"
,
node
)},
GetOutputVarNode
(
"Scores"
,
node
)
->
Var
()
->
GetDataType
());
}
}
// namespace
}
// namespace ipu
}
// namespace platform
}
// namespace paddle
REGISTER_HANDLER
(
yolo_box
,
yolo_box_handler
);
paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc
浏览文件 @
7daae985
...
@@ -656,30 +656,14 @@ Node *interp_handler(Graph *graph, Node *node, const std::string &mode) {
...
@@ -656,30 +656,14 @@ Node *interp_handler(Graph *graph, Node *node, const std::string &mode) {
CreateBaseOp
(
CreateBaseOp
(
graph
,
node
,
"popart_shape"
,
{
GetInputVarNode
(
"X"
,
node
)},
{})
graph
,
node
,
"popart_shape"
,
{
GetInputVarNode
(
"X"
,
node
)},
{})
->
outputs
[
0
];
->
outputs
[
0
];
Node
*
start
=
CreateConst
(
graph
,
Node
*
nc
=
CreateSlice
(
graph
,
node
,
node
,
{
input_shape
},
{},
std
::
vector
<
int
>
{
0
},
std
::
vector
<
int
>
{
0
},
std
::
vector
<
int64_t
>
{
1
},
ONNXDataType
::
INT32
)
->
outputs
[
0
];
Node
*
end
=
CreateConst
(
graph
,
node
,
std
::
vector
<
int
>
{
2
},
std
::
vector
<
int
>
{
2
},
std
::
vector
<
int64_t
>
{
1
},
ONNXDataType
::
INT32
)
->
outputs
[
0
];
Node
*
axes
=
CreateConst
(
graph
,
node
,
std
::
vector
<
int
>
{
0
},
std
::
vector
<
int
>
{
0
},
std
::
vector
<
int64_t
>
{
1
},
std
::
vector
<
int
>
{
1
})
ONNXDataType
::
INT32
)
->
outputs
[
0
];
Node
*
nc
=
CreateBaseOp
(
graph
,
node
,
"popart_slice"
,
{
input_shape
,
start
,
end
,
axes
},
{},
{})
->
outputs
[
0
];
->
outputs
[
0
];
size
=
CreateBaseOp
(
graph
,
size
=
CreateBaseOp
(
graph
,
node
,
node
,
...
...
paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc
浏览文件 @
7daae985
...
@@ -256,6 +256,69 @@ Node *CreateSoftmaxOpset11(Graph *graph,
...
@@ -256,6 +256,69 @@ Node *CreateSoftmaxOpset11(Graph *graph,
}
}
}
}
Node
*
CreateSlice
(
Graph
*
graph
,
Node
*
node
,
const
std
::
vector
<
Node
*>
&
inputs
,
const
std
::
vector
<
Node
*>
&
outputs
,
const
std
::
vector
<
int
>
&
starts
,
const
std
::
vector
<
int
>
&
ends
,
const
std
::
vector
<
int
>
&
axes
,
const
std
::
vector
<
int
>
&
strides
)
{
auto
*
starts_node
=
CreateConst
(
graph
,
node
,
starts
,
{
int64_t
(
starts
.
size
())},
ONNXDataType
::
INT32
)
->
outputs
[
0
];
auto
*
ends_node
=
CreateConst
(
graph
,
node
,
ends
,
{
int64_t
(
ends
.
size
())},
ONNXDataType
::
INT32
)
->
outputs
[
0
];
auto
*
axes_node
=
CreateConst
(
graph
,
node
,
axes
,
{
int64_t
(
axes
.
size
())},
ONNXDataType
::
INT32
)
->
outputs
[
0
];
auto
*
strides_node
=
CreateConst
(
graph
,
node
,
strides
,
{
int64_t
(
strides
.
size
())},
ONNXDataType
::
INT32
)
->
outputs
[
0
];
return
CreateBaseOp
(
graph
,
node
,
"popart_slice"
,
{
inputs
[
0
],
starts_node
,
ends_node
,
axes_node
,
strides_node
},
outputs
);
}
Node
*
CreateSplit
(
Graph
*
graph
,
Node
*
node
,
const
std
::
vector
<
Node
*>
&
inputs
,
const
std
::
vector
<
Node
*>
&
outputs
,
const
std
::
vector
<
int64_t
>
&
split
,
const
int64_t
axis
)
{
if
(
!
outputs
.
empty
())
{
return
CreateBaseOp
(
graph
,
node
,
"popart_split"
,
inputs
,
outputs
,
{{
"num_outputs"
,
int64_t
(
split
.
size
())},
{
"axis"
,
int64_t
(
axis
)},
{
"split"
,
split
}});
}
else
{
std
::
vector
<
Node
*>
splits_output_nodes
;
for
(
int
j
=
0
;
j
<
split
.
size
();
j
++
)
{
splits_output_nodes
.
push_back
(
MakeVarNode
(
graph
,
node
));
}
return
CreateBaseOp
(
graph
,
node
,
"popart_split"
,
inputs
,
{
splits_output_nodes
},
{{
"num_outputs"
,
int64_t
(
split
.
size
())},
{
"axis"
,
int64_t
(
axis
)},
{
"split"
,
split
}});
}
}
}
// namespace ipu
}
// namespace ipu
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h
浏览文件 @
7daae985
...
@@ -104,6 +104,22 @@ Node *CreateSoftmaxOpset11(Graph *graph,
...
@@ -104,6 +104,22 @@ Node *CreateSoftmaxOpset11(Graph *graph,
const
std
::
vector
<
Node
*>
&
outputs
,
const
std
::
vector
<
Node
*>
&
outputs
,
int64_t
axis
);
int64_t
axis
);
Node
*
CreateSlice
(
Graph
*
graph
,
Node
*
node
,
const
std
::
vector
<
Node
*>
&
inputs
,
const
std
::
vector
<
Node
*>
&
outputs
,
const
std
::
vector
<
int
>
&
starts
,
const
std
::
vector
<
int
>
&
ends
,
const
std
::
vector
<
int
>
&
axes
,
const
std
::
vector
<
int
>
&
strides
);
Node
*
CreateSplit
(
Graph
*
graph
,
Node
*
node
,
const
std
::
vector
<
Node
*>
&
inputs
,
const
std
::
vector
<
Node
*>
&
outputs
,
const
std
::
vector
<
int64_t
>
&
split
,
const
int64_t
axis
);
}
// namespace ipu
}
// namespace ipu
}
// namespace platform
}
// namespace platform
}
// namespace paddle
}
// namespace paddle
paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc
浏览文件 @
7daae985
...
@@ -249,94 +249,57 @@ Node *lookup_table_op_handler(Graph *graph,
...
@@ -249,94 +249,57 @@ Node *lookup_table_op_handler(Graph *graph,
{{
"value"
,
const_value_
},
{{
"value"
,
const_value_
},
{
"dims"
,
const_shape_
},
{
"dims"
,
const_shape_
},
{
"dtype"
,
GetOutputVarDType
(
node
)}});
{
"dtype"
,
GetOutputVarDType
(
node
)}});
auto
axes
=
CreateConst
(
graph
,
if
(
padding_idx_
==
0
)
{
node
,
auto
right_slice
=
{},
CreateSlice
(
graph
,
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
0
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
step
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
left_start
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
0
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
left_end
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
padding_idx_
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
right_start
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
padding_idx_
+
1
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
right_end
=
CreateConst
(
graph
,
node
,
{},
{},
{{
"value"
,
std
::
vector
<
int64_t
>
{
table_size_
}},
{
"dims"
,
std
::
vector
<
int64_t
>
{
1
}},
{
"dtype"
,
ONNXDataType
::
INT64
}});
auto
left_slice
=
CreateBaseOp
(
graph
,
node
,
node
,
"popart_slice"
,
{
GetInputVarNode
(
"W"
,
node
)},
{
GetInputVarNode
(
"W"
,
node
),
left_start
->
outputs
[
0
],
left_end
->
outputs
[
0
],
axes
->
outputs
[
0
],
step
->
outputs
[
0
]},
{},
{});
auto
right_slice
=
CreateBaseOp
(
graph
,
node
,
"popart_slice"
,
{
GetInputVarNode
(
"W"
,
node
),
right_start
->
outputs
[
0
],
right_end
->
outputs
[
0
],
axes
->
outputs
[
0
],
step
->
outputs
[
0
]},
{},
{},
{});
std
::
vector
<
int
>
{
static_cast
<
int
>
(
padding_idx_
)
+
1
},
std
::
vector
<
int
>
{
static_cast
<
int
>
(
table_size_
)},
if
(
padding_idx_
==
0
)
{
std
::
vector
<
int
>
{
0
},
std
::
vector
<
int
>
{
1
});
w_node
=
CreateBaseOp
(
graph
,
w_node
=
CreateBaseOp
(
graph
,
node
,
node
,
"popart_concat"
,
"popart_concat"
,
{
concat_const
->
outputs
[
0
],
right_slice
->
outputs
[
0
]},
{
concat_const
->
outputs
[
0
],
right_slice
->
outputs
[
0
]},
{},
{},
{{
"axis"
,
int64_t
(
0
)}});
{{
"axis"
,
int64_t
(
0
)}});
ClearNode
(
left_start
);
ClearNode
(
left_end
);
ClearNode
(
left_slice
);
}
else
if
(
padding_idx_
==
table_size_
-
1
)
{
}
else
if
(
padding_idx_
==
table_size_
-
1
)
{
auto
left_slice
=
CreateSlice
(
graph
,
node
,
{
GetInputVarNode
(
"W"
,
node
)},
{},
std
::
vector
<
int
>
{
0
},
std
::
vector
<
int
>
{
static_cast
<
int
>
(
padding_idx_
)},
std
::
vector
<
int
>
{
0
},
std
::
vector
<
int
>
{
1
});
w_node
=
CreateBaseOp
(
graph
,
w_node
=
CreateBaseOp
(
graph
,
node
,
node
,
"popart_concat"
,
"popart_concat"
,
{
left_slice
->
outputs
[
0
],
concat_const
->
outputs
[
0
]},
{
left_slice
->
outputs
[
0
],
concat_const
->
outputs
[
0
]},
{},
{},
{{
"axis"
,
int64_t
{
0
}}});
{{
"axis"
,
int64_t
{
0
}}});
ClearNode
(
right_start
);
ClearNode
(
right_end
);
ClearNode
(
right_slice
);
}
else
{
}
else
{
auto
left_slice
=
CreateSlice
(
graph
,
node
,
{
GetInputVarNode
(
"W"
,
node
)},
{},
std
::
vector
<
int
>
{
0
},
std
::
vector
<
int
>
{
static_cast
<
int
>
(
padding_idx_
)},
std
::
vector
<
int
>
{
0
},
std
::
vector
<
int
>
{
1
});
auto
right_slice
=
CreateSlice
(
graph
,
node
,
{
GetInputVarNode
(
"W"
,
node
)},
{},
std
::
vector
<
int
>
{
static_cast
<
int
>
(
padding_idx_
)
+
1
},
std
::
vector
<
int
>
{
static_cast
<
int
>
(
table_size_
)},
std
::
vector
<
int
>
{
0
},
std
::
vector
<
int
>
{
1
});
w_node
=
CreateBaseOp
(
graph
,
w_node
=
CreateBaseOp
(
graph
,
node
,
node
,
"popart_concat"
,
"popart_concat"
,
...
@@ -441,72 +404,69 @@ Node *shape_handler(Graph *graph, Node *node) {
...
@@ -441,72 +404,69 @@ Node *shape_handler(Graph *graph, Node *node) {
Node
*
slice_handler
(
Graph
*
graph
,
Node
*
node
)
{
Node
*
slice_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
*
op
=
node
->
Op
();
Node
*
starts
=
nullptr
;
auto
inputs
=
op
->
Inputs
();
if
(
!
op
->
HasAttr
(
"starts"
))
{
starts
=
GetInputVarNode
(
"StartsTensor"
,
node
);
auto
axes_value
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"axes"
));
std
::
vector
<
std
::
vector
<
int
>>
slice_values
(
3
);
std
::
vector
<
std
::
string
>
tensor_names
{
"Starts"
,
"Ends"
,
"Strides"
};
std
::
vector
<
std
::
string
>
attr_names
{
"starts"
,
"ends"
,
"strides"
};
for
(
int
i
=
0
;
i
<
3
;
i
++
)
{
// Starts and Ends are default keys in inputs, but Strides.
bool
is_tensor
=
(
inputs
.
find
(
tensor_names
[
i
]
+
"TensorList"
)
!=
inputs
.
end
()
&&
!
inputs
.
at
(
tensor_names
[
i
]
+
"TensorList"
).
empty
())
||
(
inputs
.
find
(
tensor_names
[
i
]
+
"Tensor"
)
!=
inputs
.
end
()
&&
!
inputs
.
at
(
tensor_names
[
i
]
+
"Tensor"
).
empty
());
if
(
is_tensor
)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Do not support starts, ends and strides as tensors."
));
}
else
{
}
else
{
auto
starts_
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"starts"
));
if
(
i
==
2
&&
!
op
->
HasAttr
(
"strides"
))
{
auto
dim
=
int64_t
(
starts_
.
size
());
slice_values
[
i
]
=
std
::
vector
<
int
>
(
axes_value
.
size
(),
1
);
starts
=
CreateConst
(
graph
,
node
,
std
::
vector
<
int
>
{
starts_
},
{
dim
},
ONNXDataType
::
INT32
);
starts
=
starts
->
outputs
[
0
];
}
Node
*
ends
=
nullptr
;
if
(
!
op
->
HasAttr
(
"ends"
))
{
ends
=
GetInputVarNode
(
"EndsTensor"
,
node
);
}
else
{
}
else
{
auto
ends_
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"ends"
));
slice_values
[
i
]
=
auto
dim
=
int64_t
(
ends_
.
size
());
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
attr_names
[
i
]));
ends
=
CreateConst
(
}
graph
,
node
,
std
::
vector
<
int
>
{
ends_
},
{
dim
},
ONNXDataType
::
INT32
);
ends
=
ends
->
outputs
[
0
];
}
}
Node
*
axes
=
nullptr
;
{
auto
axes_
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"axes"
));
auto
dim
=
int64_t
(
axes_
.
size
());
axes
=
CreateConst
(
graph
,
node
,
std
::
vector
<
int
>
{
axes_
},
{
dim
},
ONNXDataType
::
INT32
);
}
}
auto
decrease_axis_
=
auto
decrease_axis_
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"decrease_axis"
));
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"decrease_axis"
));
auto
input_shape_
=
GetInputVarNode
(
"Input"
,
node
)
->
Var
()
->
GetShape
();
auto
output_shape_
=
GetOutputVarNode
(
"Out"
,
node
)
->
Var
()
->
GetShape
();
if
(
decrease_axis_
.
size
()
==
0
)
{
if
(
decrease_axis_
.
size
()
==
0
)
{
return
CreateBaseOp
(
return
CreateSlice
(
graph
,
graph
,
node
,
node
,
"popart_slice"
,
{
GetInputVarNode
(
"Input"
,
node
)},
{
GetInputVarNode
(
"Input"
,
node
),
starts
,
ends
,
axes
->
outputs
[
0
]},
{
GetOutputVarNode
(
"Out"
,
node
)},
node
->
outputs
);
slice_values
[
0
],
}
else
if
(
output_shape_
==
std
::
vector
<
int64_t
>
{
0
}
||
slice_values
[
1
],
input_shape_
.
size
()
>
output_shape_
.
size
())
{
axes_value
,
auto
slice
=
CreateBaseOp
(
slice_values
[
2
]);
graph
,
}
else
{
auto
*
slice
=
CreateSlice
(
graph
,
node
,
node
,
"popart_slice"
,
{
GetInputVarNode
(
"Input"
,
node
)},
{
GetInputVarNode
(
"Input"
,
node
),
starts
,
ends
,
axes
->
outputs
[
0
]},
{},
{},
{});
slice_values
[
0
],
slice_values
[
1
],
axes_value
,
slice_values
[
2
])
->
outputs
[
0
];
return
CreateBaseOp
(
return
CreateBaseOp
(
graph
,
graph
,
node
,
node
,
"popart_squeeze"
,
"popart_squeeze"
,
{
slice
->
outputs
[
0
]
},
{
slice
},
{
GetOutputVarNode
(
"Out"
,
node
)},
{
GetOutputVarNode
(
"Out"
,
node
)},
{{
"axes"
,
{{
"axes"
,
std
::
vector
<
int64_t
>
{
decrease_axis_
.
begin
(),
decrease_axis_
.
end
()}}});
std
::
vector
<
int64_t
>
{
decrease_axis_
.
begin
(),
decrease_axis_
.
end
()}}});
}
else
{
return
CreateBaseOp
(
graph
,
node
,
"popart_slice"
,
{
GetInputVarNode
(
"Input"
,
node
),
starts
,
ends
,
axes
->
outputs
[
0
]},
node
->
outputs
);
}
}
}
}
Node
*
strided_slice_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
slice_handler
(
graph
,
node
);
}
Node
*
expand_handler
(
Graph
*
graph
,
Node
*
node
)
{
Node
*
expand_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
*
op
=
node
->
Op
();
if
(
!
op
->
Input
(
"expand_times_tensor"
).
empty
())
{
if
(
!
op
->
Input
(
"expand_times_tensor"
).
empty
())
{
...
@@ -552,6 +512,10 @@ Node *assign_handler(Graph *graph, Node *node) {
...
@@ -552,6 +512,10 @@ Node *assign_handler(Graph *graph, Node *node) {
{});
{});
}
}
Node
*
share_data_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
assign_handler
(
graph
,
node
);
}
Node
*
assign_value_handler
(
Graph
*
graph
,
Node
*
node
)
{
Node
*
assign_value_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
*
op
=
node
->
Op
();
auto
dtype_
=
PADDLE_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
auto
dtype_
=
PADDLE_GET_CONST
(
int
,
op
->
GetAttr
(
"dtype"
));
...
@@ -731,15 +695,12 @@ Node *split_handler(Graph *graph, Node *node) {
...
@@ -731,15 +695,12 @@ Node *split_handler(Graph *graph, Node *node) {
auto
*
op
=
node
->
Op
();
auto
*
op
=
node
->
Op
();
auto
axis
=
PADDLE_GET_CONST
(
int
,
op
->
GetAttr
(
"axis"
));
auto
axis
=
PADDLE_GET_CONST
(
int
,
op
->
GetAttr
(
"axis"
));
auto
sections
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"sections"
));
auto
sections
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"sections"
));
return
CreateBaseOp
(
return
CreateSplit
(
graph
,
graph
,
node
,
node
,
"popart_split"
,
{
GetInputVarNode
(
"X"
,
node
)},
{
GetInputVarNode
(
"X"
,
node
)},
node
->
outputs
,
node
->
outputs
,
{{
"num_outputs"
,
int64_t
(
sections
.
size
())},
std
::
vector
<
int64_t
>
{
sections
.
begin
(),
sections
.
end
()},
{
"axis"
,
int64_t
(
axis
)},
axis
);
{
"split"
,
std
::
vector
<
int64_t
>
{
sections
.
begin
(),
sections
.
end
()}}});
}
}
Node
*
dot_handler
(
Graph
*
graph
,
Node
*
node
)
{
Node
*
dot_handler
(
Graph
*
graph
,
Node
*
node
)
{
...
@@ -1116,19 +1077,8 @@ Node *flip_handler(Graph *graph, Node *node) {
...
@@ -1116,19 +1077,8 @@ Node *flip_handler(Graph *graph, Node *node) {
auto
axis
=
axes
[
i
];
auto
axis
=
axes
[
i
];
std
::
vector
<
int64_t
>
split
;
std
::
vector
<
int64_t
>
split
;
split
.
resize
(
input_shape
[
axis
],
1
);
split
.
resize
(
input_shape
[
axis
],
1
);
std
::
vector
<
Node
*>
splits_output_nodes
;
auto
splits_outputs
=
for
(
int
j
=
0
;
j
<
split
.
size
();
j
++
)
{
CreateSplit
(
graph
,
node
,
{
temp_node
},
{},
split
,
axis
)
->
outputs
;
splits_output_nodes
.
push_back
(
MakeVarNode
(
graph
,
node
));
}
auto
splits_outputs
=
CreateBaseOp
(
graph
,
node
,
"popart_split"
,
{
temp_node
},
{
splits_output_nodes
},
{{
"num_outputs"
,
int64_t
(
split
.
size
())},
{
"axis"
,
int64_t
(
axis
)},
{
"split"
,
split
}})
->
outputs
;
std
::
reverse
(
splits_outputs
.
begin
(),
splits_outputs
.
end
());
std
::
reverse
(
splits_outputs
.
begin
(),
splits_outputs
.
end
());
if
(
i
!=
axes
.
size
()
-
1
)
{
if
(
i
!=
axes
.
size
()
-
1
)
{
temp_node
=
CreateBaseOp
(
graph
,
temp_node
=
CreateBaseOp
(
graph
,
...
@@ -1244,6 +1194,70 @@ Node *p_norm_handler(Graph *graph, Node *node) {
...
@@ -1244,6 +1194,70 @@ Node *p_norm_handler(Graph *graph, Node *node) {
{
GetOutputVarNode
(
"Out"
,
node
)});
{
GetOutputVarNode
(
"Out"
,
node
)});
}
}
Node
*
tile_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
inputs
=
op
->
Inputs
();
bool
is_repeat_tensors
=
(
inputs
.
find
(
"RepeatTimes"
)
!=
inputs
.
end
()
&&
!
inputs
.
at
(
"RepeatTimes"
).
empty
())
||
(
inputs
.
find
(
"repeat_times_tensor"
)
!=
inputs
.
end
()
&&
!
inputs
.
at
(
"repeat_times_tensor"
).
empty
());
if
(
is_repeat_tensors
)
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Do not support repeats as tensors."
));
}
auto
repeat_times
=
PADDLE_GET_CONST
(
std
::
vector
<
int
>
,
op
->
GetAttr
(
"repeat_times"
));
int
nums
=
repeat_times
.
size
();
std
::
vector
<
int
>
ones
(
GetInputVarNode
(
"X"
,
node
)
->
Var
()
->
GetShape
().
size
()
-
nums
,
1
);
repeat_times
.
insert
(
repeat_times
.
begin
(),
ones
.
begin
(),
ones
.
end
());
auto
*
repeat_node
=
CreateConst
(
graph
,
node
,
std
::
vector
<
int64_t
>
{
repeat_times
.
begin
(),
repeat_times
.
end
()},
{
int64_t
(
repeat_times
.
size
())},
ONNXDataType
::
INT64
)
->
outputs
[
0
];
return
CreateBaseOp
(
graph
,
node
,
"popart_tile"
,
{
GetInputVarNode
(
"X"
,
node
),
repeat_node
},
{
GetOutputVarNode
(
"Out"
,
node
)});
}
Node
*
unstack_handler
(
Graph
*
graph
,
Node
*
node
)
{
auto
*
op
=
node
->
Op
();
auto
axis
=
PADDLE_GET_CONST
(
int
,
op
->
GetAttr
(
"axis"
));
if
(
axis
<
0
)
{
axis
+=
GetInputVarNode
(
"X"
,
node
)
->
Var
()
->
GetShape
().
size
();
}
std
::
vector
<
int64_t
>
split
(
node
->
outputs
.
size
(),
1
);
auto
split_output_nodes
=
CreateSplit
(
graph
,
node
,
{
GetInputVarNode
(
"X"
,
node
)},
{},
split
,
axis
)
->
outputs
;
Node
*
output
=
nullptr
;
for
(
int
i
=
0
;
i
<
split
.
size
();
i
++
)
{
output
=
CreateBaseOp
(
graph
,
node
,
"popart_squeeze"
,
{
split_output_nodes
[
i
]},
{
node
->
outputs
[
i
]},
{{
"axes"
,
std
::
vector
<
int64_t
>
{
axis
}}});
}
return
output
;
}
Node
*
where_handler
(
Graph
*
graph
,
Node
*
node
)
{
return
CreateBaseOp
(
graph
,
node
,
"popart_where"
,
{
GetInputVarNode
(
"Condition"
,
node
),
GetInputVarNode
(
"X"
,
node
),
GetInputVarNode
(
"Y"
,
node
)},
{
GetOutputVarNode
(
"Out"
,
node
)});
}
}
// namespace
}
// namespace
}
// namespace ipu
}
// namespace ipu
}
// namespace platform
}
// namespace platform
...
@@ -1265,6 +1279,7 @@ REGISTER_HANDLER(concat, concat_handler);
...
@@ -1265,6 +1279,7 @@ REGISTER_HANDLER(concat, concat_handler);
REGISTER_HANDLER
(
stack
,
stack_handler
);
REGISTER_HANDLER
(
stack
,
stack_handler
);
REGISTER_HANDLER
(
shape
,
shape_handler
);
REGISTER_HANDLER
(
shape
,
shape_handler
);
REGISTER_HANDLER
(
slice
,
slice_handler
);
REGISTER_HANDLER
(
slice
,
slice_handler
);
REGISTER_HANDLER
(
strided_slice
,
strided_slice_handler
);
REGISTER_HANDLER
(
expand
,
expand_handler
);
REGISTER_HANDLER
(
expand
,
expand_handler
);
REGISTER_HANDLER
(
expand_v2
,
expand_v2_handler
);
REGISTER_HANDLER
(
expand_v2
,
expand_v2_handler
);
REGISTER_HANDLER
(
expand_as_v2
,
expand_as_v2_handler
);
REGISTER_HANDLER
(
expand_as_v2
,
expand_as_v2_handler
);
...
@@ -1281,3 +1296,7 @@ REGISTER_HANDLER(dist, dist_handler);
...
@@ -1281,3 +1296,7 @@ REGISTER_HANDLER(dist, dist_handler);
REGISTER_HANDLER
(
flip
,
flip_handler
);
REGISTER_HANDLER
(
flip
,
flip_handler
);
REGISTER_HANDLER
(
meshgrid
,
meshgrid_handler
);
REGISTER_HANDLER
(
meshgrid
,
meshgrid_handler
);
REGISTER_HANDLER
(
p_norm
,
p_norm_handler
);
REGISTER_HANDLER
(
p_norm
,
p_norm_handler
);
REGISTER_HANDLER
(
share_data
,
share_data_handler
);
REGISTER_HANDLER
(
tile
,
tile_handler
);
REGISTER_HANDLER
(
unstack
,
unstack_handler
);
REGISTER_HANDLER
(
where
,
where_handler
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录