diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc index 5f0ba745ed3c909909645238e399921be652550f..21c9beade3082ea1cc0e2d92e64a518e14d24fd3 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc @@ -376,11 +376,473 @@ Node *dropout_handler(Graph *graph, Node *node) { } } +Node *conv2d_transpose_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + + auto data_format = BOOST_GET_CONST(std::string, op->GetAttr("data_format")); + if (data_format != "NCHW") { + PADDLE_THROW( + platform::errors::InvalidArgument("Only support NCHW as data_format.")); + } + + auto *kernel_info = GetInputVarNode("Filter", node); + auto kernel_shape = kernel_info->Var()->GetShape(); + + auto dilations_ = BOOST_GET_CONST(std::vector, op->GetAttr("dilations")); + auto dilations = std::vector{dilations_.begin(), dilations_.end()}; + auto strides_ = BOOST_GET_CONST(std::vector, op->GetAttr("strides")); + auto strides = std::vector{strides_.begin(), strides_.end()}; + auto output_padding_ = + BOOST_GET_CONST(std::vector, op->GetAttr("output_padding")); + auto output_padding = + std::vector{output_padding_.begin(), output_padding_.end()}; + auto group_ = BOOST_GET_CONST(int, op->GetAttr("groups")); + auto group = int64_t(group_); + + auto padding_algorithm = + BOOST_GET_CONST(std::string, op->GetAttr("padding_algorithm")); + + auto paddings_ = BOOST_GET_CONST(std::vector, op->GetAttr("paddings")); + if (paddings_.size() == 2) { + paddings_.push_back(paddings_[0]); + paddings_.push_back(paddings_[1]); + } else if (paddings_.size() == 4) { + std::swap(paddings_[1], paddings_[2]); + } + auto paddings = std::vector{paddings_.begin(), paddings_.end()}; + + if (padding_algorithm == "SAME") { + // Update paddings and dilations based on the sizes of H and W. + auto input_shape = GetInputVarNode("Input", node)->Var()->GetShape(); + for (auto i = 0; i < 2; i++) { + auto out_size = (input_shape[i + 2] + strides[i] - 1) / strides[i]; + auto pad_sum = std::max( + (out_size - 1) * strides[i] + kernel_shape[i] - input_shape[i + 2], + static_cast(0)); + auto pad_0 = pad_sum / 2; + auto pad_1 = pad_sum - pad_0; + paddings[i] = pad_0; + paddings[i + 2] = pad_1; + } + for (auto i = 0; i < dilations.size(); i++) { + dilations[i] = 1; + } + } else if (padding_algorithm == "VALID") { + for (auto i = 0; i < paddings.size(); i++) { + paddings[i] = 0; + } + } + + auto attrs = AttributeMap{{"dilations", dilations}, + {"group", group}, + {"kernel_shape", kernel_shape}, + {"output_padding", output_padding}, + {"pads", paddings}, + {"strides", strides}}; + if (!op->Input("Bias").empty()) { + return CreateBaseOp(graph, + node, + "popart_convtranspose", + { + GetInputVarNode("Input", node), + GetInputVarNode("Filter", node), + GetInputVarNode("Bias", node), + }, + node->outputs, + attrs); + } else { + return CreateBaseOp(graph, + node, + "popart_convtranspose", + { + GetInputVarNode("Input", node), + GetInputVarNode("Filter", node), + }, + node->outputs, + attrs); + } +} + +Node *affine_channel_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + + auto data_layout = BOOST_GET_CONST(std::string, op->GetAttr("data_layout")); + if (data_layout != "NCHW") { + PADDLE_THROW( + platform::errors::InvalidArgument("Only support NCHW as data_format.")); + } + + auto *scale = GetInputVarNode("Scale", node); + auto *bias = GetInputVarNode("Bias", node); + auto scale_shape = scale->Var()->GetShape(); + auto bias_shape = bias->Var()->GetShape(); + if (scale_shape.size() <= 1 || bias_shape.size() <= 1) { + auto attrs = AttributeMap{{"value", std::vector{1, -1, 1, 1}}, + {"dims", std::vector{4}}, + {"dtype", ONNXDataType::INT64}}; + auto new_shape_const = CreateConst(graph, node, {}, {}, attrs); + + scale = CreateBaseOp(graph, + node, + "popart_reshape", + {scale, new_shape_const->outputs[0]}, + {}, + {}) + ->outputs[0]; + bias = CreateBaseOp(graph, + node, + "popart_reshape", + {bias, new_shape_const->outputs[0]}, + {}, + {}) + ->outputs[0]; + } + auto *out = CreateBaseOp( + graph, node, "popart_mul", {GetInputVarNode("X", node), scale}, {}); + return CreateBaseOp(graph, + node, + "popart_add", + {out->outputs[0], bias}, + {GetOutputVarNode("Out", node)}); +} + +Node *interp_handler(Graph *graph, Node *node, const std::string &mode) { + auto *op = node->Op(); + + auto data_layout = BOOST_GET_CONST(std::string, op->GetAttr("data_layout")); + if (data_layout != "NCHW") { + PADDLE_THROW( + platform::errors::InvalidArgument("Only support NCHW as data_format.")); + } + + auto align_corners = BOOST_GET_CONST(bool, op->GetAttr("align_corners")); + auto align_mode = BOOST_GET_CONST(int, op->GetAttr("align_mode")); + + auto paddle_target_dtype = VarType::FP32; + auto onnx_target_dtype = ONNXDataType::FLOAT; + if (GetInputVarNode("X", node)->Var()->GetDataType() == VarType::FP16) { + paddle_target_dtype = VarType::FP16; + onnx_target_dtype = ONNXDataType::FLOAT16; + } + + std::string coordinate_transformation_mode = "half_pixel"; + if (align_corners) { + coordinate_transformation_mode = "align_corners"; + } else if (mode == "nearest") { + coordinate_transformation_mode = "asymmetric"; + } else if (align_mode == 1 && mode == "cubic") { + coordinate_transformation_mode = "asymmetric"; + } + + bool has_out_size = node->Op()->Input("OutSize").size() > 0; + bool has_size_tensor = node->Op()->Input("SizeTensor").size() > 0; + bool has_scale_tensor = node->Op()->Input("Scale").size() > 0; + + Node *size = nullptr; + Node *scale = nullptr; + // Input: Size and Scale + if (has_out_size) { + // Get 'size' from the tensor + size = GetInputVarNode("OutSize", node); + if (size->Var()->GetDataType() != VarType::INT64) { + size = CreateCast(graph, + node, + {GetInputVarNode("OutSize", node)}, + {}, + VarType::INT64) + ->outputs[0]; + } + } else if (has_size_tensor) { + // Get 'size' from multi-tensors + std::vector size_nodes; + for (auto var_name : node->Op()->Input("SizeTensor")) { + Node *size_node = GetInputVarNodeByVarName(var_name, node); + if (size_node->Var()->GetDataType() != VarType::INT64) { + size_node = CreateCast(graph, node, {size_node}, {}, VarType::INT64) + ->outputs[0]; + } + size_nodes.push_back(size_node); + } + size = CreateBaseOp(graph, + node, + "popart_concat", + size_nodes, + {}, + {{"axis", int64_t(0)}}) + ->outputs[0]; + } else if (has_scale_tensor) { + // Get 'scale' from tensor + scale = GetInputVarNode("Scale", node); + if (scale->Var()->GetDataType() != paddle_target_dtype) { + scale = + CreateCast(graph, node, {scale}, {}, paddle_target_dtype)->outputs[0]; + } + auto *padding = CreateConst(graph, + node, + {}, + {}, + {{"value", std::vector{1.0, 1.0}}, + {"dims", std::vector{2}}, + {"dtype", onnx_target_dtype}}) + ->outputs[0]; + scale = CreateBaseOp(graph, + node, + "popart_concat", + {padding, scale}, + {}, + {{"axis", int64_t(0)}}) + ->outputs[0]; + } else { + // Get 'size' or 'scale' from attribute + auto out_d = BOOST_GET_CONST(int, op->GetAttr("out_d")); + auto out_h = BOOST_GET_CONST(int, op->GetAttr("out_h")); + auto out_w = BOOST_GET_CONST(int, op->GetAttr("out_w")); + if (out_d > 0 || out_w > 0 || out_h > 0) { + std::vector out_size; + if (GetInputVarNode("X", node)->Var()->GetShape().size() == 5) { + out_size.push_back(int64_t(out_d)); + out_size.push_back(int64_t(out_h)); + } else if (GetInputVarNode("X", node)->Var()->GetShape().size() == 4) { + out_size.push_back(int64_t(out_h)); + } + out_size.push_back(int64_t(out_w)); + size = + CreateConst(graph, + node, + {}, + {}, + {{"value", out_size}, + {"dims", std::vector{int64_t(out_size.size())}}, + {"dtype", ONNXDataType::INT64}}) + ->outputs[0]; + } else { + auto scale_value = + BOOST_GET_CONST(std::vector, op->GetAttr("scale")); + float padding = 1.0; + scale_value.insert(scale_value.begin(), padding); + scale_value.insert(scale_value.begin(), padding); + scale = CreateConst( + graph, + node, + {}, + {}, + {{"value", scale_value}, + {"dims", std::vector{int64_t(scale_value.size())}}, + {"dtype", onnx_target_dtype}}) + ->outputs[0]; + } + } + + Node *roi = + CreateConst( + graph, + node, + {}, + {}, + {{"value", + std::vector( + GetInputVarNode("X", node)->Var()->GetShape().size() * 2, 1.0)}, + {"dims", + std::vector{int64_t( + GetInputVarNode("X", node)->Var()->GetShape().size() * 2)}}, + {"dtype", onnx_target_dtype}}) + ->outputs[0]; + + if (size != nullptr) { + Node *input_shape = + CreateBaseOp( + graph, node, "popart_shape", {GetInputVarNode("X", node)}, {}) + ->outputs[0]; + Node *start = CreateConst(graph, + node, + std::vector{0}, + std::vector{1}, + ONNXDataType::INT32) + ->outputs[0]; + Node *end = CreateConst(graph, + node, + std::vector{2}, + std::vector{1}, + ONNXDataType::INT32) + ->outputs[0]; + Node *axes = CreateConst(graph, + node, + std::vector{0}, + std::vector{1}, + ONNXDataType::INT32) + ->outputs[0]; + Node *nc = CreateBaseOp(graph, + node, + "popart_slice", + {input_shape, start, end, axes}, + {}, + {}) + ->outputs[0]; + size = CreateBaseOp(graph, + node, + "popart_concat", + {nc, size}, + {}, + {{"axis", int64_t(0)}}) + ->outputs[0]; + } + auto resize_attrs = AttributeMap{ + {"coordinate_transformation_mode", coordinate_transformation_mode}, + {"cubic_coeff_a", float{-0.75}}, + {"exclude_outside", int64_t{0}}, + {"extrapolation_value", float{0.0}}, + {"mode", mode}, + {"nearest_mode", std::string("round_prefer_floor")}}; + + if (mode == "nearest" && coordinate_transformation_mode == "asymmetric") { + resize_attrs.at("nearest_mode") = std::string("floor"); + } + + return CreateBaseOp(graph, + node, + "popart_resize", + {GetInputVarNode("X", node), roi, scale, size}, + {GetOutputVarNode("Out", node)}, + resize_attrs); +} + +Node *bilinear_interp_v2_handler(Graph *graph, Node *node) { + return interp_handler(graph, node, "linear"); +} + +Node *nearest_interp_v2_handler(Graph *graph, Node *node) { + return interp_handler(graph, node, "nearest"); +} + +Node *bicubic_interp_v2_handler(Graph *graph, Node *node) { + return interp_handler(graph, node, "cubic"); +} + +Node *linear_interp_v2_handler(Graph *graph, Node *node) { + return interp_handler(graph, node, "linear"); +} + +Node *trilinear_interp_v2_handler(Graph *graph, Node *node) { + return interp_handler(graph, node, "linear"); +} + +Node *data_norm_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + + int slot_dim = -1; + if (op->HasAttr("slot_dim")) { + slot_dim = BOOST_GET_CONST(int, op->GetAttr("slot_dim")); + } + + if (slot_dim > 0) { + PADDLE_THROW( + platform::errors::InvalidArgument("slot_dim > 0 is not supported.")); + } + + bool enable_scale_and_shift = false; + if (op->HasAttr("enable_scale_and_shift")) { + enable_scale_and_shift = + BOOST_GET_CONST(bool, op->GetAttr("enable_scale_and_shift")); + } + + auto *mean_arr = CreateBaseOp(graph, + node, + "popart_div", + {GetInputVarNode("BatchSum", node), + GetInputVarNode("BatchSize", node)}, + {}) + ->outputs[0]; + auto *scale_arr = CreateBaseOp(graph, + node, + "popart_div", + {GetInputVarNode("BatchSize", node), + GetInputVarNode("BatchSquareSum", node)}, + {}) + ->outputs[0]; + scale_arr = + CreateBaseOp(graph, node, "popart_sqrt", {scale_arr}, {})->outputs[0]; + auto out = + CreateBaseOp( + graph, node, "popart_sub", {GetInputVarNode("X", node), mean_arr}, {}) + ->outputs[0]; + + if (enable_scale_and_shift) { + auto scale_res = CreateBaseOp(graph, + node, + "popart_mul", + {out, GetInputVarNode("scale_w", node)}, + {}) + ->outputs[0]; + return CreateBaseOp(graph, + node, + "popart_add", + {scale_res, GetInputVarNode("bias", node)}, + {GetOutputVarNode("Y", node)}); + } else { + return CreateBaseOp(graph, + node, + "popart_mul", + {out, scale_arr}, + {GetOutputVarNode("Y", node)}); + } +} + +Node *pad_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto mode = BOOST_GET_CONST(std::string, op->GetAttr("mode")); + auto value = BOOST_GET_CONST(float, op->GetAttr("value")); + auto data_format = BOOST_GET_CONST(std::string, op->GetAttr("data_format")); + + if (data_format == "NDHWC") { + PADDLE_THROW( + platform::errors::Unimplemented("NDHWC format is not supported.")); + } + if (mode == "replicate" || mode == "circular") { + PADDLE_THROW(platform::errors::Unimplemented( + "circular and replicate modes are not supported.")); + } + if (op->Input("Paddings").size()) { + // Paddings -> input tensor + // PopART Pad Op only support `pad` as a constant + PADDLE_THROW(platform::errors::Unimplemented( + "Do not support Paddings as a inputs tensor")); + } + // Paddings -> Attr + auto paddings = BOOST_GET_CONST(std::vector, op->GetAttr("paddings")); + std::vector new_paddings(10, 0); + new_paddings[2] = paddings[4]; + new_paddings[3] = paddings[2]; + new_paddings[4] = paddings[0]; + new_paddings[7] = paddings[5]; + new_paddings[8] = paddings[3]; + new_paddings[9] = paddings[1]; + + auto *paddings_node = CreateConst(graph, + node, + new_paddings, + std::vector{10}, + ONNXDataType::INT64) + ->outputs[0]; + auto *value_node = CreateConst(graph, + node, + std::vector{value}, + std::vector{1}, + ONNXDataType::FLOAT) + ->outputs[0]; + return CreateBaseOp(graph, + node, + "popart_pad", + {GetInputVarNode("X", node), paddings_node, value_node}, + {GetOutputVarNode("Out", node)}, + {{"mode", mode}}); +} + } // namespace } // namespace ipu } // namespace platform } // namespace paddle +REGISTER_HANDLER(affine_channel, affine_channel_handler); REGISTER_HANDLER(pool2d, pool2d_handler); REGISTER_HANDLER(max_pool2d_with_index, max_pool2d_with_index_handler); REGISTER_HANDLER(batch_norm, batch_norm_handler); @@ -388,4 +850,12 @@ REGISTER_HANDLER(group_norm, group_norm_handler); REGISTER_HANDLER(instance_norm, instance_norm_handler); REGISTER_HANDLER(layer_norm, layer_norm_handler); REGISTER_HANDLER(conv2d, conv2d_handler); +REGISTER_HANDLER(conv2d_transpose, conv2d_transpose_handler); REGISTER_HANDLER(dropout, dropout_handler); +REGISTER_HANDLER(bilinear_interp_v2, bilinear_interp_v2_handler); +REGISTER_HANDLER(nearest_interp_v2, nearest_interp_v2_handler); +REGISTER_HANDLER(bicubic_interp_v2, bicubic_interp_v2_handler); +REGISTER_HANDLER(linear_interp_v2, linear_interp_v2_handler); +REGISTER_HANDLER(trilinear_interp_v2, trilinear_interp_v2_handler); +REGISTER_HANDLER(data_norm, data_norm_handler); +REGISTER_HANDLER(pad3d, pad_handler); diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc index 9b7fb7b835235a5b95b218a756f650f7436a71eb..0bf0335db0f34e58ed69b11c9da2c3e1b5fbbff4 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc @@ -33,10 +33,15 @@ Node *fill_constant_handler(Graph *graph, Node *node) { auto dtype = VarType2OnnxDType(static_cast(dtype_)); auto dims = BOOST_GET_CONST(std::vector, op->GetAttr("shape")); auto value_ = BOOST_GET_CONST(float, op->GetAttr("value")); - size_t size = 1; + int size = 1; for (auto &dim : dims) { size *= dim; } + PADDLE_ENFORCE_GT(size, + 0, + errors::InvalidArgument( + "IPU doesn't support non-positive dimensions. Please " + "check tensor shape setting.")); Attribute value; switch (dtype_) { case VarType::FP16: @@ -598,10 +603,15 @@ Node *fill_any_like_handler(Graph *graph, Node *node) { auto x_shape = GetInputVarNode("X", node)->Var()->GetShape(); auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype")); auto dtype = static_cast(dtype_); - size_t size = 1; + int size = 1; for (auto &dim : x_shape) { size *= dim; } + PADDLE_ENFORCE_GT(size, + 0, + errors::InvalidArgument( + "IPU doesn't support non-positive dimensions. Please " + "check tensor shape setting.")); Attribute out_value; switch (dtype) { @@ -748,6 +758,491 @@ Node *dot_handler(Graph *graph, Node *node) { }); } +Node *clip_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + // if (min_value == -FLT_MAX) then means no min_value + // if (max_value == FLT_MAX) then means no max_value + auto min_value = BOOST_GET_CONST(float, op->GetAttr("min")); + auto max_value = BOOST_GET_CONST(float, op->GetAttr("max")); + + bool has_min_tensor = false; + bool has_max_tensor = false; + if (node->Op()->Input("Min").size()) { + has_min_tensor = true; + } + if (node->Op()->Input("Max").size()) { + has_max_tensor = true; + } + + bool transfer_input_dtype = false; + Node *input_data = GetInputVarNode("X", node); + if (input_data->Var()->GetDataType() != VarType::FP32 && + input_data->Var()->GetDataType() != VarType::FP16) { + input_data = + CreateCast(graph, node, {input_data}, {}, VarType::FP32)->outputs[0]; + transfer_input_dtype = true; + } + + Node *min_tensor = nullptr; + if (has_min_tensor) { + if (GetInputVarNode("Min", node)->Var()->GetDataType() != VarType::FP32) { + min_tensor = + CreateCast( + graph, node, {GetInputVarNode("Min", node)}, {}, VarType::FP32) + ->outputs[0]; + } else { + min_tensor = GetInputVarNode("Min", node); + } + } else { + min_tensor = CreateConst(graph, + node, + {}, + {}, + {{"value", std::vector{min_value}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::FLOAT}}) + ->outputs[0]; + } + + Node *max_tensor = nullptr; + if (has_max_tensor) { + if (GetInputVarNode("Max", node)->Var()->GetDataType() != VarType::FP32) { + max_tensor = + CreateCast( + graph, node, {GetInputVarNode("Max", node)}, {}, VarType::FP32) + ->outputs[0]; + } else { + max_tensor = GetInputVarNode("Max", node); + } + } else { + max_tensor = CreateConst(graph, + node, + {}, + {}, + {{"value", std::vector{max_value}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::FLOAT}}) + ->outputs[0]; + } + + if (transfer_input_dtype) { + auto clip_res = CreateBaseOp( + graph, node, "popart_clip", {input_data, min_tensor, max_tensor}, {}); + return CreateCast(graph, + node, + clip_res->outputs, + {GetOutputVarNode("Out", node)}, + GetInputVarNode("X", node)->Var()->GetDataType()); + } else { + return CreateBaseOp(graph, + node, + "popart_clip", + {input_data, min_tensor, max_tensor}, + {GetOutputVarNode("Out", node)}); + } +} + +Node *dist_handler(Graph *graph, Node *node) { + // Minimum negative float + union neg_infinity { + int neg_int_inf; + float neg_float_int; + }; + neg_infinity neg_inf; + neg_inf.neg_int_inf = 0xFF800000; + float g_NegFloatInfinity = neg_inf.neg_float_int; + + auto *op = node->Op(); + auto *sub_node = + CreateBaseOp(graph, + node, + "popart_sub", + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, + {}) + ->outputs[0]; + auto *abs_node = + CreateBaseOp(graph, node, "popart_abs", {sub_node}, {})->outputs[0]; + + auto p = BOOST_GET_CONST(float, op->GetAttr("p")); + + // Reshape to 1-D output + auto target_shape = AttributeMap{{"value", std::vector{-1}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::INT64}}; + auto *target_shape_node = + CreateBaseOp(graph, node, "popart_constant", {}, {}, target_shape) + ->outputs[0]; + + if (fabs(p) < 1e-6) { + auto *sign_node = + CreateBaseOp(graph, node, "popart_sign", {abs_node}, {})->outputs[0]; + auto *sum_node = CreateBaseOp(graph, + node, + "popart_reducesum", + {sign_node}, + {}, + {{"keepdims", int64_t{0}}}) + ->outputs[0]; + return CreateBaseOp(graph, + node, + "popart_reshape", + {sum_node, target_shape_node}, + {GetOutputVarNode("Out", node)}); + } else if (p == std::numeric_limits::infinity()) { + auto *max_node = CreateBaseOp(graph, + node, + "popart_reducemax", + {abs_node}, + {}, + {{"keepdims", int64_t{0}}}) + ->outputs[0]; + return CreateBaseOp(graph, + node, + "popart_reshape", + {max_node, target_shape_node}, + {GetOutputVarNode("Out", node)}); + } else if (p == g_NegFloatInfinity) { + auto *min_node = CreateBaseOp(graph, + node, + "popart_reducemin", + {abs_node}, + {}, + {{"keepdims", int64_t{0}}}) + ->outputs[0]; + return CreateBaseOp(graph, + node, + "popart_reshape", + {min_node, target_shape_node}, + {GetOutputVarNode("Out", node)}); + } else { + auto target_dtype = ONNXDataType::FLOAT; + if (GetInputVarNode("X", node)->Var()->GetDataType() == VarType::FP16) { + target_dtype = ONNXDataType::FLOAT16; + } + + auto pow_factor = AttributeMap{{"value", std::vector{p}}, + {"dims", std::vector{1}}, + {"dtype", target_dtype}}; + auto *pow_factor_node = + CreateBaseOp(graph, node, "popart_constant", {}, {}, pow_factor) + ->outputs[0]; + auto *pow_node = + CreateBaseOp(graph, node, "popart_pow", {abs_node, pow_factor_node}, {}) + ->outputs[0]; + auto *sum_node = CreateBaseOp(graph, + node, + "popart_reducesum", + {pow_node}, + {}, + {{"keepdims", int64_t{0}}}) + ->outputs[0]; + auto *s_node = + CreateBaseOp( + graph, node, "popart_reshape", {sum_node, target_shape_node}, {}) + ->outputs[0]; + auto *p_1 = + CreateBaseOp(graph, node, "popart_reciprocal", {pow_factor_node}, {}) + ->outputs[0]; + return CreateBaseOp(graph, + node, + "popart_pow", + {s_node, p_1}, + {GetOutputVarNode("Out", node)}); + } +} + +Node *expand_as_v2_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + Node *shape = nullptr; + auto op_inputs = op->Inputs(); + // PopART Expand Op only support the constant tensor as the input `shape`. + if (op_inputs.find("target_tensor") != op_inputs.end()) { + PADDLE_THROW(platform::errors::Unimplemented( + "Do not support input tensor `target_tensor`. Please use the attribute " + "`target_shape`.")); + } + auto input_shape = GetInputVarNode("X", node)->Var()->GetShape(); + auto shape_value = + BOOST_GET_CONST(std::vector, op->GetAttr("target_shape")); + // Check the dimensions + int input_shape_index = input_shape.size() - 1; + int target_shape_index = shape_value.size() - 1; + while (input_shape_index >= 0) { + if (input_shape[input_shape_index] != + int64_t(shape_value[target_shape_index]) && + input_shape[input_shape_index] != int64_t(1)) { + PADDLE_THROW(platform::errors::Unimplemented( + "For input and `shape`, corresponding dimensions must have the same " + "value or input dim = 1.")); + } + target_shape_index--; + input_shape_index--; + } + shape = CreateConst( + graph, + node, + {}, + {}, + {{"value", + std::vector{shape_value.begin(), shape_value.end()}}, + {"dims", std::vector{int64_t(shape_value.size())}}, + {"dtype", ONNXDataType::INT64}}) + ->outputs[0]; + return CreateBaseOp(graph, + node, + "popart_expand", + {GetInputVarNode("X", node), shape}, + {GetOutputVarNode("Out", node)}); +} + +Node *expand_v2_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + + // PopART Expand Op only support the constant tensor as the input `shape`. + if (op->Input("Shape").size()) { + PADDLE_THROW( + platform::errors::Unimplemented("Do not support input tensor `Shape`. " + "Please use the attribute `shape`.")); + } + if (op->Input("expand_shapes_tensor").size()) { + PADDLE_THROW(platform::errors::Unimplemented( + "Do not support input tensor `expand_shapes_tensor`. Please use the " + "attribute `shape`.")); + } + auto input_shape = GetInputVarNode("X", node)->Var()->GetShape(); + auto shape_value = BOOST_GET_CONST(std::vector, op->GetAttr("shape")); + // Check the dimensions + int input_shape_index = input_shape.size() - 1; + int target_shape_index = shape_value.size() - 1; + while (input_shape_index >= 0) { + if (input_shape[input_shape_index] != + int64_t(shape_value[target_shape_index]) && + input_shape[input_shape_index] != int64_t(1)) { + PADDLE_THROW(platform::errors::Unimplemented( + "For input and `shape`, corresponding dimensions must have the same " + "value or input dim = 1.")); + } + target_shape_index--; + input_shape_index--; + } + + auto *shape = + CreateConst( + graph, + node, + {}, + {}, + {{"value", + std::vector{shape_value.begin(), shape_value.end()}}, + {"dims", std::vector{int64_t(shape_value.size())}}, + {"dtype", ONNXDataType::INT64}}) + ->outputs[0]; + + return CreateBaseOp(graph, + node, + "popart_expand", + {GetInputVarNode("X", node), shape}, + {GetOutputVarNode("Out", node)}); +} + +Node *flatten_contiguous_range_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto start_axis = BOOST_GET_CONST(int, op->GetAttr("start_axis")); + auto stop_axis = BOOST_GET_CONST(int, op->GetAttr("stop_axis")); + auto input_rank = GetInputVarNode("X", node)->Var()->GetShape().size(); + + if (start_axis < 0) { + start_axis += input_rank; + } + if (stop_axis < 0) { + stop_axis += input_rank; + } + + std::vector target_shape; + if (start_axis == 0 && stop_axis == input_rank - 1) { + target_shape.push_back(-1); + } else { + auto input_shape = GetInputVarNode("X", node)->Var()->GetShape(); + if (start_axis == 0) { + target_shape.assign(input_shape.begin() + stop_axis + 1, + input_shape.end()); + target_shape.insert(target_shape.begin(), -1); + } else if (stop_axis == input_rank - 1) { + target_shape.assign(input_shape.begin(), + input_shape.begin() + start_axis); + target_shape.push_back(-1); + } else { + target_shape.insert(target_shape.begin(), + input_shape.begin(), + input_shape.begin() + start_axis); + target_shape.push_back(-1); + target_shape.insert(target_shape.end(), + input_shape.begin() + stop_axis + 1, + input_shape.end()); + } + } + auto *unknown_dim_node = CreateConst(graph, + node, + target_shape, + {int64_t(target_shape.size())}, + ONNXDataType::INT64) + ->outputs[0]; + return CreateBaseOp(graph, + node, + "popart_reshape", + {GetInputVarNode("X", node), unknown_dim_node}, + {GetOutputVarNode("Out", node)}, + {}); +} + +Node *flip_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto axes = BOOST_GET_CONST(std::vector, op->GetAttr("axis")); + auto input_shape = GetInputVarNode("X", node)->Var()->GetShape(); + for (auto it = axes.begin(); it != axes.end();) { + if (*it < 0) { + *it += input_shape.size(); + } + // Remove input_shape[axis] == 1 + if (input_shape[*it] == 1) { + it = axes.erase(it); + } else { + it++; + } + } + auto *temp_node = GetInputVarNode("X", node); + for (auto i = 0; i < axes.size(); i++) { + auto axis = axes[i]; + std::vector split; + split.resize(input_shape[axis], 1); + std::vector splits_output_nodes; + for (int j = 0; j < split.size(); j++) { + splits_output_nodes.push_back(MakeVarNode(graph, node)); + } + auto splits_outputs = CreateBaseOp(graph, + node, + "popart_split", + {temp_node}, + {splits_output_nodes}, + {{"num_outputs", int64_t(split.size())}, + {"axis", int64_t(axis)}, + {"split", split}}) + ->outputs; + std::reverse(splits_outputs.begin(), splits_outputs.end()); + if (i != axes.size() - 1) { + temp_node = CreateBaseOp(graph, + node, + "popart_concat", + splits_outputs, + {}, + {{"axis", int64_t(axis)}}) + ->outputs[0]; + } else { + temp_node = CreateBaseOp(graph, + node, + "popart_concat", + splits_outputs, + {}, + {{"axis", int64_t(axis)}}) + ->outputs[0]; + } + } + // In case of `axis` is empty. Identity Op will be deleted in passes. + return CreateBaseOp(graph, + node, + "popart_identity", + {temp_node}, + {GetOutputVarNode("Out", node)}, + {}); +} + +Node *meshgrid_handler(Graph *graph, Node *node) { + Node *res = nullptr; + // All inputs are 1-D tensors + std::vector out_shape; + for (auto input : node->inputs) { + auto input_shape = input->Var()->GetShape(); + out_shape.push_back(input_shape[0]); + } + // Expand Op only allows a const tensor as `shape` + auto *out_shape_node = CreateConst(graph, + node, + out_shape, + {int64_t(out_shape.size())}, + ONNXDataType::INT64) + ->outputs[0]; + + for (int i = 0; i < node->inputs.size(); i++) { + // Reshape each input tensor to [node->inputs.size()] by filling with 1 + std::vector target_shape(node->inputs.size(), 1); + target_shape[i] = node->inputs[i]->Var()->GetShape()[0]; + auto *target_shape_node = CreateConst(graph, + node, + target_shape, + {int64_t(target_shape.size())}, + ONNXDataType::INT64) + ->outputs[0]; + auto *t_reshaped = CreateBaseOp(graph, + node, + "popart_reshape", + {node->inputs[i], target_shape_node}, + {}, + {}) + ->outputs[0]; + res = CreateBaseOp(graph, + node, + "popart_expand", + {t_reshaped, out_shape_node}, + {node->outputs[i]}); + } + return res; +} + +Node *p_norm_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto keepdim = BOOST_GET_CONST(bool, op->GetAttr("keepdim")); + auto axis = BOOST_GET_CONST(int, op->GetAttr("axis")); + auto porder = BOOST_GET_CONST(float, op->GetAttr("porder")); + + auto target_dtype = ONNXDataType::FLOAT; + if (GetInputVarNode("X", node)->Var()->GetDataType() == VarType::FP16) { + target_dtype = ONNXDataType::FLOAT16; + } + + auto *pnode = CreateConst(graph, + node, + std::vector{porder}, + std::vector{1}, + target_dtype) + ->outputs[0]; + auto *abs_node = + CreateBaseOp(graph, node, "popart_abs", {GetInputVarNode("X", node)}, {}) + ->outputs[0]; + auto *pow_node = + CreateBaseOp(graph, node, "popart_pow", {abs_node, pnode}, {}) + ->outputs[0]; + auto *reducesum_node = CreateBaseOp(graph, + node, + "popart_reducesum", + {pow_node}, + {}, + {{"axes", std::vector{axis}}, + {"keepdims", int64_t(keepdim)}}) + ->outputs[0]; + auto *pnode1 = + CreateConst(graph, + node, + std::vector{static_cast(1.0 / porder)}, + std::vector{1}, + target_dtype) + ->outputs[0]; + return CreateBaseOp(graph, + node, + "popart_pow", + {reducesum_node, pnode1}, + {GetOutputVarNode("Out", node)}); +} + } // namespace } // namespace ipu } // namespace platform @@ -759,6 +1254,7 @@ REGISTER_HANDLER(uniform_random, uniform_random_handler); REGISTER_HANDLER(transpose2, transpose_handler); REGISTER_HANDLER(reshape2, reshape_handler); REGISTER_HANDLER(flatten2, flatten2_handler); +REGISTER_HANDLER(flatten_contiguous_range, flatten_contiguous_range_handler); REGISTER_HANDLER(gather, gather_handler); REGISTER_HANDLER(squeeze2, squeeze_handler); REGISTER_HANDLER(cast, cast_handler); @@ -769,6 +1265,8 @@ REGISTER_HANDLER(stack, stack_handler); REGISTER_HANDLER(shape, shape_handler); REGISTER_HANDLER(slice, slice_handler); REGISTER_HANDLER(expand, expand_handler); +REGISTER_HANDLER(expand_v2, expand_v2_handler); +REGISTER_HANDLER(expand_as_v2, expand_as_v2_handler); REGISTER_HANDLER(assign, assign_handler); REGISTER_HANDLER(assign_value, assign_value_handler); REGISTER_HANDLER(fill_any_like, fill_any_like_handler); @@ -777,3 +1275,8 @@ REGISTER_HANDLER(split, split_handler); REGISTER_HANDLER(one_hot, one_hot_handler); REGISTER_HANDLER(one_hot_v2, one_hot_v2_handler); REGISTER_HANDLER(dot, dot_handler); +REGISTER_HANDLER(clip, clip_handler); +REGISTER_HANDLER(dist, dist_handler); +REGISTER_HANDLER(flip, flip_handler); +REGISTER_HANDLER(meshgrid, meshgrid_handler); +REGISTER_HANDLER(p_norm, p_norm_handler);