diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/detection_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/detection_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..159e1dfc1493a5cb0a878e3e6c4cecfa49bb1ce7 --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/detection_ops.cc @@ -0,0 +1,444 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *yolo_box_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto clip_bbox = PADDLE_GET_CONST(bool, op->GetAttr("clip_bbox")); + auto iou_aware = PADDLE_GET_CONST(bool, op->GetAttr("iou_aware")); + auto conf_thresh = PADDLE_GET_CONST(float, op->GetAttr("conf_thresh")); + auto iou_aware_factor = + PADDLE_GET_CONST(float, op->GetAttr("iou_aware_factor")); + auto class_num = PADDLE_GET_CONST(int, op->GetAttr("class_num")); + auto downsample_ratio = + PADDLE_GET_CONST(int, op->GetAttr("downsample_ratio")); + auto scale_x_y = PADDLE_GET_CONST(float, op->GetAttr("scale_x_y")); + auto anchors = PADDLE_GET_CONST(std::vector, op->GetAttr("anchors")); + + // For Slice Op, while value is very large, it equals to the ends. + int max_int = INT_MAX; + int anchor_num = anchors.size() / 2; + + // FP32 or FP16 + auto target_dtype = GetInputVarNode("X", node)->Var()->GetDataType(); + + Node *input_x = GetInputVarNode("X", node); + if (iou_aware) { + input_x = + CreateSlice(graph, + node, + {input_x}, + {}, + std::vector{0, 0, 0, 0}, + std::vector{max_int, anchor_num, max_int, max_int}, + std::vector{0, 1, 2, 3}, + std::vector{1, 1, 1, 1}) + ->outputs[0]; + } + auto nchw = GetInputVarNode("X", node)->Var()->GetShape(); + // Channel `C` = anchor_num * (5 + class_num) + auto *reshaped_x = + CreateReshape( + graph, + node, + {input_x}, + {}, + std::vector{nchw[0], anchor_num, -1, nchw[2], nchw[3]}) + ->outputs[0]; + auto *transposed_x = + CreateBaseOp(graph, + node, + "popart_transpose", + {reshaped_x}, + {}, + {{"perm", std::vector{0, 1, 3, 4, 2}}}) + ->outputs[0]; + + // Build the grid + // grid_x_0 shape is [w] + std::vector grid_x_0(nchw[3]); + std::iota(grid_x_0.begin(), grid_x_0.end(), 0.0f); + // grid_y_0 shape is [h] + std::vector grid_y_0(nchw[2]); + std::iota(grid_y_0.begin(), grid_y_0.end(), 0.0f); + // grid_x_1 shape is [w * h] + std::vector grid_x_1; + for (int i = 0; i < nchw[2]; i++) { + grid_x_1.insert(grid_x_1.end(), grid_x_0.begin(), grid_x_0.end()); + } + auto *grid_x_1_node = CreateConst(graph, + node, + grid_x_1, + {int64_t(grid_x_1.size())}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + // grid_y_1 shape is [h * w] + std::vector grid_y_1; + for (int i = 0; i < nchw[3]; i++) { + grid_y_1.insert(grid_y_1.end(), grid_y_0.begin(), grid_y_0.end()); + } + auto *grid_y_1_node = CreateConst(graph, + node, + grid_y_1, + {int64_t(grid_y_1.size())}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + auto *grid_x_node = CreateReshape(graph, + node, + {grid_x_1_node}, + {}, + std::vector{nchw[2], nchw[3], 1}) + ->outputs[0]; + auto *grid_y_2_node = CreateReshape(graph, + node, + {grid_y_1_node}, + {}, + std::vector{nchw[3], nchw[2], 1}) + ->outputs[0]; + auto *grid_y_node = CreateBaseOp(graph, + node, + "popart_transpose", + {grid_y_2_node}, + {}, + {{"perm", std::vector{1, 0, 2}}}) + ->outputs[0]; + auto *grid_node = CreateBaseOp(graph, + node, + "popart_concat", + {grid_x_node, grid_y_node}, + {}, + {{"axis", int64_t(2)}}) + ->outputs[0]; + + // Generate the positions(x, y) of boxes + // pred_box[:, :, :, :, 0] = (grid_x + sigmoid(pred_box[:, :, :, :, 0]) * + // scale_x_y + bias_x_y) / w pred_box[:, :, :, :, 1] = (grid_y + + // sigmoid(pred_box[:, :, :, :, 1]) * scale_x_y + bias_x_y) / h + auto *pred_box_xy = + CreateSlice(graph, + node, + {transposed_x}, + {}, + std::vector{0, 0, 0, 0, 0}, + std::vector{max_int, max_int, max_int, max_int, 2}, + std::vector{0, 1, 2, 3, 4}, + std::vector{1, 1, 1, 1, 1}) + ->outputs[0]; + auto *scale_x_y_node = CreateConst(graph, + node, + std::vector{scale_x_y}, + {int64_t(1)}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + auto *bias_x_y_node = + CreateConst(graph, + node, + std::vector{(1.0f - scale_x_y) / 2.0f}, + {int64_t(1)}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + auto *wh = CreateConst(graph, + node, + std::vector{static_cast(nchw[3]), + static_cast(nchw[2])}, + {int64_t(2)}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + pred_box_xy = CreateBaseOp(graph, node, "popart_sigmoid", {pred_box_xy}, {}) + ->outputs[0]; + pred_box_xy = + CreateBaseOp(graph, node, "popart_mul", {pred_box_xy, scale_x_y_node}, {}) + ->outputs[0]; + pred_box_xy = + CreateBaseOp(graph, node, "popart_add", {pred_box_xy, bias_x_y_node}, {}) + ->outputs[0]; + pred_box_xy = + CreateBaseOp(graph, node, "popart_add", {pred_box_xy, grid_node}, {}) + ->outputs[0]; + pred_box_xy = CreateBaseOp(graph, node, "popart_div", {pred_box_xy, wh}, {}) + ->outputs[0]; + + // Generate Width and Height of boxes + // anchors = [(anchors[i], anchors[i + 1]) for i in range(0, len(anchors), 2)] + // anchors_s = np.array( + // [(an_w / input_w, an_h / input_h) for an_w, an_h in anchors]) + // anchor_w = anchors_s[:, 0:1].reshape((1, an_num, 1, 1)) + // anchor_h = anchors_s[:, 1:2].reshape((1, an_num, 1, 1)) + auto *anchors_node = + CreateConst( + graph, + node, + std::vector{anchors.begin(), anchors.begin() + anchor_num * 2}, + {int64_t(anchor_num * 2)}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + anchors_node = + CreateReshape( + graph, node, {anchors_node}, {}, std::vector{anchor_num, 2}) + ->outputs[0]; + auto *downsample_node = + CreateConst(graph, + node, + std::vector{static_cast(downsample_ratio)}, + {int64_t(1)}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + auto *ori_wh = + CreateBaseOp(graph, node, "popart_mul", {wh, downsample_node}, {}) + ->outputs[0]; + anchors_node = + CreateBaseOp(graph, node, "popart_div", {anchors_node, ori_wh}, {}) + ->outputs[0]; + anchors_node = CreateReshape(graph, + node, + {anchors_node}, + {}, + std::vector{1, anchor_num, 1, 1, 2}) + ->outputs[0]; + auto *pred_box_wh = + CreateSlice(graph, + node, + {transposed_x}, + {}, + std::vector{0, 0, 0, 0, 2}, + std::vector{max_int, max_int, max_int, max_int, 4}, + std::vector{0, 1, 2, 3, 4}, + std::vector{1, 1, 1, 1, 1}) + ->outputs[0]; + pred_box_wh = + CreateBaseOp(graph, node, "popart_exp", {pred_box_wh}, {})->outputs[0]; + pred_box_wh = + CreateBaseOp(graph, node, "popart_mul", {pred_box_wh, anchors_node}, {}) + ->outputs[0]; + + // Ignore the boxes whose confidience lower than the threshold + // if iou_aware: + // pred_conf = sigmoid(x[:, :, :, :, 4:5])**( + // 1 - iou_aware_factor) * sigmoid(ioup)**iou_aware_factor + // else: + // pred_conf = sigmoid(x[:, :, :, :, 4:5]) + auto *confidence = + CreateSlice(graph, + node, + {transposed_x}, + {}, + std::vector{0, 0, 0, 0, 4}, + std::vector{max_int, max_int, max_int, max_int, 5}, + std::vector{0, 1, 2, 3, 4}, + std::vector{1, 1, 1, 1, 1}) + ->outputs[0]; + auto *pred_conf = + CreateBaseOp(graph, node, "popart_sigmoid", {confidence}, {})->outputs[0]; + if (iou_aware) { + auto *ioup = + CreateSlice(graph, + node, + {GetInputVarNode("X", node)}, + {}, + std::vector{0, 0, 0, 0}, + std::vector{max_int, anchor_num, max_int, max_int}, + std::vector{0, 1, 2, 3}, + std::vector{1, 1, 1, 1}) + ->outputs[0]; + ioup = CreateBaseOp(graph, + node, + "popart_unsqueeze", + {ioup}, + {}, + {{"axes", std::vector{4}}}) + ->outputs[0]; + ioup = CreateBaseOp(graph, node, "popart_sigmoid", {ioup}, {})->outputs[0]; + auto *power_0 = CreateConst(graph, + node, + std::vector{1.0f - iou_aware_factor}, + {int64_t(1)}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + auto *power_1 = CreateConst(graph, + node, + std::vector{iou_aware_factor}, + {int64_t(1)}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + ioup = CreateBaseOp(graph, node, "popart_pow", {ioup, power_1}, {}) + ->outputs[0]; + pred_conf = + CreateBaseOp(graph, node, "popart_pow", {pred_conf, power_0}, {}) + ->outputs[0]; + pred_conf = CreateBaseOp(graph, node, "popart_mul", {pred_conf, ioup}, {}) + ->outputs[0]; + } + // pred_conf[pred_conf < conf_thresh] = 0. + // pred_score = sigmoid(x[:, :, :, :, 5:]) * pred_conf + // pred_box = pred_box * (pred_conf > 0.).astype('float32') + auto *value_2 = CreateConst(graph, + node, + std::vector{2.0f}, + {int64_t(1)}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + auto *center = + CreateBaseOp(graph, node, "popart_div", {pred_box_wh, value_2}, {}) + ->outputs[0]; + auto *min_xy = + CreateBaseOp(graph, node, "popart_sub", {pred_box_xy, center}, {}) + ->outputs[0]; + auto *max_xy = + CreateBaseOp(graph, node, "popart_add", {pred_box_xy, center}, {}) + ->outputs[0]; + + auto *conf_thresh_node = CreateConst(graph, + node, + std::vector{conf_thresh}, + {int64_t(1)}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + auto *filter = + CreateBaseOp( + graph, node, "popart_greater", {pred_conf, conf_thresh_node}, {}) + ->outputs[0]; + filter = CreateCast(graph, node, {filter}, {}, target_dtype)->outputs[0]; + pred_conf = CreateBaseOp(graph, node, "popart_mul", {pred_conf, filter}, {}) + ->outputs[0]; + auto *pred_score = + CreateSlice(graph, + node, + {transposed_x}, + {}, + std::vector{0, 0, 0, 0, 5}, + std::vector{max_int, max_int, max_int, max_int, max_int}, + std::vector{0, 1, 2, 3, 4}, + std::vector{1, 1, 1, 1, 1}) + ->outputs[0]; + pred_score = + CreateBaseOp(graph, node, "popart_sigmoid", {pred_score}, {})->outputs[0]; + pred_score = + CreateBaseOp(graph, node, "popart_mul", {pred_score, pred_conf}, {}) + ->outputs[0]; + auto *pred_box = CreateBaseOp(graph, + node, + "popart_concat", + {min_xy, max_xy}, + {}, + {{"axis", int64_t(4)}}) + ->outputs[0]; + pred_box = CreateBaseOp(graph, node, "popart_mul", {pred_box, filter}, {}) + ->outputs[0]; + pred_box = + CreateReshape( + graph, node, {pred_box}, {}, std::vector{nchw[0], -1, 4}) + ->outputs[0]; + + // Clip the boxes to img_size + auto *float_img_size = + CreateCast( + graph, node, {GetInputVarNode("ImgSize", node)}, {}, target_dtype) + ->outputs[0]; + float_img_size = CreateBaseOp(graph, + node, + "popart_unsqueeze", + {float_img_size}, + {}, + {{"axes", std::vector(1)}}) + ->outputs[0]; + auto split_im_hw = + CreateSplit( + graph, node, {float_img_size}, {}, std::vector{1, 1}, 2) + ->outputs; + auto *im_whwh = + CreateBaseOp( + graph, + node, + "popart_concat", + {split_im_hw[1], split_im_hw[0], split_im_hw[1], split_im_hw[0]}, + {}, + {{"axis", int64_t(2)}}) + ->outputs[0]; + if (!clip_bbox) { + auto *out = CreateBaseOp(graph, node, "popart_mul", {pred_box, im_whwh}, {}) + ->outputs[0]; + CreateCast(graph, + node, + {out}, + {GetOutputVarNode("Boxes", node)}, + GetOutputVarNode("Boxes", node)->Var()->GetDataType()); + + } else { + pred_box = CreateBaseOp(graph, node, "popart_mul", {pred_box, im_whwh}, {}) + ->outputs[0]; + auto *im_wh = CreateBaseOp(graph, + node, + "popart_concat", + {split_im_hw[1], split_im_hw[0]}, + {}, + {{"axis", int64_t(2)}}) + ->outputs[0]; + auto *float_value_1 = CreateConst(graph, + node, + std::vector{1.0f}, + {int64_t(1)}, + VarType2OnnxDType(target_dtype)) + ->outputs[0]; + im_wh = CreateBaseOp(graph, node, "popart_sub", {im_wh, float_value_1}, {}) + ->outputs[0]; + auto pred_box_xymin_xymax = + CreateSplit(graph, node, {pred_box}, {}, std::vector{2, 2}, 2) + ->outputs; + pred_box_xymin_xymax[0] = + CreateBaseOp(graph, node, "popart_relu", {pred_box_xymin_xymax[0]}, {}) + ->outputs[0]; + pred_box_xymin_xymax[1] = + CreateBaseOp( + graph, node, "popart_min", {pred_box_xymin_xymax[1], im_wh}, {}) + ->outputs[0]; + auto *out = CreateBaseOp(graph, + node, + "popart_concat", + pred_box_xymin_xymax, + {}, + {{"axis", int64_t(2)}}) + ->outputs[0]; + CreateCast(graph, + node, + {out}, + {GetOutputVarNode("Boxes", node)}, + GetOutputVarNode("Boxes", node)->Var()->GetDataType()); + } + auto *score_out = CreateReshape(graph, + node, + {pred_score}, + {}, + std::vector{nchw[0], -1, class_num}) + ->outputs[0]; + return CreateCast(graph, + node, + {score_out}, + {GetOutputVarNode("Scores", node)}, + GetOutputVarNode("Scores", node)->Var()->GetDataType()); +} + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle + +REGISTER_HANDLER(yolo_box, yolo_box_handler); diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc index e28531f349d1436de8c90ee0270f21012474e29f..a016647efc99749b7b3597174654e6666e10b67b 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc @@ -656,30 +656,14 @@ Node *interp_handler(Graph *graph, Node *node, const std::string &mode) { CreateBaseOp( graph, node, "popart_shape", {GetInputVarNode("X", node)}, {}) ->outputs[0]; - Node *start = CreateConst(graph, - node, - std::vector{0}, - std::vector{1}, - ONNXDataType::INT32) - ->outputs[0]; - Node *end = CreateConst(graph, - node, - std::vector{2}, - std::vector{1}, - ONNXDataType::INT32) - ->outputs[0]; - Node *axes = CreateConst(graph, - node, - std::vector{0}, - std::vector{1}, - ONNXDataType::INT32) - ->outputs[0]; - Node *nc = CreateBaseOp(graph, - node, - "popart_slice", - {input_shape, start, end, axes}, - {}, - {}) + Node *nc = CreateSlice(graph, + node, + {input_shape}, + {}, + std::vector{0}, + std::vector{2}, + std::vector{0}, + std::vector{1}) ->outputs[0]; size = CreateBaseOp(graph, node, diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc index 6badf37d5b334ef5f999ff572e17ef4c53be0941..ceaf21377c2b01130b932f7c775ae3e8f2866465 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc @@ -256,6 +256,69 @@ Node *CreateSoftmaxOpset11(Graph *graph, } } +Node *CreateSlice(Graph *graph, + Node *node, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &starts, + const std::vector &ends, + const std::vector &axes, + const std::vector &strides) { + auto *starts_node = + CreateConst( + graph, node, starts, {int64_t(starts.size())}, ONNXDataType::INT32) + ->outputs[0]; + auto *ends_node = + CreateConst( + graph, node, ends, {int64_t(ends.size())}, ONNXDataType::INT32) + ->outputs[0]; + auto *axes_node = + CreateConst( + graph, node, axes, {int64_t(axes.size())}, ONNXDataType::INT32) + ->outputs[0]; + auto *strides_node = + CreateConst( + graph, node, strides, {int64_t(strides.size())}, ONNXDataType::INT32) + ->outputs[0]; + return CreateBaseOp( + graph, + node, + "popart_slice", + {inputs[0], starts_node, ends_node, axes_node, strides_node}, + outputs); +} + +Node *CreateSplit(Graph *graph, + Node *node, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &split, + const int64_t axis) { + if (!outputs.empty()) { + return CreateBaseOp(graph, + node, + "popart_split", + inputs, + outputs, + {{"num_outputs", int64_t(split.size())}, + {"axis", int64_t(axis)}, + {"split", split}}); + } else { + std::vector splits_output_nodes; + for (int j = 0; j < split.size(); j++) { + splits_output_nodes.push_back(MakeVarNode(graph, node)); + } + return CreateBaseOp(graph, + node, + "popart_split", + inputs, + {splits_output_nodes}, + {{"num_outputs", int64_t(split.size())}, + {"axis", int64_t(axis)}, + {"split", split}}); + } +} + } // namespace ipu } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h index 3071c2a0b90cf0756956b72d2bb6463e39139cfb..ad2a20ae324aeeed2076c9cb6589067fac37f14d 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h @@ -104,6 +104,22 @@ Node *CreateSoftmaxOpset11(Graph *graph, const std::vector &outputs, int64_t axis); +Node *CreateSlice(Graph *graph, + Node *node, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &starts, + const std::vector &ends, + const std::vector &axes, + const std::vector &strides); + +Node *CreateSplit(Graph *graph, + Node *node, + const std::vector &inputs, + const std::vector &outputs, + const std::vector &split, + const int64_t axis); + } // namespace ipu } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc index 95d273bb66f01c3a3b2627903d42e3dacb069a74..9df51d5c42fc94b81372deb0afce8bd1c625d71d 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc @@ -249,94 +249,57 @@ Node *lookup_table_op_handler(Graph *graph, {{"value", const_value_}, {"dims", const_shape_}, {"dtype", GetOutputVarDType(node)}}); - auto axes = CreateConst(graph, - node, - {}, - {}, - {{"value", std::vector{0}}, - {"dims", std::vector{1}}, - {"dtype", ONNXDataType::INT64}}); - auto step = CreateConst(graph, - node, - {}, - {}, - {{"value", std::vector{1}}, - {"dims", std::vector{1}}, - {"dtype", ONNXDataType::INT64}}); - - auto left_start = CreateConst(graph, - node, - {}, - {}, - {{"value", std::vector{0}}, - {"dims", std::vector{1}}, - {"dtype", ONNXDataType::INT64}}); - auto left_end = CreateConst(graph, - node, - {}, - {}, - {{"value", std::vector{padding_idx_}}, - {"dims", std::vector{1}}, - {"dtype", ONNXDataType::INT64}}); - - auto right_start = - CreateConst(graph, - node, - {}, - {}, - {{"value", std::vector{padding_idx_ + 1}}, - {"dims", std::vector{1}}, - {"dtype", ONNXDataType::INT64}}); - auto right_end = CreateConst(graph, - node, - {}, - {}, - {{"value", std::vector{table_size_}}, - {"dims", std::vector{1}}, - {"dtype", ONNXDataType::INT64}}); - - auto left_slice = CreateBaseOp(graph, - node, - "popart_slice", - {GetInputVarNode("W", node), - left_start->outputs[0], - left_end->outputs[0], - axes->outputs[0], - step->outputs[0]}, - {}, - {}); - auto right_slice = CreateBaseOp(graph, - node, - "popart_slice", - {GetInputVarNode("W", node), - right_start->outputs[0], - right_end->outputs[0], - axes->outputs[0], - step->outputs[0]}, - {}, - {}); - if (padding_idx_ == 0) { + auto right_slice = + CreateSlice(graph, + node, + {GetInputVarNode("W", node)}, + {}, + std::vector{static_cast(padding_idx_) + 1}, + std::vector{static_cast(table_size_)}, + std::vector{0}, + std::vector{1}); w_node = CreateBaseOp(graph, node, "popart_concat", {concat_const->outputs[0], right_slice->outputs[0]}, {}, {{"axis", int64_t(0)}}); - ClearNode(left_start); - ClearNode(left_end); - ClearNode(left_slice); } else if (padding_idx_ == table_size_ - 1) { + auto left_slice = + CreateSlice(graph, + node, + {GetInputVarNode("W", node)}, + {}, + std::vector{0}, + std::vector{static_cast(padding_idx_)}, + std::vector{0}, + std::vector{1}); w_node = CreateBaseOp(graph, node, "popart_concat", {left_slice->outputs[0], concat_const->outputs[0]}, {}, {{"axis", int64_t{0}}}); - ClearNode(right_start); - ClearNode(right_end); - ClearNode(right_slice); } else { + auto left_slice = + CreateSlice(graph, + node, + {GetInputVarNode("W", node)}, + {}, + std::vector{0}, + std::vector{static_cast(padding_idx_)}, + std::vector{0}, + std::vector{1}); + auto right_slice = + CreateSlice(graph, + node, + {GetInputVarNode("W", node)}, + {}, + std::vector{static_cast(padding_idx_) + 1}, + std::vector{static_cast(table_size_)}, + std::vector{0}, + std::vector{1}); w_node = CreateBaseOp(graph, node, "popart_concat", @@ -441,72 +404,69 @@ Node *shape_handler(Graph *graph, Node *node) { Node *slice_handler(Graph *graph, Node *node) { auto *op = node->Op(); - Node *starts = nullptr; - if (!op->HasAttr("starts")) { - starts = GetInputVarNode("StartsTensor", node); - } else { - auto starts_ = PADDLE_GET_CONST(std::vector, op->GetAttr("starts")); - auto dim = int64_t(starts_.size()); - starts = CreateConst( - graph, node, std::vector{starts_}, {dim}, ONNXDataType::INT32); - starts = starts->outputs[0]; - } - Node *ends = nullptr; - if (!op->HasAttr("ends")) { - ends = GetInputVarNode("EndsTensor", node); - } else { - auto ends_ = PADDLE_GET_CONST(std::vector, op->GetAttr("ends")); - auto dim = int64_t(ends_.size()); - ends = CreateConst( - graph, node, std::vector{ends_}, {dim}, ONNXDataType::INT32); - ends = ends->outputs[0]; - } - Node *axes = nullptr; - { - auto axes_ = PADDLE_GET_CONST(std::vector, op->GetAttr("axes")); - auto dim = int64_t(axes_.size()); - axes = CreateConst( - graph, node, std::vector{axes_}, {dim}, ONNXDataType::INT32); + auto inputs = op->Inputs(); + + auto axes_value = PADDLE_GET_CONST(std::vector, op->GetAttr("axes")); + + std::vector> slice_values(3); + std::vector tensor_names{"Starts", "Ends", "Strides"}; + std::vector attr_names{"starts", "ends", "strides"}; + for (int i = 0; i < 3; i++) { + // Starts and Ends are default keys in inputs, but Strides. + bool is_tensor = + (inputs.find(tensor_names[i] + "TensorList") != inputs.end() && + !inputs.at(tensor_names[i] + "TensorList").empty()) || + (inputs.find(tensor_names[i] + "Tensor") != inputs.end() && + !inputs.at(tensor_names[i] + "Tensor").empty()); + if (is_tensor) { + PADDLE_THROW(platform::errors::Unimplemented( + "Do not support starts, ends and strides as tensors.")); + } else { + if (i == 2 && !op->HasAttr("strides")) { + slice_values[i] = std::vector(axes_value.size(), 1); + } else { + slice_values[i] = + PADDLE_GET_CONST(std::vector, op->GetAttr(attr_names[i])); + } + } } auto decrease_axis_ = PADDLE_GET_CONST(std::vector, op->GetAttr("decrease_axis")); - auto input_shape_ = GetInputVarNode("Input", node)->Var()->GetShape(); - auto output_shape_ = GetOutputVarNode("Out", node)->Var()->GetShape(); if (decrease_axis_.size() == 0) { - return CreateBaseOp( - graph, - node, - "popart_slice", - {GetInputVarNode("Input", node), starts, ends, axes->outputs[0]}, - node->outputs); - } else if (output_shape_ == std::vector{0} || - input_shape_.size() > output_shape_.size()) { - auto slice = CreateBaseOp( - graph, - node, - "popart_slice", - {GetInputVarNode("Input", node), starts, ends, axes->outputs[0]}, - {}, - {}); + return CreateSlice(graph, + node, + {GetInputVarNode("Input", node)}, + {GetOutputVarNode("Out", node)}, + slice_values[0], + slice_values[1], + axes_value, + slice_values[2]); + } else { + auto *slice = CreateSlice(graph, + node, + {GetInputVarNode("Input", node)}, + {}, + slice_values[0], + slice_values[1], + axes_value, + slice_values[2]) + ->outputs[0]; return CreateBaseOp( graph, node, "popart_squeeze", - {slice->outputs[0]}, + {slice}, {GetOutputVarNode("Out", node)}, {{"axes", std::vector{decrease_axis_.begin(), decrease_axis_.end()}}}); - } else { - return CreateBaseOp( - graph, - node, - "popart_slice", - {GetInputVarNode("Input", node), starts, ends, axes->outputs[0]}, - node->outputs); } } +Node *strided_slice_handler(Graph *graph, Node *node) { + return slice_handler(graph, node); +} + Node *expand_handler(Graph *graph, Node *node) { auto *op = node->Op(); if (!op->Input("expand_times_tensor").empty()) { @@ -552,6 +512,10 @@ Node *assign_handler(Graph *graph, Node *node) { {}); } +Node *share_data_handler(Graph *graph, Node *node) { + return assign_handler(graph, node); +} + Node *assign_value_handler(Graph *graph, Node *node) { auto *op = node->Op(); auto dtype_ = PADDLE_GET_CONST(int, op->GetAttr("dtype")); @@ -731,15 +695,12 @@ Node *split_handler(Graph *graph, Node *node) { auto *op = node->Op(); auto axis = PADDLE_GET_CONST(int, op->GetAttr("axis")); auto sections = PADDLE_GET_CONST(std::vector, op->GetAttr("sections")); - return CreateBaseOp( - graph, - node, - "popart_split", - {GetInputVarNode("X", node)}, - node->outputs, - {{"num_outputs", int64_t(sections.size())}, - {"axis", int64_t(axis)}, - {"split", std::vector{sections.begin(), sections.end()}}}); + return CreateSplit(graph, + node, + {GetInputVarNode("X", node)}, + node->outputs, + std::vector{sections.begin(), sections.end()}, + axis); } Node *dot_handler(Graph *graph, Node *node) { @@ -1116,19 +1077,8 @@ Node *flip_handler(Graph *graph, Node *node) { auto axis = axes[i]; std::vector split; split.resize(input_shape[axis], 1); - std::vector splits_output_nodes; - for (int j = 0; j < split.size(); j++) { - splits_output_nodes.push_back(MakeVarNode(graph, node)); - } - auto splits_outputs = CreateBaseOp(graph, - node, - "popart_split", - {temp_node}, - {splits_output_nodes}, - {{"num_outputs", int64_t(split.size())}, - {"axis", int64_t(axis)}, - {"split", split}}) - ->outputs; + auto splits_outputs = + CreateSplit(graph, node, {temp_node}, {}, split, axis)->outputs; std::reverse(splits_outputs.begin(), splits_outputs.end()); if (i != axes.size() - 1) { temp_node = CreateBaseOp(graph, @@ -1244,6 +1194,70 @@ Node *p_norm_handler(Graph *graph, Node *node) { {GetOutputVarNode("Out", node)}); } +Node *tile_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto inputs = op->Inputs(); + bool is_repeat_tensors = + (inputs.find("RepeatTimes") != inputs.end() && + !inputs.at("RepeatTimes").empty()) || + (inputs.find("repeat_times_tensor") != inputs.end() && + !inputs.at("repeat_times_tensor").empty()); + if (is_repeat_tensors) { + PADDLE_THROW( + platform::errors::Unimplemented("Do not support repeats as tensors.")); + } + auto repeat_times = + PADDLE_GET_CONST(std::vector, op->GetAttr("repeat_times")); + int nums = repeat_times.size(); + std::vector ones( + GetInputVarNode("X", node)->Var()->GetShape().size() - nums, 1); + repeat_times.insert(repeat_times.begin(), ones.begin(), ones.end()); + auto *repeat_node = CreateConst(graph, + node, + std::vector{repeat_times.begin(), + repeat_times.end()}, + {int64_t(repeat_times.size())}, + ONNXDataType::INT64) + ->outputs[0]; + return CreateBaseOp(graph, + node, + "popart_tile", + {GetInputVarNode("X", node), repeat_node}, + {GetOutputVarNode("Out", node)}); +} + +Node *unstack_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto axis = PADDLE_GET_CONST(int, op->GetAttr("axis")); + if (axis < 0) { + axis += GetInputVarNode("X", node)->Var()->GetShape().size(); + } + std::vector split(node->outputs.size(), 1); + auto split_output_nodes = + CreateSplit(graph, node, {GetInputVarNode("X", node)}, {}, split, axis) + ->outputs; + Node *output = nullptr; + for (int i = 0; i < split.size(); i++) { + output = CreateBaseOp(graph, + node, + "popart_squeeze", + {split_output_nodes[i]}, + {node->outputs[i]}, + {{"axes", std::vector{axis}}}); + } + return output; +} + +Node *where_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, + node, + "popart_where", + {GetInputVarNode("Condition", node), + GetInputVarNode("X", node), + GetInputVarNode("Y", node)}, + {GetOutputVarNode("Out", node)}); +} + } // namespace } // namespace ipu } // namespace platform @@ -1265,6 +1279,7 @@ REGISTER_HANDLER(concat, concat_handler); REGISTER_HANDLER(stack, stack_handler); REGISTER_HANDLER(shape, shape_handler); REGISTER_HANDLER(slice, slice_handler); +REGISTER_HANDLER(strided_slice, strided_slice_handler); REGISTER_HANDLER(expand, expand_handler); REGISTER_HANDLER(expand_v2, expand_v2_handler); REGISTER_HANDLER(expand_as_v2, expand_as_v2_handler); @@ -1281,3 +1296,7 @@ REGISTER_HANDLER(dist, dist_handler); REGISTER_HANDLER(flip, flip_handler); REGISTER_HANDLER(meshgrid, meshgrid_handler); REGISTER_HANDLER(p_norm, p_norm_handler); +REGISTER_HANDLER(share_data, share_data_handler); +REGISTER_HANDLER(tile, tile_handler); +REGISTER_HANDLER(unstack, unstack_handler); +REGISTER_HANDLER(where, where_handler);