未验证 提交 f1111f3c 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] support more ops 1/N (#44205)

* add authors
Co-authored-by: NAllen Guo <alleng@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>

* squash cpp changes 2/N
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>
上级 5988553f
......@@ -376,11 +376,473 @@ Node *dropout_handler(Graph *graph, Node *node) {
}
}
Node *conv2d_transpose_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto data_format = BOOST_GET_CONST(std::string, op->GetAttr("data_format"));
if (data_format != "NCHW") {
PADDLE_THROW(
platform::errors::InvalidArgument("Only support NCHW as data_format."));
}
auto *kernel_info = GetInputVarNode("Filter", node);
auto kernel_shape = kernel_info->Var()->GetShape();
auto dilations_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("dilations"));
auto dilations = std::vector<int64_t>{dilations_.begin(), dilations_.end()};
auto strides_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("strides"));
auto strides = std::vector<int64_t>{strides_.begin(), strides_.end()};
auto output_padding_ =
BOOST_GET_CONST(std::vector<int>, op->GetAttr("output_padding"));
auto output_padding =
std::vector<int64_t>{output_padding_.begin(), output_padding_.end()};
auto group_ = BOOST_GET_CONST(int, op->GetAttr("groups"));
auto group = int64_t(group_);
auto padding_algorithm =
BOOST_GET_CONST(std::string, op->GetAttr("padding_algorithm"));
auto paddings_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("paddings"));
if (paddings_.size() == 2) {
paddings_.push_back(paddings_[0]);
paddings_.push_back(paddings_[1]);
} else if (paddings_.size() == 4) {
std::swap(paddings_[1], paddings_[2]);
}
auto paddings = std::vector<int64_t>{paddings_.begin(), paddings_.end()};
if (padding_algorithm == "SAME") {
// Update paddings and dilations based on the sizes of H and W.
auto input_shape = GetInputVarNode("Input", node)->Var()->GetShape();
for (auto i = 0; i < 2; i++) {
auto out_size = (input_shape[i + 2] + strides[i] - 1) / strides[i];
auto pad_sum = std::max(
(out_size - 1) * strides[i] + kernel_shape[i] - input_shape[i + 2],
static_cast<int64_t>(0));
auto pad_0 = pad_sum / 2;
auto pad_1 = pad_sum - pad_0;
paddings[i] = pad_0;
paddings[i + 2] = pad_1;
}
for (auto i = 0; i < dilations.size(); i++) {
dilations[i] = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto i = 0; i < paddings.size(); i++) {
paddings[i] = 0;
}
}
auto attrs = AttributeMap{{"dilations", dilations},
{"group", group},
{"kernel_shape", kernel_shape},
{"output_padding", output_padding},
{"pads", paddings},
{"strides", strides}};
if (!op->Input("Bias").empty()) {
return CreateBaseOp(graph,
node,
"popart_convtranspose",
{
GetInputVarNode("Input", node),
GetInputVarNode("Filter", node),
GetInputVarNode("Bias", node),
},
node->outputs,
attrs);
} else {
return CreateBaseOp(graph,
node,
"popart_convtranspose",
{
GetInputVarNode("Input", node),
GetInputVarNode("Filter", node),
},
node->outputs,
attrs);
}
}
Node *affine_channel_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto data_layout = BOOST_GET_CONST(std::string, op->GetAttr("data_layout"));
if (data_layout != "NCHW") {
PADDLE_THROW(
platform::errors::InvalidArgument("Only support NCHW as data_format."));
}
auto *scale = GetInputVarNode("Scale", node);
auto *bias = GetInputVarNode("Bias", node);
auto scale_shape = scale->Var()->GetShape();
auto bias_shape = bias->Var()->GetShape();
if (scale_shape.size() <= 1 || bias_shape.size() <= 1) {
auto attrs = AttributeMap{{"value", std::vector<int64_t>{1, -1, 1, 1}},
{"dims", std::vector<int64_t>{4}},
{"dtype", ONNXDataType::INT64}};
auto new_shape_const = CreateConst(graph, node, {}, {}, attrs);
scale = CreateBaseOp(graph,
node,
"popart_reshape",
{scale, new_shape_const->outputs[0]},
{},
{})
->outputs[0];
bias = CreateBaseOp(graph,
node,
"popart_reshape",
{bias, new_shape_const->outputs[0]},
{},
{})
->outputs[0];
}
auto *out = CreateBaseOp(
graph, node, "popart_mul", {GetInputVarNode("X", node), scale}, {});
return CreateBaseOp(graph,
node,
"popart_add",
{out->outputs[0], bias},
{GetOutputVarNode("Out", node)});
}
Node *interp_handler(Graph *graph, Node *node, const std::string &mode) {
auto *op = node->Op();
auto data_layout = BOOST_GET_CONST(std::string, op->GetAttr("data_layout"));
if (data_layout != "NCHW") {
PADDLE_THROW(
platform::errors::InvalidArgument("Only support NCHW as data_format."));
}
auto align_corners = BOOST_GET_CONST(bool, op->GetAttr("align_corners"));
auto align_mode = BOOST_GET_CONST(int, op->GetAttr("align_mode"));
auto paddle_target_dtype = VarType::FP32;
auto onnx_target_dtype = ONNXDataType::FLOAT;
if (GetInputVarNode("X", node)->Var()->GetDataType() == VarType::FP16) {
paddle_target_dtype = VarType::FP16;
onnx_target_dtype = ONNXDataType::FLOAT16;
}
std::string coordinate_transformation_mode = "half_pixel";
if (align_corners) {
coordinate_transformation_mode = "align_corners";
} else if (mode == "nearest") {
coordinate_transformation_mode = "asymmetric";
} else if (align_mode == 1 && mode == "cubic") {
coordinate_transformation_mode = "asymmetric";
}
bool has_out_size = node->Op()->Input("OutSize").size() > 0;
bool has_size_tensor = node->Op()->Input("SizeTensor").size() > 0;
bool has_scale_tensor = node->Op()->Input("Scale").size() > 0;
Node *size = nullptr;
Node *scale = nullptr;
// Input: Size and Scale
if (has_out_size) {
// Get 'size' from the tensor
size = GetInputVarNode("OutSize", node);
if (size->Var()->GetDataType() != VarType::INT64) {
size = CreateCast(graph,
node,
{GetInputVarNode("OutSize", node)},
{},
VarType::INT64)
->outputs[0];
}
} else if (has_size_tensor) {
// Get 'size' from multi-tensors
std::vector<Node *> size_nodes;
for (auto var_name : node->Op()->Input("SizeTensor")) {
Node *size_node = GetInputVarNodeByVarName(var_name, node);
if (size_node->Var()->GetDataType() != VarType::INT64) {
size_node = CreateCast(graph, node, {size_node}, {}, VarType::INT64)
->outputs[0];
}
size_nodes.push_back(size_node);
}
size = CreateBaseOp(graph,
node,
"popart_concat",
size_nodes,
{},
{{"axis", int64_t(0)}})
->outputs[0];
} else if (has_scale_tensor) {
// Get 'scale' from tensor
scale = GetInputVarNode("Scale", node);
if (scale->Var()->GetDataType() != paddle_target_dtype) {
scale =
CreateCast(graph, node, {scale}, {}, paddle_target_dtype)->outputs[0];
}
auto *padding = CreateConst(graph,
node,
{},
{},
{{"value", std::vector<float>{1.0, 1.0}},
{"dims", std::vector<int64_t>{2}},
{"dtype", onnx_target_dtype}})
->outputs[0];
scale = CreateBaseOp(graph,
node,
"popart_concat",
{padding, scale},
{},
{{"axis", int64_t(0)}})
->outputs[0];
} else {
// Get 'size' or 'scale' from attribute
auto out_d = BOOST_GET_CONST(int, op->GetAttr("out_d"));
auto out_h = BOOST_GET_CONST(int, op->GetAttr("out_h"));
auto out_w = BOOST_GET_CONST(int, op->GetAttr("out_w"));
if (out_d > 0 || out_w > 0 || out_h > 0) {
std::vector<int64_t> out_size;
if (GetInputVarNode("X", node)->Var()->GetShape().size() == 5) {
out_size.push_back(int64_t(out_d));
out_size.push_back(int64_t(out_h));
} else if (GetInputVarNode("X", node)->Var()->GetShape().size() == 4) {
out_size.push_back(int64_t(out_h));
}
out_size.push_back(int64_t(out_w));
size =
CreateConst(graph,
node,
{},
{},
{{"value", out_size},
{"dims", std::vector<int64_t>{int64_t(out_size.size())}},
{"dtype", ONNXDataType::INT64}})
->outputs[0];
} else {
auto scale_value =
BOOST_GET_CONST(std::vector<float>, op->GetAttr("scale"));
float padding = 1.0;
scale_value.insert(scale_value.begin(), padding);
scale_value.insert(scale_value.begin(), padding);
scale = CreateConst(
graph,
node,
{},
{},
{{"value", scale_value},
{"dims", std::vector<int64_t>{int64_t(scale_value.size())}},
{"dtype", onnx_target_dtype}})
->outputs[0];
}
}
Node *roi =
CreateConst(
graph,
node,
{},
{},
{{"value",
std::vector<float>(
GetInputVarNode("X", node)->Var()->GetShape().size() * 2, 1.0)},
{"dims",
std::vector<int64_t>{int64_t(
GetInputVarNode("X", node)->Var()->GetShape().size() * 2)}},
{"dtype", onnx_target_dtype}})
->outputs[0];
if (size != nullptr) {
Node *input_shape =
CreateBaseOp(
graph, node, "popart_shape", {GetInputVarNode("X", node)}, {})
->outputs[0];
Node *start = CreateConst(graph,
node,
std::vector<int>{0},
std::vector<int64_t>{1},
ONNXDataType::INT32)
->outputs[0];
Node *end = CreateConst(graph,
node,
std::vector<int>{2},
std::vector<int64_t>{1},
ONNXDataType::INT32)
->outputs[0];
Node *axes = CreateConst(graph,
node,
std::vector<int>{0},
std::vector<int64_t>{1},
ONNXDataType::INT32)
->outputs[0];
Node *nc = CreateBaseOp(graph,
node,
"popart_slice",
{input_shape, start, end, axes},
{},
{})
->outputs[0];
size = CreateBaseOp(graph,
node,
"popart_concat",
{nc, size},
{},
{{"axis", int64_t(0)}})
->outputs[0];
}
auto resize_attrs = AttributeMap{
{"coordinate_transformation_mode", coordinate_transformation_mode},
{"cubic_coeff_a", float{-0.75}},
{"exclude_outside", int64_t{0}},
{"extrapolation_value", float{0.0}},
{"mode", mode},
{"nearest_mode", std::string("round_prefer_floor")}};
if (mode == "nearest" && coordinate_transformation_mode == "asymmetric") {
resize_attrs.at("nearest_mode") = std::string("floor");
}
return CreateBaseOp(graph,
node,
"popart_resize",
{GetInputVarNode("X", node), roi, scale, size},
{GetOutputVarNode("Out", node)},
resize_attrs);
}
Node *bilinear_interp_v2_handler(Graph *graph, Node *node) {
return interp_handler(graph, node, "linear");
}
Node *nearest_interp_v2_handler(Graph *graph, Node *node) {
return interp_handler(graph, node, "nearest");
}
Node *bicubic_interp_v2_handler(Graph *graph, Node *node) {
return interp_handler(graph, node, "cubic");
}
Node *linear_interp_v2_handler(Graph *graph, Node *node) {
return interp_handler(graph, node, "linear");
}
Node *trilinear_interp_v2_handler(Graph *graph, Node *node) {
return interp_handler(graph, node, "linear");
}
Node *data_norm_handler(Graph *graph, Node *node) {
auto *op = node->Op();
int slot_dim = -1;
if (op->HasAttr("slot_dim")) {
slot_dim = BOOST_GET_CONST(int, op->GetAttr("slot_dim"));
}
if (slot_dim > 0) {
PADDLE_THROW(
platform::errors::InvalidArgument("slot_dim > 0 is not supported."));
}
bool enable_scale_and_shift = false;
if (op->HasAttr("enable_scale_and_shift")) {
enable_scale_and_shift =
BOOST_GET_CONST(bool, op->GetAttr("enable_scale_and_shift"));
}
auto *mean_arr = CreateBaseOp(graph,
node,
"popart_div",
{GetInputVarNode("BatchSum", node),
GetInputVarNode("BatchSize", node)},
{})
->outputs[0];
auto *scale_arr = CreateBaseOp(graph,
node,
"popart_div",
{GetInputVarNode("BatchSize", node),
GetInputVarNode("BatchSquareSum", node)},
{})
->outputs[0];
scale_arr =
CreateBaseOp(graph, node, "popart_sqrt", {scale_arr}, {})->outputs[0];
auto out =
CreateBaseOp(
graph, node, "popart_sub", {GetInputVarNode("X", node), mean_arr}, {})
->outputs[0];
if (enable_scale_and_shift) {
auto scale_res = CreateBaseOp(graph,
node,
"popart_mul",
{out, GetInputVarNode("scale_w", node)},
{})
->outputs[0];
return CreateBaseOp(graph,
node,
"popart_add",
{scale_res, GetInputVarNode("bias", node)},
{GetOutputVarNode("Y", node)});
} else {
return CreateBaseOp(graph,
node,
"popart_mul",
{out, scale_arr},
{GetOutputVarNode("Y", node)});
}
}
Node *pad_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto mode = BOOST_GET_CONST(std::string, op->GetAttr("mode"));
auto value = BOOST_GET_CONST(float, op->GetAttr("value"));
auto data_format = BOOST_GET_CONST(std::string, op->GetAttr("data_format"));
if (data_format == "NDHWC") {
PADDLE_THROW(
platform::errors::Unimplemented("NDHWC format is not supported."));
}
if (mode == "replicate" || mode == "circular") {
PADDLE_THROW(platform::errors::Unimplemented(
"circular and replicate modes are not supported."));
}
if (op->Input("Paddings").size()) {
// Paddings -> input tensor
// PopART Pad Op only support `pad` as a constant
PADDLE_THROW(platform::errors::Unimplemented(
"Do not support Paddings as a inputs tensor"));
}
// Paddings -> Attr
auto paddings = BOOST_GET_CONST(std::vector<int>, op->GetAttr("paddings"));
std::vector<int64_t> new_paddings(10, 0);
new_paddings[2] = paddings[4];
new_paddings[3] = paddings[2];
new_paddings[4] = paddings[0];
new_paddings[7] = paddings[5];
new_paddings[8] = paddings[3];
new_paddings[9] = paddings[1];
auto *paddings_node = CreateConst(graph,
node,
new_paddings,
std::vector<int64_t>{10},
ONNXDataType::INT64)
->outputs[0];
auto *value_node = CreateConst(graph,
node,
std::vector<float>{value},
std::vector<int64_t>{1},
ONNXDataType::FLOAT)
->outputs[0];
return CreateBaseOp(graph,
node,
"popart_pad",
{GetInputVarNode("X", node), paddings_node, value_node},
{GetOutputVarNode("Out", node)},
{{"mode", mode}});
}
} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
REGISTER_HANDLER(affine_channel, affine_channel_handler);
REGISTER_HANDLER(pool2d, pool2d_handler);
REGISTER_HANDLER(max_pool2d_with_index, max_pool2d_with_index_handler);
REGISTER_HANDLER(batch_norm, batch_norm_handler);
......@@ -388,4 +850,12 @@ REGISTER_HANDLER(group_norm, group_norm_handler);
REGISTER_HANDLER(instance_norm, instance_norm_handler);
REGISTER_HANDLER(layer_norm, layer_norm_handler);
REGISTER_HANDLER(conv2d, conv2d_handler);
REGISTER_HANDLER(conv2d_transpose, conv2d_transpose_handler);
REGISTER_HANDLER(dropout, dropout_handler);
REGISTER_HANDLER(bilinear_interp_v2, bilinear_interp_v2_handler);
REGISTER_HANDLER(nearest_interp_v2, nearest_interp_v2_handler);
REGISTER_HANDLER(bicubic_interp_v2, bicubic_interp_v2_handler);
REGISTER_HANDLER(linear_interp_v2, linear_interp_v2_handler);
REGISTER_HANDLER(trilinear_interp_v2, trilinear_interp_v2_handler);
REGISTER_HANDLER(data_norm, data_norm_handler);
REGISTER_HANDLER(pad3d, pad_handler);
......@@ -33,10 +33,15 @@ Node *fill_constant_handler(Graph *graph, Node *node) {
auto dtype = VarType2OnnxDType(static_cast<VarType::Type>(dtype_));
auto dims = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("shape"));
auto value_ = BOOST_GET_CONST(float, op->GetAttr("value"));
size_t size = 1;
int size = 1;
for (auto &dim : dims) {
size *= dim;
}
PADDLE_ENFORCE_GT(size,
0,
errors::InvalidArgument(
"IPU doesn't support non-positive dimensions. Please "
"check tensor shape setting."));
Attribute value;
switch (dtype_) {
case VarType::FP16:
......@@ -598,10 +603,15 @@ Node *fill_any_like_handler(Graph *graph, Node *node) {
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape();
auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype"));
auto dtype = static_cast<VarType::Type>(dtype_);
size_t size = 1;
int size = 1;
for (auto &dim : x_shape) {
size *= dim;
}
PADDLE_ENFORCE_GT(size,
0,
errors::InvalidArgument(
"IPU doesn't support non-positive dimensions. Please "
"check tensor shape setting."));
Attribute out_value;
switch (dtype) {
......@@ -748,6 +758,491 @@ Node *dot_handler(Graph *graph, Node *node) {
});
}
Node *clip_handler(Graph *graph, Node *node) {
auto *op = node->Op();
// if (min_value == -FLT_MAX) then means no min_value
// if (max_value == FLT_MAX) then means no max_value
auto min_value = BOOST_GET_CONST(float, op->GetAttr("min"));
auto max_value = BOOST_GET_CONST(float, op->GetAttr("max"));
bool has_min_tensor = false;
bool has_max_tensor = false;
if (node->Op()->Input("Min").size()) {
has_min_tensor = true;
}
if (node->Op()->Input("Max").size()) {
has_max_tensor = true;
}
bool transfer_input_dtype = false;
Node *input_data = GetInputVarNode("X", node);
if (input_data->Var()->GetDataType() != VarType::FP32 &&
input_data->Var()->GetDataType() != VarType::FP16) {
input_data =
CreateCast(graph, node, {input_data}, {}, VarType::FP32)->outputs[0];
transfer_input_dtype = true;
}
Node *min_tensor = nullptr;
if (has_min_tensor) {
if (GetInputVarNode("Min", node)->Var()->GetDataType() != VarType::FP32) {
min_tensor =
CreateCast(
graph, node, {GetInputVarNode("Min", node)}, {}, VarType::FP32)
->outputs[0];
} else {
min_tensor = GetInputVarNode("Min", node);
}
} else {
min_tensor = CreateConst(graph,
node,
{},
{},
{{"value", std::vector<float>{min_value}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}})
->outputs[0];
}
Node *max_tensor = nullptr;
if (has_max_tensor) {
if (GetInputVarNode("Max", node)->Var()->GetDataType() != VarType::FP32) {
max_tensor =
CreateCast(
graph, node, {GetInputVarNode("Max", node)}, {}, VarType::FP32)
->outputs[0];
} else {
max_tensor = GetInputVarNode("Max", node);
}
} else {
max_tensor = CreateConst(graph,
node,
{},
{},
{{"value", std::vector<float>{max_value}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::FLOAT}})
->outputs[0];
}
if (transfer_input_dtype) {
auto clip_res = CreateBaseOp(
graph, node, "popart_clip", {input_data, min_tensor, max_tensor}, {});
return CreateCast(graph,
node,
clip_res->outputs,
{GetOutputVarNode("Out", node)},
GetInputVarNode("X", node)->Var()->GetDataType());
} else {
return CreateBaseOp(graph,
node,
"popart_clip",
{input_data, min_tensor, max_tensor},
{GetOutputVarNode("Out", node)});
}
}
Node *dist_handler(Graph *graph, Node *node) {
// Minimum negative float
union neg_infinity {
int neg_int_inf;
float neg_float_int;
};
neg_infinity neg_inf;
neg_inf.neg_int_inf = 0xFF800000;
float g_NegFloatInfinity = neg_inf.neg_float_int;
auto *op = node->Op();
auto *sub_node =
CreateBaseOp(graph,
node,
"popart_sub",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{})
->outputs[0];
auto *abs_node =
CreateBaseOp(graph, node, "popart_abs", {sub_node}, {})->outputs[0];
auto p = BOOST_GET_CONST(float, op->GetAttr("p"));
// Reshape to 1-D output
auto target_shape = AttributeMap{{"value", std::vector<int64_t>{-1}},
{"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT64}};
auto *target_shape_node =
CreateBaseOp(graph, node, "popart_constant", {}, {}, target_shape)
->outputs[0];
if (fabs(p) < 1e-6) {
auto *sign_node =
CreateBaseOp(graph, node, "popart_sign", {abs_node}, {})->outputs[0];
auto *sum_node = CreateBaseOp(graph,
node,
"popart_reducesum",
{sign_node},
{},
{{"keepdims", int64_t{0}}})
->outputs[0];
return CreateBaseOp(graph,
node,
"popart_reshape",
{sum_node, target_shape_node},
{GetOutputVarNode("Out", node)});
} else if (p == std::numeric_limits<float>::infinity()) {
auto *max_node = CreateBaseOp(graph,
node,
"popart_reducemax",
{abs_node},
{},
{{"keepdims", int64_t{0}}})
->outputs[0];
return CreateBaseOp(graph,
node,
"popart_reshape",
{max_node, target_shape_node},
{GetOutputVarNode("Out", node)});
} else if (p == g_NegFloatInfinity) {
auto *min_node = CreateBaseOp(graph,
node,
"popart_reducemin",
{abs_node},
{},
{{"keepdims", int64_t{0}}})
->outputs[0];
return CreateBaseOp(graph,
node,
"popart_reshape",
{min_node, target_shape_node},
{GetOutputVarNode("Out", node)});
} else {
auto target_dtype = ONNXDataType::FLOAT;
if (GetInputVarNode("X", node)->Var()->GetDataType() == VarType::FP16) {
target_dtype = ONNXDataType::FLOAT16;
}
auto pow_factor = AttributeMap{{"value", std::vector<float>{p}},
{"dims", std::vector<int64_t>{1}},
{"dtype", target_dtype}};
auto *pow_factor_node =
CreateBaseOp(graph, node, "popart_constant", {}, {}, pow_factor)
->outputs[0];
auto *pow_node =
CreateBaseOp(graph, node, "popart_pow", {abs_node, pow_factor_node}, {})
->outputs[0];
auto *sum_node = CreateBaseOp(graph,
node,
"popart_reducesum",
{pow_node},
{},
{{"keepdims", int64_t{0}}})
->outputs[0];
auto *s_node =
CreateBaseOp(
graph, node, "popart_reshape", {sum_node, target_shape_node}, {})
->outputs[0];
auto *p_1 =
CreateBaseOp(graph, node, "popart_reciprocal", {pow_factor_node}, {})
->outputs[0];
return CreateBaseOp(graph,
node,
"popart_pow",
{s_node, p_1},
{GetOutputVarNode("Out", node)});
}
}
Node *expand_as_v2_handler(Graph *graph, Node *node) {
auto *op = node->Op();
Node *shape = nullptr;
auto op_inputs = op->Inputs();
// PopART Expand Op only support the constant tensor as the input `shape`.
if (op_inputs.find("target_tensor") != op_inputs.end()) {
PADDLE_THROW(platform::errors::Unimplemented(
"Do not support input tensor `target_tensor`. Please use the attribute "
"`target_shape`."));
}
auto input_shape = GetInputVarNode("X", node)->Var()->GetShape();
auto shape_value =
BOOST_GET_CONST(std::vector<int>, op->GetAttr("target_shape"));
// Check the dimensions
int input_shape_index = input_shape.size() - 1;
int target_shape_index = shape_value.size() - 1;
while (input_shape_index >= 0) {
if (input_shape[input_shape_index] !=
int64_t(shape_value[target_shape_index]) &&
input_shape[input_shape_index] != int64_t(1)) {
PADDLE_THROW(platform::errors::Unimplemented(
"For input and `shape`, corresponding dimensions must have the same "
"value or input dim = 1."));
}
target_shape_index--;
input_shape_index--;
}
shape = CreateConst(
graph,
node,
{},
{},
{{"value",
std::vector<int64_t>{shape_value.begin(), shape_value.end()}},
{"dims", std::vector<int64_t>{int64_t(shape_value.size())}},
{"dtype", ONNXDataType::INT64}})
->outputs[0];
return CreateBaseOp(graph,
node,
"popart_expand",
{GetInputVarNode("X", node), shape},
{GetOutputVarNode("Out", node)});
}
Node *expand_v2_handler(Graph *graph, Node *node) {
auto *op = node->Op();
// PopART Expand Op only support the constant tensor as the input `shape`.
if (op->Input("Shape").size()) {
PADDLE_THROW(
platform::errors::Unimplemented("Do not support input tensor `Shape`. "
"Please use the attribute `shape`."));
}
if (op->Input("expand_shapes_tensor").size()) {
PADDLE_THROW(platform::errors::Unimplemented(
"Do not support input tensor `expand_shapes_tensor`. Please use the "
"attribute `shape`."));
}
auto input_shape = GetInputVarNode("X", node)->Var()->GetShape();
auto shape_value = BOOST_GET_CONST(std::vector<int>, op->GetAttr("shape"));
// Check the dimensions
int input_shape_index = input_shape.size() - 1;
int target_shape_index = shape_value.size() - 1;
while (input_shape_index >= 0) {
if (input_shape[input_shape_index] !=
int64_t(shape_value[target_shape_index]) &&
input_shape[input_shape_index] != int64_t(1)) {
PADDLE_THROW(platform::errors::Unimplemented(
"For input and `shape`, corresponding dimensions must have the same "
"value or input dim = 1."));
}
target_shape_index--;
input_shape_index--;
}
auto *shape =
CreateConst(
graph,
node,
{},
{},
{{"value",
std::vector<int64_t>{shape_value.begin(), shape_value.end()}},
{"dims", std::vector<int64_t>{int64_t(shape_value.size())}},
{"dtype", ONNXDataType::INT64}})
->outputs[0];
return CreateBaseOp(graph,
node,
"popart_expand",
{GetInputVarNode("X", node), shape},
{GetOutputVarNode("Out", node)});
}
Node *flatten_contiguous_range_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto start_axis = BOOST_GET_CONST(int, op->GetAttr("start_axis"));
auto stop_axis = BOOST_GET_CONST(int, op->GetAttr("stop_axis"));
auto input_rank = GetInputVarNode("X", node)->Var()->GetShape().size();
if (start_axis < 0) {
start_axis += input_rank;
}
if (stop_axis < 0) {
stop_axis += input_rank;
}
std::vector<int64_t> target_shape;
if (start_axis == 0 && stop_axis == input_rank - 1) {
target_shape.push_back(-1);
} else {
auto input_shape = GetInputVarNode("X", node)->Var()->GetShape();
if (start_axis == 0) {
target_shape.assign(input_shape.begin() + stop_axis + 1,
input_shape.end());
target_shape.insert(target_shape.begin(), -1);
} else if (stop_axis == input_rank - 1) {
target_shape.assign(input_shape.begin(),
input_shape.begin() + start_axis);
target_shape.push_back(-1);
} else {
target_shape.insert(target_shape.begin(),
input_shape.begin(),
input_shape.begin() + start_axis);
target_shape.push_back(-1);
target_shape.insert(target_shape.end(),
input_shape.begin() + stop_axis + 1,
input_shape.end());
}
}
auto *unknown_dim_node = CreateConst(graph,
node,
target_shape,
{int64_t(target_shape.size())},
ONNXDataType::INT64)
->outputs[0];
return CreateBaseOp(graph,
node,
"popart_reshape",
{GetInputVarNode("X", node), unknown_dim_node},
{GetOutputVarNode("Out", node)},
{});
}
Node *flip_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto axes = BOOST_GET_CONST(std::vector<int>, op->GetAttr("axis"));
auto input_shape = GetInputVarNode("X", node)->Var()->GetShape();
for (auto it = axes.begin(); it != axes.end();) {
if (*it < 0) {
*it += input_shape.size();
}
// Remove input_shape[axis] == 1
if (input_shape[*it] == 1) {
it = axes.erase(it);
} else {
it++;
}
}
auto *temp_node = GetInputVarNode("X", node);
for (auto i = 0; i < axes.size(); i++) {
auto axis = axes[i];
std::vector<int64_t> split;
split.resize(input_shape[axis], 1);
std::vector<Node *> splits_output_nodes;
for (int j = 0; j < split.size(); j++) {
splits_output_nodes.push_back(MakeVarNode(graph, node));
}
auto splits_outputs = CreateBaseOp(graph,
node,
"popart_split",
{temp_node},
{splits_output_nodes},
{{"num_outputs", int64_t(split.size())},
{"axis", int64_t(axis)},
{"split", split}})
->outputs;
std::reverse(splits_outputs.begin(), splits_outputs.end());
if (i != axes.size() - 1) {
temp_node = CreateBaseOp(graph,
node,
"popart_concat",
splits_outputs,
{},
{{"axis", int64_t(axis)}})
->outputs[0];
} else {
temp_node = CreateBaseOp(graph,
node,
"popart_concat",
splits_outputs,
{},
{{"axis", int64_t(axis)}})
->outputs[0];
}
}
// In case of `axis` is empty. Identity Op will be deleted in passes.
return CreateBaseOp(graph,
node,
"popart_identity",
{temp_node},
{GetOutputVarNode("Out", node)},
{});
}
Node *meshgrid_handler(Graph *graph, Node *node) {
Node *res = nullptr;
// All inputs are 1-D tensors
std::vector<int64_t> out_shape;
for (auto input : node->inputs) {
auto input_shape = input->Var()->GetShape();
out_shape.push_back(input_shape[0]);
}
// Expand Op only allows a const tensor as `shape`
auto *out_shape_node = CreateConst(graph,
node,
out_shape,
{int64_t(out_shape.size())},
ONNXDataType::INT64)
->outputs[0];
for (int i = 0; i < node->inputs.size(); i++) {
// Reshape each input tensor to [node->inputs.size()] by filling with 1
std::vector<int64_t> target_shape(node->inputs.size(), 1);
target_shape[i] = node->inputs[i]->Var()->GetShape()[0];
auto *target_shape_node = CreateConst(graph,
node,
target_shape,
{int64_t(target_shape.size())},
ONNXDataType::INT64)
->outputs[0];
auto *t_reshaped = CreateBaseOp(graph,
node,
"popart_reshape",
{node->inputs[i], target_shape_node},
{},
{})
->outputs[0];
res = CreateBaseOp(graph,
node,
"popart_expand",
{t_reshaped, out_shape_node},
{node->outputs[i]});
}
return res;
}
Node *p_norm_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto keepdim = BOOST_GET_CONST(bool, op->GetAttr("keepdim"));
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
auto porder = BOOST_GET_CONST(float, op->GetAttr("porder"));
auto target_dtype = ONNXDataType::FLOAT;
if (GetInputVarNode("X", node)->Var()->GetDataType() == VarType::FP16) {
target_dtype = ONNXDataType::FLOAT16;
}
auto *pnode = CreateConst(graph,
node,
std::vector<float>{porder},
std::vector<int64_t>{1},
target_dtype)
->outputs[0];
auto *abs_node =
CreateBaseOp(graph, node, "popart_abs", {GetInputVarNode("X", node)}, {})
->outputs[0];
auto *pow_node =
CreateBaseOp(graph, node, "popart_pow", {abs_node, pnode}, {})
->outputs[0];
auto *reducesum_node = CreateBaseOp(graph,
node,
"popart_reducesum",
{pow_node},
{},
{{"axes", std::vector<int64_t>{axis}},
{"keepdims", int64_t(keepdim)}})
->outputs[0];
auto *pnode1 =
CreateConst(graph,
node,
std::vector<float>{static_cast<float>(1.0 / porder)},
std::vector<int64_t>{1},
target_dtype)
->outputs[0];
return CreateBaseOp(graph,
node,
"popart_pow",
{reducesum_node, pnode1},
{GetOutputVarNode("Out", node)});
}
} // namespace
} // namespace ipu
} // namespace platform
......@@ -759,6 +1254,7 @@ REGISTER_HANDLER(uniform_random, uniform_random_handler);
REGISTER_HANDLER(transpose2, transpose_handler);
REGISTER_HANDLER(reshape2, reshape_handler);
REGISTER_HANDLER(flatten2, flatten2_handler);
REGISTER_HANDLER(flatten_contiguous_range, flatten_contiguous_range_handler);
REGISTER_HANDLER(gather, gather_handler);
REGISTER_HANDLER(squeeze2, squeeze_handler);
REGISTER_HANDLER(cast, cast_handler);
......@@ -769,6 +1265,8 @@ REGISTER_HANDLER(stack, stack_handler);
REGISTER_HANDLER(shape, shape_handler);
REGISTER_HANDLER(slice, slice_handler);
REGISTER_HANDLER(expand, expand_handler);
REGISTER_HANDLER(expand_v2, expand_v2_handler);
REGISTER_HANDLER(expand_as_v2, expand_as_v2_handler);
REGISTER_HANDLER(assign, assign_handler);
REGISTER_HANDLER(assign_value, assign_value_handler);
REGISTER_HANDLER(fill_any_like, fill_any_like_handler);
......@@ -777,3 +1275,8 @@ REGISTER_HANDLER(split, split_handler);
REGISTER_HANDLER(one_hot, one_hot_handler);
REGISTER_HANDLER(one_hot_v2, one_hot_v2_handler);
REGISTER_HANDLER(dot, dot_handler);
REGISTER_HANDLER(clip, clip_handler);
REGISTER_HANDLER(dist, dist_handler);
REGISTER_HANDLER(flip, flip_handler);
REGISTER_HANDLER(meshgrid, meshgrid_handler);
REGISTER_HANDLER(p_norm, p_norm_handler);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册