diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..92362ebf5be7d5426e4149e987c537691caedba4 --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc @@ -0,0 +1,36 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *equal_handler(Graph *graph, Node *node) { + auto new_node = CreateBaseOp( + graph, node, "popart_equal", + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, node->outputs); + return new_node; +} + +REGISTER_HANDLER(equal, equal_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..af7e4d0c7dbe9de77357c40f6298d1908495ac1a --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc @@ -0,0 +1,259 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *mean_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_reducemean", + {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, + { + {"keepdims", int64_t{0}}, + }); +} + +Node *pow_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + if (op->HasInput("FactorTensor") && !op->Input("FactorTensor").empty()) { + return CreateBaseOp( + graph, node, "popart_pow", + {GetInputVarNode("X", node), GetInputVarNode("FactorTensor", node)}, + node->outputs); + } else { + // Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow) + auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor")); + auto attrs = + MakeConstAttrMapFromValue(value_, {1}, ONNXDataType::FLOAT); + auto new_node_const = CreateConst(graph, node, {}, {}, attrs); + return CreateBaseOp(graph, node, "popart_pow", {GetInputVarNode("X", node), + new_node_const->outputs[0]}, + node->outputs); + } +} + +Node *mul_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto x_num_col_dims = BOOST_GET_CONST(int, op->GetAttr("x_num_col_dims")); + auto y_num_col_dims = BOOST_GET_CONST(int, op->GetAttr("y_num_col_dims")); + auto x_shape_ = GetInputVarNode("X", node)->Var()->GetShape(); + auto y_shape_ = GetInputVarNode("Y", node)->Var()->GetShape(); + + // build the shape for reshape + std::vector reshape_shape_{}; + for (int left = 0; left < x_num_col_dims; left++) { + reshape_shape_.push_back(int64_t(x_shape_[left])); + } + for (int right = y_num_col_dims; right < y_shape_.size(); right++) { + reshape_shape_.push_back(int64_t(y_shape_[right])); + } + auto x_flatten = + CreateBaseOp(graph, node, "popart_flatten", {GetInputVarNode("X", node)}, + {}, {{"axis", int64_t(x_num_col_dims)}}); + auto y_flatten = + CreateBaseOp(graph, node, "popart_flatten", {GetInputVarNode("Y", node)}, + {}, {{"axis", int64_t(y_num_col_dims)}}); + auto matmul = + CreateBaseOp(graph, node, "popart_matmul", + {x_flatten->outputs[0], y_flatten->outputs[0]}, {}, {}); + + auto reshape_const = CreateConst( + graph, node, {}, {}, + {{"value", reshape_shape_}, + {"dims", std::vector{int64_t(reshape_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}); + return CreateBaseOp(graph, node, "popart_reshape", + {matmul->outputs[0], reshape_const->outputs[0]}, + node->outputs, {}); +} + +Node *matmul_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("transpose_X")); + auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("transpose_Y")); + auto alpha = BOOST_GET_CONST(float, op->GetAttr("alpha")); + auto x_shape = GetInputVarNode("X", node)->Var()->GetShape(); + auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape(); + + int x_rank = x_shape.size(); + std::vector perm; + if (x_rank == 1) { + perm = std::vector{0}; + } else if (x_rank == 2) { + return CreateGemm(graph, node, + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, + node->outputs, transpose_x, transpose_y, alpha); + } else if (x_rank == 3) { + perm = std::vector{0, 2, 1}; + } else if (x_rank == 4) { + perm = std::vector{0, 1, 3, 2}; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "op matmul with input rank == %d", x_rank)); + } + + Node *x_node = GetInputVarNode("X", node); + Node *y_node = GetInputVarNode("Y", node); + if (transpose_x) { + x_node = CreateBaseOp(graph, node, "popart_transpose", + {GetInputVarNode("X", node)}, {}, {{"perm", perm}}); + x_node = x_node->outputs[0]; + } + if (transpose_y) { + y_node = CreateBaseOp(graph, node, "popart_transpose", + {GetInputVarNode("Y", node)}, {}, {{"perm", perm}}); + y_node = y_node->outputs[0]; + } + if (is_float_equal(alpha, 1.0)) { + auto o_node = + CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node}, {}); + auto attr = MakeConstAttrMapFromValue(alpha, {1}, ONNXDataType::FLOAT); + auto const_node = CreateConst(graph, node, {}, {}, attr); + return CreateBaseOp(graph, node, "popart_mul", + {o_node->outputs[0], const_node->outputs[0]}, + node->outputs); + } else { + return CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node}, + node->outputs); + } +} + +Node *sum_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_sum", node->inputs, node->outputs); +} + +Node *softmax_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto axis = BOOST_GET_CONST(int, op->GetAttr("axis")); + return CreateSoftmaxOpset11(graph, node, node->inputs, node->outputs, axis); +} + +Node *scale_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto scale_ = BOOST_GET_CONST(float, op->GetAttr("scale")); + auto bias_ = BOOST_GET_CONST(float, op->GetAttr("bias")); + auto bias_after_scale_ = + BOOST_GET_CONST(bool, op->GetAttr("bias_after_scale")); + auto data_type_ = GetInputVarNode("X", node)->Var()->GetDataType(); + + auto new_node_bias_var = + CreateConst(graph, node, {}, {}, {{"value", std::vector{bias_}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::FLOAT}}); + new_node_bias_var = new_node_bias_var->outputs[0]; + + Node *new_node_scale_var = nullptr; + if (op->HasInput("ScaleTensor") && !op->Input("ScaleTensor").empty()) { + new_node_scale_var = GetInputVarNode("ScaleTensor", node); + } else { + new_node_scale_var = + CreateConst(graph, node, {}, {}, {{"value", std::vector{scale_}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::FLOAT}}); + new_node_scale_var = new_node_scale_var->outputs[0]; + } + + // convert to float32 + auto new_node_cast = + CreateCast(graph, node, {GetInputVarNode("X", node)}, {}, + static_cast(framework::proto::VarType::FP32)); + Node *result = nullptr; + if (bias_after_scale_) { + auto new_node_mul = + CreateBaseOp(graph, node, "popart_mul", + {new_node_cast->outputs[0], new_node_scale_var}, {}, {}); + result = + CreateBaseOp(graph, node, "popart_add", + {new_node_mul->outputs[0], new_node_bias_var}, {}, {}); + } else { + auto new_node_add = + CreateBaseOp(graph, node, "popart_add", + {new_node_cast->outputs[0], new_node_bias_var}, {}, {}); + result = + CreateBaseOp(graph, node, "popart_mul", + {new_node_add->outputs[0], new_node_scale_var}, {}, {}); + } + auto result_after_cast = + CreateCast(graph, node, result->outputs, node->outputs, + static_cast(data_type_)); + return result_after_cast; +} + +Node *cross_entropy2_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index")); + auto new_cast = CreateCast(graph, node, {GetInputVarNode("Label", node)}, {}, + framework::proto::VarType::INT32); + auto label_shape_ = GetInputVarNode("Label", node)->Var()->GetShape(); + if (label_shape_.size() == 1) { + return CreateBaseOp(graph, node, "popart_nllloss", + {GetInputVarNode("X", node), new_cast->outputs[0]}, + {GetOutputVarNode("Y", node)}, + { + {"ignoreIndex", ignoreIndex}, + }); + } else { + std::vector new_shape_{label_shape_[0]}; + auto const_before_loss = CreateBaseOp( + graph, node, "popart_constant", {}, {}, + {{"value", new_shape_}, + {"dims", + std::vector{static_cast(new_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}); + + auto reshape_before_loss = CreateBaseOp( + graph, node, "popart_reshape", + {new_cast->outputs[0], const_before_loss->outputs[0]}, {}, {}); + + auto nllloss = CreateBaseOp( + graph, node, "popart_nllloss", + {GetInputVarNode("X", node), reshape_before_loss->outputs[0]}, {}, + { + {"ignoreIndex", ignoreIndex}, + }); + + auto const_after_loss = CreateBaseOp( + graph, node, "popart_constant", {}, {}, + {{"value", label_shape_}, + {"dims", + std::vector{static_cast(label_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}); + + auto reshape_after_loss = + CreateBaseOp(graph, node, "popart_reshape", + {nllloss->outputs[0], const_after_loss->outputs[0]}, + {GetOutputVarNode("Y", node)}, {}); + return reshape_after_loss; + } +} + +REGISTER_HANDLER(mean, mean_handler); +REGISTER_HANDLER(pow, pow_handler); +REGISTER_HANDLER(mul, mul_handler); +REGISTER_HANDLER(matmul, matmul_handler); +REGISTER_HANDLER(sum, sum_handler); +REGISTER_HANDLER(softmax, softmax_handler); +REGISTER_HANDLER(scale, scale_handler); +REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..58f3e42b7387a778a553cc1187f912b1138d59b3 --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc @@ -0,0 +1,301 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *conv2d_handler(Graph *graph, Node *node) { + OpDesc *op = node->Op(); + auto dilations_ = BOOST_GET_CONST(std::vector, op->GetAttr("dilations")); + auto dilations = std::vector{dilations_.begin(), dilations_.end()}; + auto group_ = BOOST_GET_CONST(int, op->GetAttr("groups")); + auto pads_ = BOOST_GET_CONST(std::vector, op->GetAttr("paddings")); + if (pads_.size() == 2) { + pads_.push_back(pads_[0]); + pads_.push_back(pads_[1]); + } + auto pads = std::vector{pads_.begin(), pads_.end()}; + auto stride_ = BOOST_GET_CONST(std::vector, op->GetAttr("strides")); + auto stride = std::vector{stride_.begin(), stride_.end()}; + if (op->HasInput("Bias") && !op->Input("Bias").empty()) { + return CreateConv( + graph, node, + { + GetInputVarNode("Input", node), GetInputVarNode("Filter", node), + GetInputVarNode("Bias", node), + }, + node->outputs, dilations, group_, {}, pads, stride); + } else { + return CreateConv( + graph, node, + { + GetInputVarNode("Input", node), GetInputVarNode("Filter", node), + }, + node->outputs, dilations, group_, {}, pads, stride); + } +} + +Node *batch_norm_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + std::vector inputs; + inputs.push_back(GetInputVarNode("X", node)); + inputs.push_back(GetInputVarNode("Scale", node)); + inputs.push_back(GetInputVarNode("Bias", node)); + inputs.push_back(GetInputVarNode("Mean", node)); + inputs.push_back(GetInputVarNode("Variance", node)); + int64_t num_outputs = 1; + std::vector outputs; + auto is_test_type = op->GetAttrType("is_test"); + bool is_test; + if (is_test_type == 0) { + // int + is_test = BOOST_GET_CONST(int, op->GetAttr("is_test")); + } else { + // bool + is_test = BOOST_GET_CONST(bool, op->GetAttr("is_test")); + } + outputs.push_back(GetOutputVarNode("Y", node)); + if (!is_test) { + outputs.push_back(GetOutputVarNode("MeanOut", node)); + outputs.push_back(GetOutputVarNode("VarianceOut", node)); + outputs.push_back(GetOutputVarNode("SavedMean", node)); + outputs.push_back(GetOutputVarNode("SavedVariance", node)); + num_outputs = 5; + } + // outputs.push_back(GetOutputVarNode("ReserveSpace", node)); + auto momentum = BOOST_GET_CONST(float, op->GetAttr("momentum")); + auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon")); + // data_layout + return CreateBaseOp(graph, node, "popart_batchnormalization", inputs, outputs, + { + {"momentum", momentum}, + {"epsilon", epsilon}, + {"num_outputs", num_outputs}, + }); +} + +Node *pool2d_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto pooling_type = BOOST_GET_CONST(std::string, op->GetAttr("pooling_type")); + auto global_pooling = BOOST_GET_CONST(bool, op->GetAttr("global_pooling")); + if (global_pooling) { + if (pooling_type == "max") { + return CreateBaseOp(graph, node, "popart_globalmaxpool", node->inputs, + node->outputs); + } else if (pooling_type == "avg") { + return CreateBaseOp(graph, node, "popart_globalaveragepool", node->inputs, + node->outputs); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "op pool2d with unkonwn pooling_type: %s", pooling_type)); + } + } + if (op->HasAttr("padding_algorithm")) { + auto padding_algorithm = + BOOST_GET_CONST(std::string, op->GetAttr("padding_algorithm")); + if (padding_algorithm != "EXPLICIT") { + PADDLE_THROW(platform::errors::InvalidArgument( + "op pool2d with unkonwn padding_algorithm: %s", padding_algorithm)); + } + } + + auto ksize = BOOST_GET_CONST(std::vector, op->GetAttr("ksize")); + auto kernel_shape = std::vector{ksize.begin(), ksize.end()}; + auto ceil_mode_ = BOOST_GET_CONST(bool, op->GetAttr("ceil_mode")); + auto ceil_mode = int64_t(ceil_mode_ ? 1 : 0); + auto paddings = BOOST_GET_CONST(std::vector, op->GetAttr("paddings")); + auto pads = std::vector{paddings.begin(), paddings.end()}; + if (pads.size() == 2) { + pads.push_back(paddings[0]); + pads.push_back(paddings[1]); + } + auto strides_ = BOOST_GET_CONST(std::vector, op->GetAttr("strides")); + auto strides = std::vector{strides_.begin(), strides_.end()}; + if (pooling_type == "max") { + int64_t num_outputs = 1; + auto dilations = std::vector{}; + int64_t storage_order = 0; + return CreateBaseOp(graph, node, "popart_maxpool", node->inputs, + node->outputs, { + {"num_outputs", num_outputs}, + {"kernel_shape", kernel_shape}, + {"ceil_mode", ceil_mode}, + {"dilations", dilations}, + {"pads", pads}, + {"storage_order", storage_order}, + {"strides", strides}, + }); + } else if (pooling_type == "avg") { + int64_t count_include_pad = 0; + return CreateBaseOp(graph, node, "popart_averagepool", node->inputs, + node->outputs, + { + {"kernel_shape", kernel_shape}, + {"ceil_mode", ceil_mode}, + {"count_include_pad", count_include_pad}, + {"pads", pads}, + {"strides", strides}, + }); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "op pool2d with unkonwn pooling_type: %s", pooling_type)); + } +} + +Node *group_norm_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto epsilon_ = BOOST_GET_CONST(float, op->GetAttr("epsilon")); + auto groups_ = BOOST_GET_CONST(int, op->GetAttr("groups")); + auto groups = int64_t{groups_}; + auto attrs_ = AttributeMap{{"epsilon", epsilon_}, {"num_groups", groups}}; + + std::vector inputs_ = {GetInputVarNode("X", node), + GetInputVarNode("Scale", node), + GetInputVarNode("Bias", node)}; + std::vector outputs_ = {GetOutputVarNode("Y", node), + GetOutputVarNode("Mean", node), + GetOutputVarNode("Variance", node)}; + return CreateBaseOp(graph, node, "popart_groupnormalization_v2", inputs_, + outputs_, attrs_); +} + +Node *instance_norm_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto epsilon_ = BOOST_GET_CONST(float, op->GetAttr("epsilon")); + auto attrs_ = AttributeMap{{"epsilon", epsilon_}}; + + std::vector inputs_ = {GetInputVarNode("X", node), + GetInputVarNode("Scale", node), + GetInputVarNode("Bias", node)}; + std::vector outputs_ = {GetOutputVarNode("Y", node)}; + return CreateBaseOp(graph, node, "popart_instancenormalization", inputs_, + outputs_, attrs_); +} + +Node *layer_norm_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto begin_norm_axis_ = BOOST_GET_CONST(int, op->GetAttr("begin_norm_axis")); + auto input_shape_ = GetInputVarNode("X", node)->Var()->GetShape(); + + std::vector norm_shape_{1, 1}; + for (int i = 0; i < input_shape_.size(); i++) { + if (i < begin_norm_axis_) { + norm_shape_[0] *= input_shape_[i]; + } else { + norm_shape_[1] *= input_shape_[i]; + } + } + + auto attrs1 = AttributeMap{ + {"value", norm_shape_}, + {"dims", std::vector{static_cast(norm_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}; + auto reshape1_const = + CreateBaseOp(graph, node, "popart_constant", {}, {}, attrs1); + auto new_node_reshape1 = CreateBaseOp( + graph, node, "popart_reshape", + {GetInputVarNode("X", node), reshape1_const->outputs[0]}, {}, {}); + + auto epsilon_ = BOOST_GET_CONST(float, op->GetAttr("epsilon")); + int64_t groups_ = 1; + auto groupnorm_attrs_ = + AttributeMap{{"epsilon", epsilon_}, {"num_groups", groups_}}; + auto out_Y_ = MakeVarNode(graph, node); + CreateBaseOp(graph, node, "popart_groupnormalization_v2", + {new_node_reshape1->outputs[0], GetInputVarNode("Scale", node), + GetInputVarNode("Bias", node)}, + {out_Y_, GetOutputVarNode("Mean", node), + GetOutputVarNode("Variance", node)}, + groupnorm_attrs_); + + auto attrs2 = AttributeMap{ + {"value", input_shape_}, + {"dims", std::vector{static_cast(input_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}; + auto reshape2_const = + CreateBaseOp(graph, node, "popart_constant", {}, {}, attrs2); + auto new_node_reshape2 = CreateBaseOp(graph, node, "popart_reshape", + {out_Y_, reshape2_const->outputs[0]}, + {GetOutputVarNode("Y", node)}, {}); + return new_node_reshape2; +} + +Node *dropout_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto dropout_prob_ = BOOST_GET_CONST(float, op->GetAttr("dropout_prob")); + auto dropout_implementation_ = + BOOST_GET_CONST(std::string, op->GetAttr("dropout_implementation")); + auto is_test_type_ = op->GetAttrType("is_test"); + bool is_test_; + if (is_test_type_ == 0) { + // int + is_test_ = BOOST_GET_CONST(int, op->GetAttr("is_test")); + } else { + // bool + is_test_ = BOOST_GET_CONST(bool, op->GetAttr("is_test")); + } + + if (is_test_) { + if (dropout_implementation_ == "upscale_in_train") { + return CreateBaseOp(graph, node, "popart_identity", + {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, {}); + } else if (dropout_implementation_ == "downgrade_in_infer") { + auto scale = + CreateConst(graph, node, {}, {}, + {{"value", std::vector{1 - dropout_prob_}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::FLOAT}}); + return CreateBaseOp(graph, node, "popart_mul", + {GetInputVarNode("X", node), scale->outputs[0]}, + {GetOutputVarNode("Out", node)}, {}); + } else { + PADDLE_THROW( + platform::errors::InvalidArgument("Invalid dropout_implementation")); + } + } else { + if (dropout_implementation_ == "upscale_in_train") { + auto attrs_ = + AttributeMap{{"num_outputs", (int64_t)1}, {"ratio", dropout_prob_}}; + return CreateBaseOp(graph, node, "popart_dropout", + {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, attrs_); + } else if (dropout_implementation_ == "downgrade_in_infer") { + PADDLE_THROW(platform::errors::InvalidArgument( + "Do not support downgrade_in_infer with training")); + } else { + PADDLE_THROW( + platform::errors::InvalidArgument("Invalid dropout_implementation")); + } + } +} + +REGISTER_HANDLER(pool2d, pool2d_handler); +REGISTER_HANDLER(batch_norm, batch_norm_handler); +REGISTER_HANDLER(group_norm, group_norm_handler); +REGISTER_HANDLER(instance_norm, instance_norm_handler); +REGISTER_HANDLER(layer_norm, layer_norm_handler); +REGISTER_HANDLER(conv2d, conv2d_handler); +REGISTER_HANDLER(dropout, dropout_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..b7a3a8ca7c60f59f45dec5a1630fe3fd48a6f868 --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc @@ -0,0 +1,195 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" + +namespace paddle { +namespace platform { +namespace ipu { + +// singleton +static int var_count = 0; +static int op_count = 0; + +const std::string GenerateVarName() { + return std::string("_gen_var_") + std::to_string(var_count++); +} + +const std::string GenerateOpName() { + return std::string("_gen_op_") + std::to_string(op_count++); +} + +const std::string CreateOpIdentifyId(Node *node) { + // format: op_type|out_var0|out_var1|...|_gen_* + // this name will be used as op name when exporting onnx model from popart + auto op_type = node->Name(); + std::string op_out = ""; + for (auto *out_node : node->outputs) { + op_out += "|"; + op_out += out_node->Name(); + } + return {op_type + op_out + "|" + GenerateOpName()}; +} + +Node *MakeVarNode(Graph *graph, Node *node) { + auto var_name = GenerateVarName(); + auto var_desc = std::make_unique(var_name); + + auto var = graph->CreateVarNode(var_desc.get()); + return var; +} + +Node *MakeOpNode(Graph *graph, Node *node, const std::string &type, + const std::vector &inputs, + const std::vector &outputs) { + auto op_desc = std::make_unique(); + op_desc->SetType(type); + auto op = graph->CreateOpNode(op_desc.get()); + + for (auto *in : inputs) { + ConnectNodes(in, op); + } + if (outputs.empty()) { + auto var = MakeVarNode(graph, node); + ConnectNodes(op, var); + } else { + for (auto *out : outputs) { + ConnectNodes(op, out); + } + } + + // i/o + std::vector input_names; + for (auto node : op->inputs) { + input_names.push_back(node->Name()); + } + op->Op()->SetInput("__inputs__", input_names); + std::vector output_names; + for (auto node : op->outputs) { + output_names.push_back(node->Name()); + } + op->Op()->SetOutput("__outputs__", output_names); + op->Op()->Flush(); + + return op; +} + +Node *CreateBaseOp(Graph *graph, Node *node, const std::string &type, + const std::vector &inputs, + const std::vector &outputs, + const AttributeMap &attrs) { + auto new_node = MakeOpNode(graph, node, type, inputs, outputs); + if (!attrs.empty()) { + new_node->Op()->SetAttrMap(attrs); + } + // deal special attr + if (!new_node->Op()->HasAttr(sIpuIndexAttr)) { + CopyOpAttr(sIpuIndexAttr, node->Op(), new_node->Op()); + } + if (!new_node->Op()->HasAttr(sIpuStageAttr)) { + CopyOpAttr(sIpuStageAttr, node->Op(), new_node->Op()); + } + { + new_node->Op()->SetAttr(sOpIdentifyIdAttr, CreateOpIdentifyId(node)); + new_node->Op()->Flush(); + } + + return new_node; +} + +Node *CreateConst(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const AttributeMap &attrs) { + return CreateBaseOp(graph, node, "popart_constant", inputs, outputs, attrs); +} + +Node *CreateCast(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, const int otype) { + auto to = VarType2PopStr(otype); + return CreateBaseOp(graph, node, "popart_cast", inputs, outputs, + {{"to", to}}); +} + +Node *CreateGemm(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, int64_t transA, + int64_t transB, float alpha, float beta) { + return CreateBaseOp(graph, node, "popart_gemm", inputs, outputs, + { + {"alpha", alpha}, + {"beta", beta}, + {"transA", transA}, + {"transB", transB}, + }); +} + +Node *CreateReshape(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const std::vector &oshape) { + auto attr = AttributeMap{ + {"value", oshape}, + {"dims", std::vector{static_cast(oshape.size())}}, + {"dtype", ONNXDataType::INT64}}; + auto new_node_const = + CreateBaseOp(graph, node, "popart_constant", {}, {}, attr); + auto new_node_reshape = + CreateBaseOp(graph, node, "popart_reshape", + {inputs[0], new_node_const->outputs[0]}, outputs); + return new_node_reshape; +} + +Node *CreateConv(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const std::vector &dilations, int64_t group, + const std::vector &kernel_shape, + const std::vector &pads, + const std::vector &strides) { + auto attrs = AttributeMap{ + {"dilations", dilations}, {"group", group}, + {"kernel_shape", kernel_shape}, {"pads", pads}, + {"strides", strides}, + }; + return CreateBaseOp(graph, node, "popart_conv", inputs, outputs, attrs); +} + +Node *CreateSoftmaxOpset11(Graph *graph, Node *node, + const std::vector &inputs, + const std::vector &outputs, int64_t axis) { + PADDLE_ENFORCE_EQ(inputs.size(), 1, platform::errors::InvalidArgument( + "Softmax op only support one input")); + auto x_shape = inputs[0]->Var()->GetShape(); + int x_rank = x_shape.size(); + if (axis < 0) { + axis = axis + x_rank; + } + if (axis == x_rank - 1) { + return CreateBaseOp(graph, node, "popart_softmax", inputs, outputs, + {{"axis", int64_t{-1}}}); + } else { + auto perm = std::vector(x_rank); + std::iota(perm.begin(), perm.end(), 0); + perm[x_rank - 1] = axis; + perm[axis] = x_rank - 1; + auto new_transpose_pre = CreateBaseOp(graph, node, "popart_transpose", + inputs, {}, {{"perm", perm}}); + auto new_softmax = + CreateBaseOp(graph, node, "popart_softmax", new_transpose_pre->outputs, + {}, {{"axis", int64_t{-1}}}); + return CreateBaseOp(graph, node, "popart_transpose", new_softmax->outputs, + outputs, {{"perm", perm}}); + } +} + +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..7e70e56ef9166cc50a5018b0984994391856ab02 --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h @@ -0,0 +1,85 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/device/ipu/common.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" + +namespace paddle { +namespace platform { +namespace ipu { + +using paddle::framework::AttributeMap; + +template +AttributeMap MakeConstAttrMap(std::vector value, std::vector dims, + int dtype) { + return AttributeMap{{"value", value}, {"dims", dims}, {"dtype", dtype}}; +} + +template +AttributeMap MakeConstAttrMapFromValue(T v, std::vector dims, + int dtype) { + size_t size = 1; + for (auto &dim : dims) { + size *= dim; + } + return MakeConstAttrMap(std::vector(size, v), dims, dtype); +} + +const std::string GenerateVarName(); +const std::string CreateOpIdentifyId(Node *node); + +Node *MakeVarNode(Graph *graph, Node *node); +Node *MakeOpNode(Graph *graph, Node *node, const std::string &type, + const std::vector &inputs, + const std::vector &outputs); + +Node *CreateBaseOp(Graph *graph, Node *node, const std::string &type, + const std::vector &inputs, + const std::vector &outputs, + const AttributeMap &attrs = {}); + +Node *CreateConst(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const AttributeMap &attrs); + +// otype is proto::VarType::Type +Node *CreateCast(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, const int otype); + +Node *CreateGemm(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, int64_t transA = 0, + int64_t transB = 0, float alpha = 1.0f, float beta = 1.0f); + +Node *CreateReshape(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const std::vector &oshape); + +Node *CreateConv(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const std::vector &dilations = {1, 1}, + int64_t group = 1, + const std::vector &kernel_shape = {}, + const std::vector &pads = {0, 0, 0, 0}, + const std::vector &strides = {1, 1}); + +Node *CreateSoftmaxOpset11(Graph *graph, Node *node, + const std::vector &inputs, + const std::vector &outputs, int64_t axis); + +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/supported_ops_autogen.h b/paddle/fluid/platform/device/ipu/supported_ops_autogen.h index 4cd7f928f6e22b2aa483c4448f73cd792cd154c6..763c5a46abe287e76a3d7fe32ce9d4c8c0cd2c62 100644 --- a/paddle/fluid/platform/device/ipu/supported_ops_autogen.h +++ b/paddle/fluid/platform/device/ipu/supported_ops_autogen.h @@ -195,3 +195,5 @@ OP_DECL(popart_sqrt, aiOnnxOpset.sqrt, NONE) // NOLINT OP_DECL(popart_tanh, aiOnnxOpset.tanh, NONE) // NOLINT OP_DECL(popart_tile, aiOnnxOpset.tile, NONE) // NOLINT OP_DECL(popart_transpose, aiOnnxOpset.transpose, ARG(INT_VEC,perm) ) // NOLINT + +// clang-format on