未验证 提交 050fd168 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] add more ops (#38831)

* support more ops

* Co-authored-by: Xiaobing Wang <xiaobingw@graphcore.ai>
Co-authored-by: NAllen Guo <alleng@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NHaicheng Jiang <haichengj@graphcore.ai>
Co-authored-by: NHan Zhao <hanzhao@graphcore.ai>

* add authors
Co-authored-by: NXiaobing Wang <xiaobingw@graphcore.ai>
Co-authored-by: NAllen Guo <alleng@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NHaicheng Jiang <haichengj@graphcore.ai>
Co-authored-by: NHan Zhao <hanzhao@graphcore.ai>

* update date
Co-authored-by: NXiaobing Wang <xiaobingw@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NHaicheng Jiang <haichengj@graphcore.ai>
Co-authored-by: NHan Zhao <hanzhao@graphcore.ai>
上级 50609214
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" #include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/post_canonicalization.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
......
...@@ -48,7 +48,37 @@ Node *sqrt_handler(Graph *graph, Node *node) { ...@@ -48,7 +48,37 @@ Node *sqrt_handler(Graph *graph, Node *node) {
} }
Node *gelu_handler(Graph *graph, Node *node) { Node *gelu_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_gelu_v2"); auto *op = node->Op();
auto approximate_ = BOOST_GET_CONST(bool, op->GetAttr("approximate"));
if (approximate_) {
return activation_op_handler(graph, node, "popart_gelu_v2");
} else {
auto sqrt2 = CreateConst(graph, node, {}, {},
{{"value", std::vector<float>{1.4142135623730951}},
{"dims", std::vector<int64_t>{1}},
{"dtype", GetOutputVarDtype(node)}});
auto zero_point_five =
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{0.5}},
{"dims", std::vector<int64_t>{1}},
{"dtype", GetOutputVarDtype(node)}});
auto one =
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{1}},
{"dims", std::vector<int64_t>{1}},
{"dtype", GetOutputVarDtype(node)}});
auto div =
CreateBaseOp(graph, node, "popart_div",
{GetInputVarNode("X", node), sqrt2->outputs[0]}, {}, {});
auto erf =
CreateBaseOp(graph, node, "popart_erf", {div->outputs[0]}, {}, {});
auto add = CreateBaseOp(graph, node, "popart_add",
{erf->outputs[0], one->outputs[0]}, {}, {});
auto mul1 =
CreateBaseOp(graph, node, "popart_mul",
{GetInputVarNode("X", node), add->outputs[0]}, {}, {});
return CreateBaseOp(graph, node, "popart_mul",
{mul1->outputs[0], zero_point_five->outputs[0]},
{GetOutputVarNode("Out", node)}, {});
}
} }
Node *log_softmax_handler(Graph *graph, Node *node) { Node *log_softmax_handler(Graph *graph, Node *node) {
......
...@@ -180,6 +180,17 @@ const bool is_float_equal(float a, float b, float eps) { ...@@ -180,6 +180,17 @@ const bool is_float_equal(float a, float b, float eps) {
return std::fabs(a - b) <= eps; return std::fabs(a - b) <= eps;
} }
const int GetOutputVarDtype(const Node *node, const std::string &output_name) {
auto out_node = GetOutputVarNode(output_name, node);
PADDLE_ENFORCE_NOT_NULL(out_node, platform::errors::Unavailable(
"Node's out node does not exist."));
auto var = out_node->Var();
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::Unavailable("Node is not a variable."));
auto proto_var_type = var->GetDataType();
return VarType2OnnxDtype(proto_var_type);
}
} // namespace ipu } // namespace ipu
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -23,10 +23,6 @@ namespace paddle { ...@@ -23,10 +23,6 @@ namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
using framework::ir::Graph;
using framework::ir::Node;
using framework::OpDesc;
#define REGISTER_HANDLER(name, func) \ #define REGISTER_HANDLER(name, func) \
static bool __UNUSED_##name = \ static bool __UNUSED_##name = \
paddle::platform::ipu::RegisterHandler(#name, func) paddle::platform::ipu::RegisterHandler(#name, func)
...@@ -58,6 +54,8 @@ Node *GetOutputVarNodeByVarName(const std::string &var_name, ...@@ -58,6 +54,8 @@ Node *GetOutputVarNodeByVarName(const std::string &var_name,
const Node *op_node); const Node *op_node);
const bool is_float_equal(float a, float b, float eps = 1e-8); const bool is_float_equal(float a, float b, float eps = 1e-8);
const int GetOutputVarDtype(const Node *node,
const std::string &output_name = "Out");
} // namespace ipu } // namespace ipu
} // namespace platform } // namespace platform
......
...@@ -28,7 +28,21 @@ Node *equal_handler(Graph *graph, Node *node) { ...@@ -28,7 +28,21 @@ Node *equal_handler(Graph *graph, Node *node) {
return new_node; return new_node;
} }
Node *logical_not_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_logical_not",
{GetInputVarNode("X", node)},
{GetOutputVarNode("Out", node)}, {});
}
Node *greater_than_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_greater",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{GetOutputVarNode("Out", node)}, {});
}
REGISTER_HANDLER(equal, equal_handler); REGISTER_HANDLER(equal, equal_handler);
REGISTER_HANDLER(logical_not, logical_not_handler);
REGISTER_HANDLER(greater_than, greater_than_handler);
} // namespace } // namespace
} // namespace ipu } // namespace ipu
......
...@@ -41,7 +41,8 @@ Node *pow_handler(Graph *graph, Node *node) { ...@@ -41,7 +41,8 @@ Node *pow_handler(Graph *graph, Node *node) {
// Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow) // Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow)
auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor")); auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor"));
auto attrs = auto attrs =
MakeConstAttrMapFromValue<float>(value_, {1}, ONNXDataType::FLOAT); MakeConstAttrMapFromValue<float>(value_, {1}, GetOutputVarDtype(node));
auto new_node_const = CreateConst(graph, node, {}, {}, attrs); auto new_node_const = CreateConst(graph, node, {}, {}, attrs);
return CreateBaseOp(graph, node, "popart_pow", {GetInputVarNode("X", node), return CreateBaseOp(graph, node, "popart_pow", {GetInputVarNode("X", node),
new_node_const->outputs[0]}, new_node_const->outputs[0]},
...@@ -122,16 +123,16 @@ Node *matmul_handler(Graph *graph, Node *node) { ...@@ -122,16 +123,16 @@ Node *matmul_handler(Graph *graph, Node *node) {
y_node = y_node->outputs[0]; y_node = y_node->outputs[0];
} }
if (is_float_equal(alpha, 1.0)) { if (is_float_equal(alpha, 1.0)) {
return CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node},
node->outputs);
} else {
auto o_node = auto o_node =
CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node}, {}); CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node}, {});
auto attr = MakeConstAttrMapFromValue(alpha, {1}, ONNXDataType::FLOAT); auto attr = MakeConstAttrMapFromValue(alpha, {1}, GetOutputVarDtype(node));
auto const_node = CreateConst(graph, node, {}, {}, attr); auto const_node = CreateConst(graph, node, {}, {}, attr);
return CreateBaseOp(graph, node, "popart_mul", return CreateBaseOp(graph, node, "popart_mul",
{o_node->outputs[0], const_node->outputs[0]}, {o_node->outputs[0], const_node->outputs[0]},
node->outputs); node->outputs);
} else {
return CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node},
node->outputs);
} }
} }
...@@ -141,7 +142,10 @@ Node *sum_handler(Graph *graph, Node *node) { ...@@ -141,7 +142,10 @@ Node *sum_handler(Graph *graph, Node *node) {
Node *softmax_handler(Graph *graph, Node *node) { Node *softmax_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis")); int axis = -1;
if (op->HasAttr("axis")) {
axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
}
return CreateSoftmaxOpset11(graph, node, node->inputs, node->outputs, axis); return CreateSoftmaxOpset11(graph, node, node->inputs, node->outputs, axis);
} }
...@@ -153,42 +157,72 @@ Node *scale_handler(Graph *graph, Node *node) { ...@@ -153,42 +157,72 @@ Node *scale_handler(Graph *graph, Node *node) {
BOOST_GET_CONST(bool, op->GetAttr("bias_after_scale")); BOOST_GET_CONST(bool, op->GetAttr("bias_after_scale"));
auto data_type_ = GetInputVarNode("X", node)->Var()->GetDataType(); auto data_type_ = GetInputVarNode("X", node)->Var()->GetDataType();
auto new_node_bias_var = auto cast = CreateCast(graph, node, {GetInputVarNode("X", node)}, {},
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{bias_}}, static_cast<int>(framework::proto::VarType::FP32));
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}});
new_node_bias_var = new_node_bias_var->outputs[0];
Node *new_node_scale_var = nullptr;
if (op->HasInput("ScaleTensor") && !op->Input("ScaleTensor").empty()) {
new_node_scale_var = GetInputVarNode("ScaleTensor", node);
} else {
new_node_scale_var =
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{scale_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}});
new_node_scale_var = new_node_scale_var->outputs[0];
}
// convert to float32
auto new_node_cast =
CreateCast(graph, node, {GetInputVarNode("X", node)}, {},
static_cast<int>(framework::proto::VarType::FP32));
Node *result = nullptr; Node *result = nullptr;
if (bias_after_scale_) { if (op->HasInput("ScaleTensor") && !op->Input("ScaleTensor").empty()) {
auto new_node_mul = auto scale = GetInputVarNode("ScaleTensor", node);
CreateBaseOp(graph, node, "popart_mul", if (is_float_equal(bias_, 0.0)) {
{new_node_cast->outputs[0], new_node_scale_var}, {}, {}); result = CreateBaseOp(graph, node, "popart_mul",
result = {cast->outputs[0], scale}, {}, {});
CreateBaseOp(graph, node, "popart_add", } else {
{new_node_mul->outputs[0], new_node_bias_var}, {}, {}); auto bias = CreateConst(graph, node, {}, {},
{{"value", std::vector<float>{bias_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}});
bias = bias->outputs[0];
if (bias_after_scale_) {
auto mul = CreateBaseOp(graph, node, "popart_mul",
{cast->outputs[0], scale}, {}, {});
result = CreateBaseOp(graph, node, "popart_add",
{mul->outputs[0], bias}, {}, {});
} else {
auto add = CreateBaseOp(graph, node, "popart_add",
{cast->outputs[0], bias}, {}, {});
result = CreateBaseOp(graph, node, "popart_mul",
{add->outputs[0], scale}, {}, {});
}
}
} else { } else {
auto new_node_add = if (is_float_equal(bias_, 0.0) && is_float_equal(scale_, 1.0)) {
CreateBaseOp(graph, node, "popart_add", return CreateBaseOp(graph, node, "popart_identity",
{new_node_cast->outputs[0], new_node_bias_var}, {}, {}); {GetInputVarNode("X", node)}, node->outputs, {});
result = } else if (is_float_equal(scale_, 1.0)) {
CreateBaseOp(graph, node, "popart_mul", auto bias = CreateConst(graph, node, {}, {},
{new_node_add->outputs[0], new_node_scale_var}, {}, {}); {{"value", std::vector<float>{bias_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}});
result = CreateBaseOp(graph, node, "popart_add",
{cast->outputs[0], bias->outputs[0]}, {}, {});
} else if (is_float_equal(bias_, 0.0)) {
auto scale = CreateConst(graph, node, {}, {},
{{"value", std::vector<float>{scale_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}});
result = CreateBaseOp(graph, node, "popart_mul",
{cast->outputs[0], scale->outputs[0]}, {}, {});
} else {
auto bias = CreateConst(graph, node, {}, {},
{{"value", std::vector<float>{bias_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}});
auto scale = CreateConst(graph, node, {}, {},
{{"value", std::vector<float>{scale_}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}});
if (bias_after_scale_) {
auto mul = CreateBaseOp(graph, node, "popart_mul",
{cast->outputs[0], scale->outputs[0]}, {}, {});
result = CreateBaseOp(graph, node, "popart_add",
{mul->outputs[0], bias->outputs[0]}, {}, {});
} else {
auto add = CreateBaseOp(graph, node, "popart_add",
{cast->outputs[0], bias->outputs[0]}, {}, {});
result = CreateBaseOp(graph, node, "popart_mul",
{add->outputs[0], scale->outputs[0]}, {}, {});
}
}
} }
auto result_after_cast = auto result_after_cast =
CreateCast(graph, node, result->outputs, node->outputs, CreateCast(graph, node, result->outputs, node->outputs,
...@@ -199,16 +233,27 @@ Node *scale_handler(Graph *graph, Node *node) { ...@@ -199,16 +233,27 @@ Node *scale_handler(Graph *graph, Node *node) {
Node *cross_entropy2_handler(Graph *graph, Node *node) { Node *cross_entropy2_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index")); auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index"));
auto new_cast = CreateCast(graph, node, {GetInputVarNode("Label", node)}, {}, Node *new_cast = nullptr;
framework::proto::VarType::INT32); if (GetInputVarNode("Label", node)->Var()->GetDataType() ==
framework::proto::VarType::INT32) {
new_cast = GetInputVarNode("Label", node);
} else {
auto new_cast = CreateCast(graph, node, {GetInputVarNode("Label", node)},
{}, framework::proto::VarType::INT32);
new_cast = new_cast->outputs[0];
}
auto label_shape_ = GetInputVarNode("Label", node)->Var()->GetShape(); auto label_shape_ = GetInputVarNode("Label", node)->Var()->GetShape();
if (label_shape_.size() == 1) { if (label_shape_[label_shape_.size() - 1] != 1) {
return CreateBaseOp(graph, node, "popart_nllloss", auto log = CreateBaseOp(graph, node, "popart_log",
{GetInputVarNode("X", node), new_cast->outputs[0]}, {GetInputVarNode("X", node)}, {}, {});
{GetOutputVarNode("Y", node)}, return CreateBaseOp(
{ graph, node, "popart_nllloss_v2", {log->outputs[0], new_cast},
{"ignoreIndex", ignoreIndex}, {GetOutputVarNode("Y", node)},
}); {
{"reduction", 2}, // popart::ReductionType::NoReduction
{"ignoreIndex", ignoreIndex},
{"inputIsLogProbability", true},
});
} else { } else {
std::vector<int64_t> new_shape_{label_shape_[0]}; std::vector<int64_t> new_shape_{label_shape_[0]};
auto const_before_loss = CreateBaseOp( auto const_before_loss = CreateBaseOp(
...@@ -218,15 +263,19 @@ Node *cross_entropy2_handler(Graph *graph, Node *node) { ...@@ -218,15 +263,19 @@ Node *cross_entropy2_handler(Graph *graph, Node *node) {
std::vector<int64_t>{static_cast<int64_t>(new_shape_.size())}}, std::vector<int64_t>{static_cast<int64_t>(new_shape_.size())}},
{"dtype", ONNXDataType::INT64}}); {"dtype", ONNXDataType::INT64}});
auto reshape_before_loss = CreateBaseOp( auto reshape_before_loss =
graph, node, "popart_reshape", CreateBaseOp(graph, node, "popart_reshape",
{new_cast->outputs[0], const_before_loss->outputs[0]}, {}, {}); {new_cast, const_before_loss->outputs[0]}, {}, {});
auto log = CreateBaseOp(graph, node, "popart_log",
{GetInputVarNode("X", node)}, {}, {});
auto nllloss = CreateBaseOp( auto nllloss = CreateBaseOp(
graph, node, "popart_nllloss", graph, node, "popart_nllloss_v2",
{GetInputVarNode("X", node), reshape_before_loss->outputs[0]}, {}, {log->outputs[0], reshape_before_loss->outputs[0]}, {},
{ {
{"reduction", 2}, // popart::ReductionType::NoReduction
{"ignoreIndex", ignoreIndex}, {"ignoreIndex", ignoreIndex},
{"inputIsLogProbability", true},
}); });
auto const_after_loss = CreateBaseOp( auto const_after_loss = CreateBaseOp(
...@@ -244,6 +293,73 @@ Node *cross_entropy2_handler(Graph *graph, Node *node) { ...@@ -244,6 +293,73 @@ Node *cross_entropy2_handler(Graph *graph, Node *node) {
} }
} }
Node *cumsum_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto exclusive = BOOST_GET_CONST(bool, op->GetAttr("exclusive"));
int64_t popart_exclusive = 1 ? exclusive : 0;
auto reverse = BOOST_GET_CONST(bool, op->GetAttr("reverse"));
int64_t popart_reverse = 1 ? reverse : 0;
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
auto axis_node =
CreateConst(graph, node, {}, {}, {{"value", std::vector<int64_t>{axis}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT64}});
return CreateBaseOp(
graph, node, "popart_cumsum",
{GetInputVarNode("X", node), axis_node->outputs[0]},
{GetOutputVarNode("Out", node)},
{{"exclusive", popart_exclusive}, {"reverse", popart_reverse}});
}
Node *matmul_v2_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("trans_x"));
auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("trans_y"));
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape();
auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape();
std::vector<int64_t> perm;
int x_rank = x_shape.size();
if (x_rank == 1) {
perm = std::vector<int64_t>{0};
} else if (x_rank == 2) {
perm = std::vector<int64_t>{1, 0};
} else if (x_rank == 3) {
perm = std::vector<int64_t>{0, 2, 1};
} else if (x_rank == 4) {
perm = std::vector<int64_t>{0, 1, 3, 2};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"op matmul with input rank == %d", x_rank));
}
Node *x_node = GetInputVarNode("X", node);
Node *y_node = GetInputVarNode("Y", node);
if (transpose_x) {
x_node = CreateBaseOp(graph, node, "popart_transpose",
{GetInputVarNode("X", node)}, {}, {{"perm", perm}});
x_node = x_node->outputs[0];
}
if (transpose_y) {
y_node = CreateBaseOp(graph, node, "popart_transpose",
{GetInputVarNode("Y", node)}, {}, {{"perm", perm}});
y_node = y_node->outputs[0];
}
return CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node},
node->outputs);
}
Node *arg_max_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto axis = BOOST_GET_CONST(int64_t, op->GetAttr("axis"));
return CreateBaseOp(graph, node, "popart_argmax",
{GetInputVarNode("X", node)},
{GetOutputVarNode("Out", node)},
{{"axis", axis}, {"keepdims", int64_t{0}}});
}
REGISTER_HANDLER(mean, mean_handler); REGISTER_HANDLER(mean, mean_handler);
REGISTER_HANDLER(pow, pow_handler); REGISTER_HANDLER(pow, pow_handler);
REGISTER_HANDLER(mul, mul_handler); REGISTER_HANDLER(mul, mul_handler);
...@@ -252,6 +368,9 @@ REGISTER_HANDLER(sum, sum_handler); ...@@ -252,6 +368,9 @@ REGISTER_HANDLER(sum, sum_handler);
REGISTER_HANDLER(softmax, softmax_handler); REGISTER_HANDLER(softmax, softmax_handler);
REGISTER_HANDLER(scale, scale_handler); REGISTER_HANDLER(scale, scale_handler);
REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler); REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler);
REGISTER_HANDLER(cumsum, cumsum_handler);
REGISTER_HANDLER(matmul_v2, matmul_v2_handler);
REGISTER_HANDLER(arg_max, arg_max_handler);
} // namespace } // namespace
} // namespace ipu } // namespace ipu
......
...@@ -22,7 +22,7 @@ namespace ipu { ...@@ -22,7 +22,7 @@ namespace ipu {
namespace { namespace {
Node *conv2d_handler(Graph *graph, Node *node) { Node *conv2d_handler(Graph *graph, Node *node) {
OpDesc *op = node->Op(); auto *op = node->Op();
auto dilations_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("dilations")); auto dilations_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("dilations"));
auto dilations = std::vector<int64_t>{dilations_.begin(), dilations_.end()}; auto dilations = std::vector<int64_t>{dilations_.begin(), dilations_.end()};
auto group_ = BOOST_GET_CONST(int, op->GetAttr("groups")); auto group_ = BOOST_GET_CONST(int, op->GetAttr("groups"));
...@@ -193,6 +193,21 @@ Node *layer_norm_handler(Graph *graph, Node *node) { ...@@ -193,6 +193,21 @@ Node *layer_norm_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto begin_norm_axis_ = BOOST_GET_CONST(int, op->GetAttr("begin_norm_axis")); auto begin_norm_axis_ = BOOST_GET_CONST(int, op->GetAttr("begin_norm_axis"));
auto input_shape_ = GetInputVarNode("X", node)->Var()->GetShape(); auto input_shape_ = GetInputVarNode("X", node)->Var()->GetShape();
auto epsilon_ = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
int64_t groups_ = 1;
auto groupnorm_attrs_ =
AttributeMap{{"epsilon", epsilon_}, {"num_groups", groups_}};
if (input_shape_.size() == 2) {
return CreateBaseOp(
graph, node, "popart_groupnormalization_v2",
{GetInputVarNode("X", node), GetInputVarNode("Scale", node),
GetInputVarNode("Bias", node)},
{GetOutputVarNode("Y", node), GetOutputVarNode("Mean", node),
GetOutputVarNode("Variance", node)},
groupnorm_attrs_);
}
std::vector<int64_t> norm_shape_{1, 1}; std::vector<int64_t> norm_shape_{1, 1};
for (int i = 0; i < input_shape_.size(); i++) { for (int i = 0; i < input_shape_.size(); i++) {
...@@ -213,10 +228,6 @@ Node *layer_norm_handler(Graph *graph, Node *node) { ...@@ -213,10 +228,6 @@ Node *layer_norm_handler(Graph *graph, Node *node) {
graph, node, "popart_reshape", graph, node, "popart_reshape",
{GetInputVarNode("X", node), reshape1_const->outputs[0]}, {}, {}); {GetInputVarNode("X", node), reshape1_const->outputs[0]}, {}, {});
auto epsilon_ = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
int64_t groups_ = 1;
auto groupnorm_attrs_ =
AttributeMap{{"epsilon", epsilon_}, {"num_groups", groups_}};
auto out_Y_ = MakeVarNode(graph, node); auto out_Y_ = MakeVarNode(graph, node);
CreateBaseOp(graph, node, "popart_groupnormalization_v2", CreateBaseOp(graph, node, "popart_groupnormalization_v2",
{new_node_reshape1->outputs[0], GetInputVarNode("Scale", node), {new_node_reshape1->outputs[0], GetInputVarNode("Scale", node),
...@@ -262,7 +273,7 @@ Node *dropout_handler(Graph *graph, Node *node) { ...@@ -262,7 +273,7 @@ Node *dropout_handler(Graph *graph, Node *node) {
CreateConst(graph, node, {}, {}, CreateConst(graph, node, {}, {},
{{"value", std::vector<float>{1 - dropout_prob_}}, {{"value", std::vector<float>{1 - dropout_prob_}},
{"dims", std::vector<int64_t>{1}}, {"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}}); {"dtype", GetOutputVarDtype(node)}});
return CreateBaseOp(graph, node, "popart_mul", return CreateBaseOp(graph, node, "popart_mul",
{GetInputVarNode("X", node), scale->outputs[0]}, {GetInputVarNode("X", node), scale->outputs[0]},
{GetOutputVarNode("Out", node)}, {}); {GetOutputVarNode("Out", node)}, {});
......
...@@ -31,15 +31,31 @@ const std::string GenerateOpName() { ...@@ -31,15 +31,31 @@ const std::string GenerateOpName() {
} }
const std::string CreateOpIdentifyId(Node *node) { const std::string CreateOpIdentifyId(Node *node) {
// format: op_type|out_var0|out_var1|...|_gen_* // format:
// if has custom op_namescope:
// {op_namescope}/op_type/_gen_*
// else:
// {op_type}/{out_var0}/{out_var1}/.../_gen_*
// this name will be used as op name when exporting onnx model from popart // this name will be used as op name when exporting onnx model from popart
auto op_type = node->Name(); auto op_type = node->Name();
std::string op_out = ""; std::string op_namescope;
for (auto *out_node : node->outputs) { if (node->Op()->HasAttr("op_namescope")) {
op_out += "|"; op_namescope =
op_out += out_node->Name(); BOOST_GET_CONST(std::string, node->Op()->GetAttr("op_namescope"));
} else {
op_namescope = "/";
}
if (op_namescope != "/") {
return {op_namescope + op_type + "/" + GenerateOpName()};
} else {
std::string op_out = "";
for (auto *out_node : node->outputs) {
op_out += "/";
op_out += out_node->Name();
}
return {op_type + op_out + "/" + GenerateOpName()};
} }
return {op_type + op_out + "|" + GenerateOpName()};
} }
Node *MakeVarNode(Graph *graph, Node *node) { Node *MakeVarNode(Graph *graph, Node *node) {
...@@ -100,6 +116,12 @@ Node *CreateBaseOp(Graph *graph, Node *node, const std::string &type, ...@@ -100,6 +116,12 @@ Node *CreateBaseOp(Graph *graph, Node *node, const std::string &type,
if (!new_node->Op()->HasAttr(sIpuStageAttr)) { if (!new_node->Op()->HasAttr(sIpuStageAttr)) {
CopyOpAttr(sIpuStageAttr, node->Op(), new_node->Op()); CopyOpAttr(sIpuStageAttr, node->Op(), new_node->Op());
} }
if (node->Op()->HasAttr(sMatmulSerializeFactor)) {
CopyOpAttr(sMatmulSerializeFactor, node->Op(), new_node->Op());
}
if (node->Op()->HasAttr(sMatmulSerializeMode)) {
CopyOpAttr(sMatmulSerializeMode, node->Op(), new_node->Op());
}
{ {
new_node->Op()->SetAttr(sOpIdentifyIdAttr, CreateOpIdentifyId(node)); new_node->Op()->SetAttr(sOpIdentifyIdAttr, CreateOpIdentifyId(node));
new_node->Op()->Flush(); new_node->Op()->Flush();
......
...@@ -14,15 +14,16 @@ ...@@ -14,15 +14,16 @@
#pragma once #pragma once
#include "paddle/fluid/platform/device/ipu/common.h" #include "paddle/fluid/platform/device/ipu/ipu_names.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" #include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
using paddle::framework::AttributeMap;
using paddle::framework::Attribute;
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace ipu { namespace ipu {
using paddle::framework::AttributeMap;
template <typename T> template <typename T>
AttributeMap MakeConstAttrMap(std::vector<T> value, std::vector<int64_t> dims, AttributeMap MakeConstAttrMap(std::vector<T> value, std::vector<int64_t> dims,
int dtype) { int dtype) {
...@@ -56,7 +57,7 @@ Node *CreateConst(Graph *graph, Node *node, const std::vector<Node *> &inputs, ...@@ -56,7 +57,7 @@ Node *CreateConst(Graph *graph, Node *node, const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs, const std::vector<Node *> &outputs,
const AttributeMap &attrs); const AttributeMap &attrs);
// otype is proto::VarType::Type // otype is framework::proto::VarType::Type
Node *CreateCast(Graph *graph, Node *node, const std::vector<Node *> &inputs, Node *CreateCast(Graph *graph, Node *node, const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs, const int otype); const std::vector<Node *> &outputs, const int otype);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace ipu {
namespace {
Node *custom_op_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto attrs = op->GetAttrMap();
attrs.insert({"__op_type", node->Op()->Type()});
auto new_node = CreateBaseOp(graph, node, "popart_custom_op", node->inputs,
node->outputs, attrs);
return new_node;
}
Node *print_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto print_phase = BOOST_GET_CONST(std::string, op->GetAttr("print_phase"));
int64_t print_gradient = 0;
if (print_phase != "forward") {
print_gradient = 1;
}
auto title = BOOST_GET_CONST(std::string, op->GetAttr("message"));
if (title.empty()) {
title = GetInputVarNode("In", node)->Var()->Name();
}
auto attrs =
AttributeMap{{"print_gradient", print_gradient}, {"title", title}};
return CreateBaseOp(graph, node, "popart_printtensor", node->inputs,
node->outputs, attrs);
}
Node *popart_optimizer_handler(Graph *graph, Node *node) { return nullptr; }
Node *checkpointoutput_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_checkpointoutput", node->inputs,
node->outputs);
}
REGISTER_HANDLER(custom_op, custom_op_handler);
REGISTER_HANDLER(print, print_handler);
REGISTER_HANDLER(popart_optimizer, popart_optimizer_handler);
REGISTER_HANDLER(checkpointoutput, checkpointoutput_handler);
} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
...@@ -21,24 +21,24 @@ namespace platform { ...@@ -21,24 +21,24 @@ namespace platform {
namespace ipu { namespace ipu {
namespace { namespace {
Node *topK_op_handler(Graph *graph, Node *node) { Node *topk_handler(Graph *graph, Node *node) {
VLOG(10) << "[topK_op_handler] entering to handler ...";
auto *op = node->Op(); auto *op = node->Op();
auto attrs = AttributeMap{}; auto attrs = AttributeMap{};
int axis_32INT = -1;
int axis_ = -1;
if (op->HasAttr("axis")) { if (op->HasAttr("axis")) {
axis_32INT = BOOST_GET_CONST(int, op->GetAttr("axis")); axis_ = BOOST_GET_CONST(int, op->GetAttr("axis"));
} }
if (axis_32INT == -1) { if (axis_ == -1) {
auto shape = GetInputVarNode("X", node)->Var()->GetShape(); auto shape = GetInputVarNode("X", node)->Var()->GetShape();
int rank = shape.size(); int rank = shape.size();
if (rank < 1) { if (rank < 1) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"The dimension of the shape of topK input should be large than 1")); "The dimension of the shape of topK input should be large than 1"));
} }
axis_32INT = rank - 1; axis_ = rank - 1;
} }
int64_t axis = int64_t{axis_32INT}; int64_t axis = int64_t{axis_};
attrs.emplace("axis", axis); attrs.emplace("axis", axis);
bool largest = true; bool largest = true;
...@@ -63,45 +63,31 @@ Node *topK_op_handler(Graph *graph, Node *node) { ...@@ -63,45 +63,31 @@ Node *topK_op_handler(Graph *graph, Node *node) {
attrs.emplace("sorted", 0); attrs.emplace("sorted", 0);
} }
std::vector<paddle::framework::ir::Node *> inputs = node->inputs; Node *var_x = GetInputVarNode("X", node);
if (node->inputs.size() == 2) { Node *var_k = nullptr;
// Input X tensor and K const tensor if (op->HasInput("K") && !op->Input("K").empty()) {
VLOG(10) << "[topK_op_handler] get 2 input tensors."; var_k = GetInputVarNode("K", node);
inputs[0] = node->inputs[1]; // K_t } else {
VLOG(10) << "[topK_op_handler] input node(" << inputs[0]->Var()->Name() auto k = BOOST_GET_CONST(int, op->GetAttr("k"));
<< ")"; auto *op_k =
inputs[1] = node->inputs[0]; // X CreateConst(graph, node, {}, {}, {{"value", std::vector<int64_t>{k}},
VLOG(10) << "[topK_op_handler] input node(" << inputs[1]->Var()->Name() {"dims", std::vector<int64_t>{1}},
<< ")"; {"dtype", ONNXDataType::INT64}});
} else if (node->inputs.size() == 1) { var_k = op_k->outputs[0];
// Input X tensor with k integer
VLOG(10) << "[topK_op_handler] get 1 input tensor.";
int k_32INT = BOOST_GET_CONST(int, op->GetAttr("k"));
int64_t k = int64_t{k_32INT};
attrs.emplace("k", k);
}
// show output node dtype
for (auto *o_node : node->outputs) {
auto *var = o_node->Var();
// see framework.pb.h
// VarType_Type_INT64 = 3,
// VarType_Type_FP32 = 5,
auto dtype = var->GetDataType();
if (dtype == 3) {
// poplar does not support int64_t
var->SetDataType(framework::proto::VarType::INT32);
}
std::string name = var->Name();
VLOG(10) << "[topK_op_handler] output node(" << name
<< ") dtype : " << dtype;
} }
VLOG(10) << "[topK_op_handler] leave the handler.";
return CreateBaseOp(graph, node, "popart_topk", inputs, auto *var_i = MakeVarNode(graph, node);
{node->outputs[1], node->outputs[0]}, attrs); CreateBaseOp(graph, node, "popart_topk", {var_x, var_k},
{GetOutputVarNode("Out", node), var_i},
{{"axis", int64_t{axis}},
{"largest", int64_t{largest}},
{"sorted", int64_t{sorted}}});
return CreateCast(graph, node, {var_i}, {GetOutputVarNode("Indices", node)},
static_cast<int>(framework::proto::VarType::INT32));
} }
REGISTER_HANDLER(top_k, topK_op_handler); REGISTER_HANDLER(top_k, topk_handler);
REGISTER_HANDLER(top_k_v2, topK_op_handler); REGISTER_HANDLER(top_k_v2, topk_handler);
} // namespace } // namespace
} // namespace ipu } // namespace ipu
......
...@@ -21,9 +21,6 @@ namespace platform { ...@@ -21,9 +21,6 @@ namespace platform {
namespace ipu { namespace ipu {
namespace { namespace {
using framework::Attribute;
using framework::AttributeMap;
Node *fill_constant_handler(Graph *graph, Node *node) { Node *fill_constant_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
if (op->HasInput("ShapeTensor") && !op->Input("ShapeTensor").empty()) { if (op->HasInput("ShapeTensor") && !op->Input("ShapeTensor").empty()) {
...@@ -133,6 +130,14 @@ Node *reshape_handler(Graph *graph, Node *node) { ...@@ -133,6 +130,14 @@ Node *reshape_handler(Graph *graph, Node *node) {
return new_node_reshape; return new_node_reshape;
} }
Node *flatten2_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
return CreateBaseOp(
graph, node, "popart_flatten", {GetInputVarNode("X", node)},
{GetOutputVarNode("Out", node)}, {{"axis", int64_t(axis)}});
}
Node *gather_handler(Graph *graph, Node *node) { Node *gather_handler(Graph *graph, Node *node) {
auto new_node_gather = auto new_node_gather =
CreateBaseOp(graph, node, "popart_gather", CreateBaseOp(graph, node, "popart_gather",
...@@ -169,7 +174,8 @@ Node *cast_handler(Graph *graph, Node *node) { ...@@ -169,7 +174,8 @@ Node *cast_handler(Graph *graph, Node *node) {
return new_node_cast; return new_node_cast;
} }
Node *lookup_table_handler(Graph *graph, Node *node) { Node *lookup_table_op_handler(Graph *graph, Node *node,
const std::string &type) {
auto *op = node->Op(); auto *op = node->Op();
auto padding_idx_ = BOOST_GET_CONST(int64_t, op->GetAttr("padding_idx")); auto padding_idx_ = BOOST_GET_CONST(int64_t, op->GetAttr("padding_idx"));
auto w_shape_ = GetInputVarNode("W", node)->Var()->GetShape(); auto w_shape_ = GetInputVarNode("W", node)->Var()->GetShape();
...@@ -183,7 +189,7 @@ Node *lookup_table_handler(Graph *graph, Node *node) { ...@@ -183,7 +189,7 @@ Node *lookup_table_handler(Graph *graph, Node *node) {
auto concat_const = auto concat_const =
CreateConst(graph, node, {}, {}, {{"value", const_value_}, CreateConst(graph, node, {}, {}, {{"value", const_value_},
{"dims", const_shape_}, {"dims", const_shape_},
{"dtype", ONNXDataType::FLOAT}}); {"dtype", GetOutputVarDtype(node)}});
auto axes = auto axes =
CreateConst(graph, node, {}, {}, {{"value", std::vector<int64_t>{0}}, CreateConst(graph, node, {}, {}, {{"value", std::vector<int64_t>{0}},
{"dims", std::vector<int64_t>{1}}, {"dims", std::vector<int64_t>{1}},
...@@ -247,16 +253,28 @@ Node *lookup_table_handler(Graph *graph, Node *node) { ...@@ -247,16 +253,28 @@ Node *lookup_table_handler(Graph *graph, Node *node) {
w_node = GetInputVarNode("W", node); w_node = GetInputVarNode("W", node);
} }
auto squeeze = CreateBaseOp(graph, node, "popart_squeeze", // lookup_table and lookup_table_v2
{GetInputVarNode("Ids", node)}, {}, auto ids = GetInputVarNode("Ids", node);
{{"axes", std::vector<int64_t>{-1}}}); if (type == "v1") {
ids = CreateBaseOp(graph, node, "popart_squeeze",
{GetInputVarNode("Ids", node)}, {},
{{"axes", std::vector<int64_t>{-1}}});
ids = ids->outputs[0];
}
auto gather = auto gather = CreateBaseOp(graph, node, "popart_gather", {w_node, ids},
CreateBaseOp(graph, node, "popart_gather", {w_node, squeeze->outputs[0]}, {GetOutputVarNode("Out", node)}, {});
{GetOutputVarNode("Out", node)}, {});
return gather; return gather;
} }
Node *lookup_table_handler(Graph *graph, Node *node) {
return lookup_table_op_handler(graph, node, "v1");
}
Node *lookup_table_v2_handler(Graph *graph, Node *node) {
return lookup_table_op_handler(graph, node, "v2");
}
Node *unsqueeze_handler(Graph *graph, Node *node) { Node *unsqueeze_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto axes_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("axes")); auto axes_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("axes"));
...@@ -336,11 +354,32 @@ Node *slice_handler(Graph *graph, Node *node) { ...@@ -336,11 +354,32 @@ Node *slice_handler(Graph *graph, Node *node) {
auto attr = MakeConstAttrMap<int>(axes_, {dim}, ONNXDataType::INT32); auto attr = MakeConstAttrMap<int>(axes_, {dim}, ONNXDataType::INT32);
axes = CreateConst(graph, node, {}, {}, attr); axes = CreateConst(graph, node, {}, {}, attr);
} }
auto new_node = CreateBaseOp(
graph, node, "popart_slice", auto decrease_axis_ =
{GetInputVarNode("Input", node), starts, ends, axes->outputs[0]}, BOOST_GET_CONST(std::vector<int>, op->GetAttr("decrease_axis"));
node->outputs); auto input_shape_ = GetInputVarNode("Input", node)->Var()->GetShape();
return new_node; auto output_shape_ = GetOutputVarNode("Out", node)->Var()->GetShape();
if (decrease_axis_.size() == 0) {
return CreateBaseOp(
graph, node, "popart_slice",
{GetInputVarNode("Input", node), starts, ends, axes->outputs[0]},
node->outputs);
} else if (output_shape_ == std::vector<int64_t>{0} ||
input_shape_.size() > output_shape_.size()) {
auto slice = CreateBaseOp(
graph, node, "popart_slice",
{GetInputVarNode("Input", node), starts, ends, axes->outputs[0]}, {},
{});
return CreateBaseOp(graph, node, "popart_squeeze", {slice->outputs[0]},
{GetOutputVarNode("Out", node)},
{{"axes", std::vector<int64_t>{decrease_axis_.begin(),
decrease_axis_.end()}}});
} else {
return CreateBaseOp(
graph, node, "popart_slice",
{GetInputVarNode("Input", node), starts, ends, axes->outputs[0]},
node->outputs);
}
} }
Node *expand_handler(Graph *graph, Node *node) { Node *expand_handler(Graph *graph, Node *node) {
...@@ -373,11 +412,94 @@ Node *expand_handler(Graph *graph, Node *node) { ...@@ -373,11 +412,94 @@ Node *expand_handler(Graph *graph, Node *node) {
return new_node; return new_node;
} }
Node *assign_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_identity",
{GetInputVarNode("X", node)},
{GetOutputVarNode("Out", node)}, {});
}
Node *fill_any_like_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto value = BOOST_GET_CONST(float, op->GetAttr("value"));
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape();
auto dtype = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto x_dtype = static_cast<framework::proto::VarType::Type>(dtype);
size_t size = 1;
for (auto &dim : x_shape) {
size *= dim;
}
Attribute out_value;
switch (x_dtype) {
case framework::proto::VarType::FP32:
out_value = std::vector<float>(size, value);
break;
case framework::proto::VarType::FP64:
out_value = std::vector<double>(size, value);
break;
case framework::proto::VarType::INT32:
out_value = std::vector<int>(size, value);
break;
case framework::proto::VarType::INT64:
out_value = std::vector<int64_t>(size, value);
break;
case framework::proto::VarType::BOOL:
out_value = std::vector<int64_t>(size, value);
break;
default:
PADDLE_THROW(
platform::errors::Unimplemented("fill_any_like dtype: %d", x_dtype));
}
return CreateConst(graph, node, node->inputs, node->outputs,
AttributeMap{
{"value", out_value},
{"dims", x_shape},
{"dtype", VarType2OnnxDtype(dtype)},
});
}
Node *one_hot_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto depth = BOOST_GET_CONST(int, op->GetAttr("depth"));
auto allow_out_of_range =
BOOST_GET_CONST(bool, op->GetAttr("allow_out_of_range"));
if (allow_out_of_range) {
PADDLE_THROW(platform::errors::Unimplemented(
"Do not support allow_out_of_range=True"));
} else {
auto depth_tensor = CreateConst(graph, node, {}, {},
{{"value", std::vector<int64_t>{depth}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT64}});
auto value_tensor =
CreateConst(graph, node, {}, {}, {{"value", std::vector<float>{0, 1}},
{"dims", std::vector<int64_t>{2}},
{"dtype", ONNXDataType::FLOAT}});
return CreateBaseOp(graph, node, "popart_onehot",
{GetInputVarNode("X", node), depth_tensor->outputs[0],
value_tensor->outputs[0]},
{GetOutputVarNode("Out", node)},
{{"axis", int64_t{-1}}});
}
}
Node *split_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
auto sections = BOOST_GET_CONST(std::vector<int>, op->GetAttr("sections"));
return CreateBaseOp(
graph, node, "popart_split", {GetInputVarNode("X", node)}, node->outputs,
{{"num_outputs", int64_t(sections.size())},
{"axis", int64_t(axis)},
{"split", std::vector<int64_t>{sections.begin(), sections.end()}}});
}
REGISTER_HANDLER(fill_constant, fill_constant_handler); REGISTER_HANDLER(fill_constant, fill_constant_handler);
REGISTER_HANDLER(gaussian_random, gaussian_random_handler); REGISTER_HANDLER(gaussian_random, gaussian_random_handler);
REGISTER_HANDLER(uniform_random, uniform_random_handler); REGISTER_HANDLER(uniform_random, uniform_random_handler);
REGISTER_HANDLER(transpose2, transpose_handler); REGISTER_HANDLER(transpose2, transpose_handler);
REGISTER_HANDLER(reshape2, reshape_handler); REGISTER_HANDLER(reshape2, reshape_handler);
REGISTER_HANDLER(flatten2, flatten2_handler);
REGISTER_HANDLER(gather, gather_handler); REGISTER_HANDLER(gather, gather_handler);
REGISTER_HANDLER(squeeze2, squeeze_handler); REGISTER_HANDLER(squeeze2, squeeze_handler);
REGISTER_HANDLER(cast, cast_handler); REGISTER_HANDLER(cast, cast_handler);
...@@ -388,6 +510,11 @@ REGISTER_HANDLER(stack, stack_handler); ...@@ -388,6 +510,11 @@ REGISTER_HANDLER(stack, stack_handler);
REGISTER_HANDLER(shape, shape_handler); REGISTER_HANDLER(shape, shape_handler);
REGISTER_HANDLER(slice, slice_handler); REGISTER_HANDLER(slice, slice_handler);
REGISTER_HANDLER(expand, expand_handler); REGISTER_HANDLER(expand, expand_handler);
REGISTER_HANDLER(assign, assign_handler);
REGISTER_HANDLER(fill_any_like, fill_any_like_handler);
REGISTER_HANDLER(lookup_table_v2, lookup_table_v2_handler);
REGISTER_HANDLER(split, split_handler);
REGISTER_HANDLER(one_hot, one_hot_handler);
} // namespace } // namespace
} // namespace ipu } // namespace ipu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册