未验证 提交 7daae985 编写于 作者: Y yaozhixin 提交者: GitHub

[IPU] Add more Ops (#44414)

* [IPU] Add more Ops

* update boost API
上级 1047cb17
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace ipu {
namespace {
Node *yolo_box_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto clip_bbox = PADDLE_GET_CONST(bool, op->GetAttr("clip_bbox"));
auto iou_aware = PADDLE_GET_CONST(bool, op->GetAttr("iou_aware"));
auto conf_thresh = PADDLE_GET_CONST(float, op->GetAttr("conf_thresh"));
auto iou_aware_factor =
PADDLE_GET_CONST(float, op->GetAttr("iou_aware_factor"));
auto class_num = PADDLE_GET_CONST(int, op->GetAttr("class_num"));
auto downsample_ratio =
PADDLE_GET_CONST(int, op->GetAttr("downsample_ratio"));
auto scale_x_y = PADDLE_GET_CONST(float, op->GetAttr("scale_x_y"));
auto anchors = PADDLE_GET_CONST(std::vector<int>, op->GetAttr("anchors"));
// For Slice Op, while value is very large, it equals to the ends.
int max_int = INT_MAX;
int anchor_num = anchors.size() / 2;
// FP32 or FP16
auto target_dtype = GetInputVarNode("X", node)->Var()->GetDataType();
Node *input_x = GetInputVarNode("X", node);
if (iou_aware) {
input_x =
CreateSlice(graph,
node,
{input_x},
{},
std::vector<int>{0, 0, 0, 0},
std::vector<int>{max_int, anchor_num, max_int, max_int},
std::vector<int>{0, 1, 2, 3},
std::vector<int>{1, 1, 1, 1})
->outputs[0];
}
auto nchw = GetInputVarNode("X", node)->Var()->GetShape();
// Channel `C` = anchor_num * (5 + class_num)
auto *reshaped_x =
CreateReshape(
graph,
node,
{input_x},
{},
std::vector<int64_t>{nchw[0], anchor_num, -1, nchw[2], nchw[3]})
->outputs[0];
auto *transposed_x =
CreateBaseOp(graph,
node,
"popart_transpose",
{reshaped_x},
{},
{{"perm", std::vector<int64_t>{0, 1, 3, 4, 2}}})
->outputs[0];
// Build the grid
// grid_x_0 shape is [w]
std::vector<float> grid_x_0(nchw[3]);
std::iota(grid_x_0.begin(), grid_x_0.end(), 0.0f);
// grid_y_0 shape is [h]
std::vector<float> grid_y_0(nchw[2]);
std::iota(grid_y_0.begin(), grid_y_0.end(), 0.0f);
// grid_x_1 shape is [w * h]
std::vector<float> grid_x_1;
for (int i = 0; i < nchw[2]; i++) {
grid_x_1.insert(grid_x_1.end(), grid_x_0.begin(), grid_x_0.end());
}
auto *grid_x_1_node = CreateConst(graph,
node,
grid_x_1,
{int64_t(grid_x_1.size())},
VarType2OnnxDType(target_dtype))
->outputs[0];
// grid_y_1 shape is [h * w]
std::vector<float> grid_y_1;
for (int i = 0; i < nchw[3]; i++) {
grid_y_1.insert(grid_y_1.end(), grid_y_0.begin(), grid_y_0.end());
}
auto *grid_y_1_node = CreateConst(graph,
node,
grid_y_1,
{int64_t(grid_y_1.size())},
VarType2OnnxDType(target_dtype))
->outputs[0];
auto *grid_x_node = CreateReshape(graph,
node,
{grid_x_1_node},
{},
std::vector<int64_t>{nchw[2], nchw[3], 1})
->outputs[0];
auto *grid_y_2_node = CreateReshape(graph,
node,
{grid_y_1_node},
{},
std::vector<int64_t>{nchw[3], nchw[2], 1})
->outputs[0];
auto *grid_y_node = CreateBaseOp(graph,
node,
"popart_transpose",
{grid_y_2_node},
{},
{{"perm", std::vector<int64_t>{1, 0, 2}}})
->outputs[0];
auto *grid_node = CreateBaseOp(graph,
node,
"popart_concat",
{grid_x_node, grid_y_node},
{},
{{"axis", int64_t(2)}})
->outputs[0];
// Generate the positions(x, y) of boxes
// pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0]) *
// scale_x_y + bias_x_y) / w pred_box[:, :, :, :, 1] = (grid_y +
// sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h
auto *pred_box_xy =
CreateSlice(graph,
node,
{transposed_x},
{},
std::vector<int>{0, 0, 0, 0, 0},
std::vector<int>{max_int, max_int, max_int, max_int, 2},
std::vector<int>{0, 1, 2, 3, 4},
std::vector<int>{1, 1, 1, 1, 1})
->outputs[0];
auto *scale_x_y_node = CreateConst(graph,
node,
std::vector<float>{scale_x_y},
{int64_t(1)},
VarType2OnnxDType(target_dtype))
->outputs[0];
auto *bias_x_y_node =
CreateConst(graph,
node,
std::vector<float>{(1.0f - scale_x_y) / 2.0f},
{int64_t(1)},
VarType2OnnxDType(target_dtype))
->outputs[0];
auto *wh = CreateConst(graph,
node,
std::vector<float>{static_cast<float>(nchw[3]),
static_cast<float>(nchw[2])},
{int64_t(2)},
VarType2OnnxDType(target_dtype))
->outputs[0];
pred_box_xy = CreateBaseOp(graph, node, "popart_sigmoid", {pred_box_xy}, {})
->outputs[0];
pred_box_xy =
CreateBaseOp(graph, node, "popart_mul", {pred_box_xy, scale_x_y_node}, {})
->outputs[0];
pred_box_xy =
CreateBaseOp(graph, node, "popart_add", {pred_box_xy, bias_x_y_node}, {})
->outputs[0];
pred_box_xy =
CreateBaseOp(graph, node, "popart_add", {pred_box_xy, grid_node}, {})
->outputs[0];
pred_box_xy = CreateBaseOp(graph, node, "popart_div", {pred_box_xy, wh}, {})
->outputs[0];
// Generate Width and Height of boxes
// anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)]
// anchors_s = np.array(
// [(an_w / input_w, an_h / input_h) for an_w, an_h in anchors])
// anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1))
// anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1))
auto *anchors_node =
CreateConst(
graph,
node,
std::vector<float>{anchors.begin(), anchors.begin() + anchor_num * 2},
{int64_t(anchor_num * 2)},
VarType2OnnxDType(target_dtype))
->outputs[0];
anchors_node =
CreateReshape(
graph, node, {anchors_node}, {}, std::vector<int64_t>{anchor_num, 2})
->outputs[0];
auto *downsample_node =
CreateConst(graph,
node,
std::vector<float>{static_cast<float>(downsample_ratio)},
{int64_t(1)},
VarType2OnnxDType(target_dtype))
->outputs[0];
auto *ori_wh =
CreateBaseOp(graph, node, "popart_mul", {wh, downsample_node}, {})
->outputs[0];
anchors_node =
CreateBaseOp(graph, node, "popart_div", {anchors_node, ori_wh}, {})
->outputs[0];
anchors_node = CreateReshape(graph,
node,
{anchors_node},
{},
std::vector<int64_t>{1, anchor_num, 1, 1, 2})
->outputs[0];
auto *pred_box_wh =
CreateSlice(graph,
node,
{transposed_x},
{},
std::vector<int>{0, 0, 0, 0, 2},
std::vector<int>{max_int, max_int, max_int, max_int, 4},
std::vector<int>{0, 1, 2, 3, 4},
std::vector<int>{1, 1, 1, 1, 1})
->outputs[0];
pred_box_wh =
CreateBaseOp(graph, node, "popart_exp", {pred_box_wh}, {})->outputs[0];
pred_box_wh =
CreateBaseOp(graph, node, "popart_mul", {pred_box_wh, anchors_node}, {})
->outputs[0];
// Ignore the boxes whose confidience lower than the threshold
// if iou_aware:
// pred_conf = sigmoid(x[:, :, :, :, 4:5])**(
// 1 - iou_aware_factor) * sigmoid(ioup)**iou_aware_factor
// else:
// pred_conf = sigmoid(x[:, :, :, :, 4:5])
auto *confidence =
CreateSlice(graph,
node,
{transposed_x},
{},
std::vector<int>{0, 0, 0, 0, 4},
std::vector<int>{max_int, max_int, max_int, max_int, 5},
std::vector<int>{0, 1, 2, 3, 4},
std::vector<int>{1, 1, 1, 1, 1})
->outputs[0];
auto *pred_conf =
CreateBaseOp(graph, node, "popart_sigmoid", {confidence}, {})->outputs[0];
if (iou_aware) {
auto *ioup =
CreateSlice(graph,
node,
{GetInputVarNode("X", node)},
{},
std::vector<int>{0, 0, 0, 0},
std::vector<int>{max_int, anchor_num, max_int, max_int},
std::vector<int>{0, 1, 2, 3},
std::vector<int>{1, 1, 1, 1})
->outputs[0];
ioup = CreateBaseOp(graph,
node,
"popart_unsqueeze",
{ioup},
{},
{{"axes", std::vector<int64_t>{4}}})
->outputs[0];
ioup = CreateBaseOp(graph, node, "popart_sigmoid", {ioup}, {})->outputs[0];
auto *power_0 = CreateConst(graph,
node,
std::vector<float>{1.0f - iou_aware_factor},
{int64_t(1)},
VarType2OnnxDType(target_dtype))
->outputs[0];
auto *power_1 = CreateConst(graph,
node,
std::vector<float>{iou_aware_factor},
{int64_t(1)},
VarType2OnnxDType(target_dtype))
->outputs[0];
ioup = CreateBaseOp(graph, node, "popart_pow", {ioup, power_1}, {})
->outputs[0];
pred_conf =
CreateBaseOp(graph, node, "popart_pow", {pred_conf, power_0}, {})
->outputs[0];
pred_conf = CreateBaseOp(graph, node, "popart_mul", {pred_conf, ioup}, {})
->outputs[0];
}
// pred_conf[pred_conf < conf_thresh] = 0.
// pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf
// pred_box = pred_box * (pred_conf > 0.).astype('float32')
auto *value_2 = CreateConst(graph,
node,
std::vector<float>{2.0f},
{int64_t(1)},
VarType2OnnxDType(target_dtype))
->outputs[0];
auto *center =
CreateBaseOp(graph, node, "popart_div", {pred_box_wh, value_2}, {})
->outputs[0];
auto *min_xy =
CreateBaseOp(graph, node, "popart_sub", {pred_box_xy, center}, {})
->outputs[0];
auto *max_xy =
CreateBaseOp(graph, node, "popart_add", {pred_box_xy, center}, {})
->outputs[0];
auto *conf_thresh_node = CreateConst(graph,
node,
std::vector<float>{conf_thresh},
{int64_t(1)},
VarType2OnnxDType(target_dtype))
->outputs[0];
auto *filter =
CreateBaseOp(
graph, node, "popart_greater", {pred_conf, conf_thresh_node}, {})
->outputs[0];
filter = CreateCast(graph, node, {filter}, {}, target_dtype)->outputs[0];
pred_conf = CreateBaseOp(graph, node, "popart_mul", {pred_conf, filter}, {})
->outputs[0];
auto *pred_score =
CreateSlice(graph,
node,
{transposed_x},
{},
std::vector<int>{0, 0, 0, 0, 5},
std::vector<int>{max_int, max_int, max_int, max_int, max_int},
std::vector<int>{0, 1, 2, 3, 4},
std::vector<int>{1, 1, 1, 1, 1})
->outputs[0];
pred_score =
CreateBaseOp(graph, node, "popart_sigmoid", {pred_score}, {})->outputs[0];
pred_score =
CreateBaseOp(graph, node, "popart_mul", {pred_score, pred_conf}, {})
->outputs[0];
auto *pred_box = CreateBaseOp(graph,
node,
"popart_concat",
{min_xy, max_xy},
{},
{{"axis", int64_t(4)}})
->outputs[0];
pred_box = CreateBaseOp(graph, node, "popart_mul", {pred_box, filter}, {})
->outputs[0];
pred_box =
CreateReshape(
graph, node, {pred_box}, {}, std::vector<int64_t>{nchw[0], -1, 4})
->outputs[0];
// Clip the boxes to img_size
auto *float_img_size =
CreateCast(
graph, node, {GetInputVarNode("ImgSize", node)}, {}, target_dtype)
->outputs[0];
float_img_size = CreateBaseOp(graph,
node,
"popart_unsqueeze",
{float_img_size},
{},
{{"axes", std::vector<int64_t>(1)}})
->outputs[0];
auto split_im_hw =
CreateSplit(
graph, node, {float_img_size}, {}, std::vector<int64_t>{1, 1}, 2)
->outputs;
auto *im_whwh =
CreateBaseOp(
graph,
node,
"popart_concat",
{split_im_hw[1], split_im_hw[0], split_im_hw[1], split_im_hw[0]},
{},
{{"axis", int64_t(2)}})
->outputs[0];
if (!clip_bbox) {
auto *out = CreateBaseOp(graph, node, "popart_mul", {pred_box, im_whwh}, {})
->outputs[0];
CreateCast(graph,
node,
{out},
{GetOutputVarNode("Boxes", node)},
GetOutputVarNode("Boxes", node)->Var()->GetDataType());
} else {
pred_box = CreateBaseOp(graph, node, "popart_mul", {pred_box, im_whwh}, {})
->outputs[0];
auto *im_wh = CreateBaseOp(graph,
node,
"popart_concat",
{split_im_hw[1], split_im_hw[0]},
{},
{{"axis", int64_t(2)}})
->outputs[0];
auto *float_value_1 = CreateConst(graph,
node,
std::vector<float>{1.0f},
{int64_t(1)},
VarType2OnnxDType(target_dtype))
->outputs[0];
im_wh = CreateBaseOp(graph, node, "popart_sub", {im_wh, float_value_1}, {})
->outputs[0];
auto pred_box_xymin_xymax =
CreateSplit(graph, node, {pred_box}, {}, std::vector<int64_t>{2, 2}, 2)
->outputs;
pred_box_xymin_xymax[0] =
CreateBaseOp(graph, node, "popart_relu", {pred_box_xymin_xymax[0]}, {})
->outputs[0];
pred_box_xymin_xymax[1] =
CreateBaseOp(
graph, node, "popart_min", {pred_box_xymin_xymax[1], im_wh}, {})
->outputs[0];
auto *out = CreateBaseOp(graph,
node,
"popart_concat",
pred_box_xymin_xymax,
{},
{{"axis", int64_t(2)}})
->outputs[0];
CreateCast(graph,
node,
{out},
{GetOutputVarNode("Boxes", node)},
GetOutputVarNode("Boxes", node)->Var()->GetDataType());
}
auto *score_out = CreateReshape(graph,
node,
{pred_score},
{},
std::vector<int64_t>{nchw[0], -1, class_num})
->outputs[0];
return CreateCast(graph,
node,
{score_out},
{GetOutputVarNode("Scores", node)},
GetOutputVarNode("Scores", node)->Var()->GetDataType());
}
} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
REGISTER_HANDLER(yolo_box, yolo_box_handler);
...@@ -656,30 +656,14 @@ Node *interp_handler(Graph *graph, Node *node, const std::string &mode) { ...@@ -656,30 +656,14 @@ Node *interp_handler(Graph *graph, Node *node, const std::string &mode) {
CreateBaseOp( CreateBaseOp(
graph, node, "popart_shape", {GetInputVarNode("X", node)}, {}) graph, node, "popart_shape", {GetInputVarNode("X", node)}, {})
->outputs[0]; ->outputs[0];
Node *start = CreateConst(graph, Node *nc = CreateSlice(graph,
node, node,
{input_shape},
{},
std::vector<int>{0}, std::vector<int>{0},
std::vector<int64_t>{1},
ONNXDataType::INT32)
->outputs[0];
Node *end = CreateConst(graph,
node,
std::vector<int>{2}, std::vector<int>{2},
std::vector<int64_t>{1},
ONNXDataType::INT32)
->outputs[0];
Node *axes = CreateConst(graph,
node,
std::vector<int>{0}, std::vector<int>{0},
std::vector<int64_t>{1}, std::vector<int>{1})
ONNXDataType::INT32)
->outputs[0];
Node *nc = CreateBaseOp(graph,
node,
"popart_slice",
{input_shape, start, end, axes},
{},
{})
->outputs[0]; ->outputs[0];
size = CreateBaseOp(graph, size = CreateBaseOp(graph,
node, node,
......
...@@ -256,6 +256,69 @@ Node *CreateSoftmaxOpset11(Graph *graph, ...@@ -256,6 +256,69 @@ Node *CreateSoftmaxOpset11(Graph *graph,
} }
} }
Node *CreateSlice(Graph *graph,
Node *node,
const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs,
const std::vector<int> &starts,
const std::vector<int> &ends,
const std::vector<int> &axes,
const std::vector<int> &strides) {
auto *starts_node =
CreateConst(
graph, node, starts, {int64_t(starts.size())}, ONNXDataType::INT32)
->outputs[0];
auto *ends_node =
CreateConst(
graph, node, ends, {int64_t(ends.size())}, ONNXDataType::INT32)
->outputs[0];
auto *axes_node =
CreateConst(
graph, node, axes, {int64_t(axes.size())}, ONNXDataType::INT32)
->outputs[0];
auto *strides_node =
CreateConst(
graph, node, strides, {int64_t(strides.size())}, ONNXDataType::INT32)
->outputs[0];
return CreateBaseOp(
graph,
node,
"popart_slice",
{inputs[0], starts_node, ends_node, axes_node, strides_node},
outputs);
}
Node *CreateSplit(Graph *graph,
Node *node,
const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs,
const std::vector<int64_t> &split,
const int64_t axis) {
if (!outputs.empty()) {
return CreateBaseOp(graph,
node,
"popart_split",
inputs,
outputs,
{{"num_outputs", int64_t(split.size())},
{"axis", int64_t(axis)},
{"split", split}});
} else {
std::vector<Node *> splits_output_nodes;
for (int j = 0; j < split.size(); j++) {
splits_output_nodes.push_back(MakeVarNode(graph, node));
}
return CreateBaseOp(graph,
node,
"popart_split",
inputs,
{splits_output_nodes},
{{"num_outputs", int64_t(split.size())},
{"axis", int64_t(axis)},
{"split", split}});
}
}
} // namespace ipu } // namespace ipu
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -104,6 +104,22 @@ Node *CreateSoftmaxOpset11(Graph *graph, ...@@ -104,6 +104,22 @@ Node *CreateSoftmaxOpset11(Graph *graph,
const std::vector<Node *> &outputs, const std::vector<Node *> &outputs,
int64_t axis); int64_t axis);
Node *CreateSlice(Graph *graph,
Node *node,
const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs,
const std::vector<int> &starts,
const std::vector<int> &ends,
const std::vector<int> &axes,
const std::vector<int> &strides);
Node *CreateSplit(Graph *graph,
Node *node,
const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs,
const std::vector<int64_t> &split,
const int64_t axis);
} // namespace ipu } // namespace ipu
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -249,94 +249,57 @@ Node *lookup_table_op_handler(Graph *graph, ...@@ -249,94 +249,57 @@ Node *lookup_table_op_handler(Graph *graph,
{{"value", const_value_}, {{"value", const_value_},
{"dims", const_shape_}, {"dims", const_shape_},
{"dtype", GetOutputVarDType(node)}}); {"dtype", GetOutputVarDType(node)}});
auto axes = CreateConst(graph, if (padding_idx_ == 0) {
node, auto right_slice =
{}, CreateSlice(graph,
{},
{{"value", std::vector<int64_t>{0}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT64}});
auto step = CreateConst(graph,
node,
{},
{},
{{"value", std::vector<int64_t>{1}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT64}});
auto left_start = CreateConst(graph,
node,
{},
{},
{{"value", std::vector<int64_t>{0}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT64}});
auto left_end = CreateConst(graph,
node,
{},
{},
{{"value", std::vector<int64_t>{padding_idx_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT64}});
auto right_start =
CreateConst(graph,
node,
{},
{},
{{"value", std::vector<int64_t>{padding_idx_ + 1}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT64}});
auto right_end = CreateConst(graph,
node,
{},
{},
{{"value", std::vector<int64_t>{table_size_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT64}});
auto left_slice = CreateBaseOp(graph,
node, node,
"popart_slice", {GetInputVarNode("W", node)},
{GetInputVarNode("W", node),
left_start->outputs[0],
left_end->outputs[0],
axes->outputs[0],
step->outputs[0]},
{},
{});
auto right_slice = CreateBaseOp(graph,
node,
"popart_slice",
{GetInputVarNode("W", node),
right_start->outputs[0],
right_end->outputs[0],
axes->outputs[0],
step->outputs[0]},
{}, {},
{}); std::vector<int>{static_cast<int>(padding_idx_) + 1},
std::vector<int>{static_cast<int>(table_size_)},
if (padding_idx_ == 0) { std::vector<int>{0},
std::vector<int>{1});
w_node = CreateBaseOp(graph, w_node = CreateBaseOp(graph,
node, node,
"popart_concat", "popart_concat",
{concat_const->outputs[0], right_slice->outputs[0]}, {concat_const->outputs[0], right_slice->outputs[0]},
{}, {},
{{"axis", int64_t(0)}}); {{"axis", int64_t(0)}});
ClearNode(left_start);
ClearNode(left_end);
ClearNode(left_slice);
} else if (padding_idx_ == table_size_ - 1) { } else if (padding_idx_ == table_size_ - 1) {
auto left_slice =
CreateSlice(graph,
node,
{GetInputVarNode("W", node)},
{},
std::vector<int>{0},
std::vector<int>{static_cast<int>(padding_idx_)},
std::vector<int>{0},
std::vector<int>{1});
w_node = CreateBaseOp(graph, w_node = CreateBaseOp(graph,
node, node,
"popart_concat", "popart_concat",
{left_slice->outputs[0], concat_const->outputs[0]}, {left_slice->outputs[0], concat_const->outputs[0]},
{}, {},
{{"axis", int64_t{0}}}); {{"axis", int64_t{0}}});
ClearNode(right_start);
ClearNode(right_end);
ClearNode(right_slice);
} else { } else {
auto left_slice =
CreateSlice(graph,
node,
{GetInputVarNode("W", node)},
{},
std::vector<int>{0},
std::vector<int>{static_cast<int>(padding_idx_)},
std::vector<int>{0},
std::vector<int>{1});
auto right_slice =
CreateSlice(graph,
node,
{GetInputVarNode("W", node)},
{},
std::vector<int>{static_cast<int>(padding_idx_) + 1},
std::vector<int>{static_cast<int>(table_size_)},
std::vector<int>{0},
std::vector<int>{1});
w_node = CreateBaseOp(graph, w_node = CreateBaseOp(graph,
node, node,
"popart_concat", "popart_concat",
...@@ -441,72 +404,69 @@ Node *shape_handler(Graph *graph, Node *node) { ...@@ -441,72 +404,69 @@ Node *shape_handler(Graph *graph, Node *node) {
Node *slice_handler(Graph *graph, Node *node) { Node *slice_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
Node *starts = nullptr; auto inputs = op->Inputs();
if (!op->HasAttr("starts")) {
starts = GetInputVarNode("StartsTensor", node); auto axes_value = PADDLE_GET_CONST(std::vector<int>, op->GetAttr("axes"));
std::vector<std::vector<int>> slice_values(3);
std::vector<std::string> tensor_names{"Starts", "Ends", "Strides"};
std::vector<std::string> attr_names{"starts", "ends", "strides"};
for (int i = 0; i < 3; i++) {
// Starts and Ends are default keys in inputs, but Strides.
bool is_tensor =
(inputs.find(tensor_names[i] + "TensorList") != inputs.end() &&
!inputs.at(tensor_names[i] + "TensorList").empty()) ||
(inputs.find(tensor_names[i] + "Tensor") != inputs.end() &&
!inputs.at(tensor_names[i] + "Tensor").empty());
if (is_tensor) {
PADDLE_THROW(platform::errors::Unimplemented(
"Do not support starts, ends and strides as tensors."));
} else { } else {
auto starts_ = PADDLE_GET_CONST(std::vector<int>, op->GetAttr("starts")); if (i == 2 && !op->HasAttr("strides")) {
auto dim = int64_t(starts_.size()); slice_values[i] = std::vector<int>(axes_value.size(), 1);
starts = CreateConst(
graph, node, std::vector<int>{starts_}, {dim}, ONNXDataType::INT32);
starts = starts->outputs[0];
}
Node *ends = nullptr;
if (!op->HasAttr("ends")) {
ends = GetInputVarNode("EndsTensor", node);
} else { } else {
auto ends_ = PADDLE_GET_CONST(std::vector<int>, op->GetAttr("ends")); slice_values[i] =
auto dim = int64_t(ends_.size()); PADDLE_GET_CONST(std::vector<int>, op->GetAttr(attr_names[i]));
ends = CreateConst( }
graph, node, std::vector<int>{ends_}, {dim}, ONNXDataType::INT32);
ends = ends->outputs[0];
} }
Node *axes = nullptr;
{
auto axes_ = PADDLE_GET_CONST(std::vector<int>, op->GetAttr("axes"));
auto dim = int64_t(axes_.size());
axes = CreateConst(
graph, node, std::vector<int>{axes_}, {dim}, ONNXDataType::INT32);
} }
auto decrease_axis_ = auto decrease_axis_ =
PADDLE_GET_CONST(std::vector<int>, op->GetAttr("decrease_axis")); PADDLE_GET_CONST(std::vector<int>, op->GetAttr("decrease_axis"));
auto input_shape_ = GetInputVarNode("Input", node)->Var()->GetShape();
auto output_shape_ = GetOutputVarNode("Out", node)->Var()->GetShape();
if (decrease_axis_.size() == 0) { if (decrease_axis_.size() == 0) {
return CreateBaseOp( return CreateSlice(graph,
graph,
node, node,
"popart_slice", {GetInputVarNode("Input", node)},
{GetInputVarNode("Input", node), starts, ends, axes->outputs[0]}, {GetOutputVarNode("Out", node)},
node->outputs); slice_values[0],
} else if (output_shape_ == std::vector<int64_t>{0} || slice_values[1],
input_shape_.size() > output_shape_.size()) { axes_value,
auto slice = CreateBaseOp( slice_values[2]);
graph, } else {
auto *slice = CreateSlice(graph,
node, node,
"popart_slice", {GetInputVarNode("Input", node)},
{GetInputVarNode("Input", node), starts, ends, axes->outputs[0]},
{}, {},
{}); slice_values[0],
slice_values[1],
axes_value,
slice_values[2])
->outputs[0];
return CreateBaseOp( return CreateBaseOp(
graph, graph,
node, node,
"popart_squeeze", "popart_squeeze",
{slice->outputs[0]}, {slice},
{GetOutputVarNode("Out", node)}, {GetOutputVarNode("Out", node)},
{{"axes", {{"axes",
std::vector<int64_t>{decrease_axis_.begin(), decrease_axis_.end()}}}); std::vector<int64_t>{decrease_axis_.begin(), decrease_axis_.end()}}});
} else {
return CreateBaseOp(
graph,
node,
"popart_slice",
{GetInputVarNode("Input", node), starts, ends, axes->outputs[0]},
node->outputs);
} }
} }
Node *strided_slice_handler(Graph *graph, Node *node) {
return slice_handler(graph, node);
}
Node *expand_handler(Graph *graph, Node *node) { Node *expand_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
if (!op->Input("expand_times_tensor").empty()) { if (!op->Input("expand_times_tensor").empty()) {
...@@ -552,6 +512,10 @@ Node *assign_handler(Graph *graph, Node *node) { ...@@ -552,6 +512,10 @@ Node *assign_handler(Graph *graph, Node *node) {
{}); {});
} }
Node *share_data_handler(Graph *graph, Node *node) {
return assign_handler(graph, node);
}
Node *assign_value_handler(Graph *graph, Node *node) { Node *assign_value_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto dtype_ = PADDLE_GET_CONST(int, op->GetAttr("dtype")); auto dtype_ = PADDLE_GET_CONST(int, op->GetAttr("dtype"));
...@@ -731,15 +695,12 @@ Node *split_handler(Graph *graph, Node *node) { ...@@ -731,15 +695,12 @@ Node *split_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto axis = PADDLE_GET_CONST(int, op->GetAttr("axis")); auto axis = PADDLE_GET_CONST(int, op->GetAttr("axis"));
auto sections = PADDLE_GET_CONST(std::vector<int>, op->GetAttr("sections")); auto sections = PADDLE_GET_CONST(std::vector<int>, op->GetAttr("sections"));
return CreateBaseOp( return CreateSplit(graph,
graph,
node, node,
"popart_split",
{GetInputVarNode("X", node)}, {GetInputVarNode("X", node)},
node->outputs, node->outputs,
{{"num_outputs", int64_t(sections.size())}, std::vector<int64_t>{sections.begin(), sections.end()},
{"axis", int64_t(axis)}, axis);
{"split", std::vector<int64_t>{sections.begin(), sections.end()}}});
} }
Node *dot_handler(Graph *graph, Node *node) { Node *dot_handler(Graph *graph, Node *node) {
...@@ -1116,19 +1077,8 @@ Node *flip_handler(Graph *graph, Node *node) { ...@@ -1116,19 +1077,8 @@ Node *flip_handler(Graph *graph, Node *node) {
auto axis = axes[i]; auto axis = axes[i];
std::vector<int64_t> split; std::vector<int64_t> split;
split.resize(input_shape[axis], 1); split.resize(input_shape[axis], 1);
std::vector<Node *> splits_output_nodes; auto splits_outputs =
for (int j = 0; j < split.size(); j++) { CreateSplit(graph, node, {temp_node}, {}, split, axis)->outputs;
splits_output_nodes.push_back(MakeVarNode(graph, node));
}
auto splits_outputs = CreateBaseOp(graph,
node,
"popart_split",
{temp_node},
{splits_output_nodes},
{{"num_outputs", int64_t(split.size())},
{"axis", int64_t(axis)},
{"split", split}})
->outputs;
std::reverse(splits_outputs.begin(), splits_outputs.end()); std::reverse(splits_outputs.begin(), splits_outputs.end());
if (i != axes.size() - 1) { if (i != axes.size() - 1) {
temp_node = CreateBaseOp(graph, temp_node = CreateBaseOp(graph,
...@@ -1244,6 +1194,70 @@ Node *p_norm_handler(Graph *graph, Node *node) { ...@@ -1244,6 +1194,70 @@ Node *p_norm_handler(Graph *graph, Node *node) {
{GetOutputVarNode("Out", node)}); {GetOutputVarNode("Out", node)});
} }
Node *tile_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto inputs = op->Inputs();
bool is_repeat_tensors =
(inputs.find("RepeatTimes") != inputs.end() &&
!inputs.at("RepeatTimes").empty()) ||
(inputs.find("repeat_times_tensor") != inputs.end() &&
!inputs.at("repeat_times_tensor").empty());
if (is_repeat_tensors) {
PADDLE_THROW(
platform::errors::Unimplemented("Do not support repeats as tensors."));
}
auto repeat_times =
PADDLE_GET_CONST(std::vector<int>, op->GetAttr("repeat_times"));
int nums = repeat_times.size();
std::vector<int> ones(
GetInputVarNode("X", node)->Var()->GetShape().size() - nums, 1);
repeat_times.insert(repeat_times.begin(), ones.begin(), ones.end());
auto *repeat_node = CreateConst(graph,
node,
std::vector<int64_t>{repeat_times.begin(),
repeat_times.end()},
{int64_t(repeat_times.size())},
ONNXDataType::INT64)
->outputs[0];
return CreateBaseOp(graph,
node,
"popart_tile",
{GetInputVarNode("X", node), repeat_node},
{GetOutputVarNode("Out", node)});
}
Node *unstack_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto axis = PADDLE_GET_CONST(int, op->GetAttr("axis"));
if (axis < 0) {
axis += GetInputVarNode("X", node)->Var()->GetShape().size();
}
std::vector<int64_t> split(node->outputs.size(), 1);
auto split_output_nodes =
CreateSplit(graph, node, {GetInputVarNode("X", node)}, {}, split, axis)
->outputs;
Node *output = nullptr;
for (int i = 0; i < split.size(); i++) {
output = CreateBaseOp(graph,
node,
"popart_squeeze",
{split_output_nodes[i]},
{node->outputs[i]},
{{"axes", std::vector<int64_t>{axis}}});
}
return output;
}
Node *where_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph,
node,
"popart_where",
{GetInputVarNode("Condition", node),
GetInputVarNode("X", node),
GetInputVarNode("Y", node)},
{GetOutputVarNode("Out", node)});
}
} // namespace } // namespace
} // namespace ipu } // namespace ipu
} // namespace platform } // namespace platform
...@@ -1265,6 +1279,7 @@ REGISTER_HANDLER(concat, concat_handler); ...@@ -1265,6 +1279,7 @@ REGISTER_HANDLER(concat, concat_handler);
REGISTER_HANDLER(stack, stack_handler); REGISTER_HANDLER(stack, stack_handler);
REGISTER_HANDLER(shape, shape_handler); REGISTER_HANDLER(shape, shape_handler);
REGISTER_HANDLER(slice, slice_handler); REGISTER_HANDLER(slice, slice_handler);
REGISTER_HANDLER(strided_slice, strided_slice_handler);
REGISTER_HANDLER(expand, expand_handler); REGISTER_HANDLER(expand, expand_handler);
REGISTER_HANDLER(expand_v2, expand_v2_handler); REGISTER_HANDLER(expand_v2, expand_v2_handler);
REGISTER_HANDLER(expand_as_v2, expand_as_v2_handler); REGISTER_HANDLER(expand_as_v2, expand_as_v2_handler);
...@@ -1281,3 +1296,7 @@ REGISTER_HANDLER(dist, dist_handler); ...@@ -1281,3 +1296,7 @@ REGISTER_HANDLER(dist, dist_handler);
REGISTER_HANDLER(flip, flip_handler); REGISTER_HANDLER(flip, flip_handler);
REGISTER_HANDLER(meshgrid, meshgrid_handler); REGISTER_HANDLER(meshgrid, meshgrid_handler);
REGISTER_HANDLER(p_norm, p_norm_handler); REGISTER_HANDLER(p_norm, p_norm_handler);
REGISTER_HANDLER(share_data, share_data_handler);
REGISTER_HANDLER(tile, tile_handler);
REGISTER_HANDLER(unstack, unstack_handler);
REGISTER_HANDLER(where, where_handler);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册