未验证 提交 2353db3a 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] add activation ops (#43662)

* add argmin and argsort ops (#800)

* add argmin and arsort ops

* Add dot bmm ops (#803)

* add bmm

* add dot op

* clean CreateConst

* clean CreateCast

* add activation ops (#808)

* add activation ops

* fix 1function-redefined error
上级 2a795dfa
......@@ -119,6 +119,21 @@ Node *tanh_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_tanh");
}
Node *brelu_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto t_min_ = BOOST_GET_CONST(float, op->GetAttr("t_min"));
auto t_max_ = BOOST_GET_CONST(float, op->GetAttr("t_max"));
auto x = GetInputVarNode("X", node);
auto cli_min = CreateConst(graph, node, std::vector<float>{t_min_}, {1},
ONNXDataType::FLOAT)
->outputs.front();
auto clip_max = CreateConst(graph, node, std::vector<float>{t_max_}, {1},
ONNXDataType::FLOAT)
->outputs.front();
return CreateBaseOp(graph, node, "popart_clip", {x, cli_min, clip_max},
node->outputs);
}
Node *gelu_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto approximate_ = BOOST_GET_CONST(bool, op->GetAttr("approximate"));
......@@ -160,6 +175,245 @@ Node *log_softmax_handler(Graph *graph, Node *node) {
node->outputs);
}
Node *elu_handler(Graph *graph, Node *node) {
auto alpha_ = BOOST_GET_CONST(float, node->Op()->GetAttr("alpha"));
return CreateBaseOp(graph, node, "popart_elu", node->inputs, node->outputs,
{
{"alpha", alpha_},
});
}
Node *hard_shrink_handler(Graph *graph, Node *node) {
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
return CreateBaseOp(graph, node, "popart_shrink", node->inputs, node->outputs,
{
{"lambd", threshold_},
{"bias", 0.0f},
});
}
Node *hard_sigmoid_handler(Graph *graph, Node *node) {
auto slope_ = BOOST_GET_CONST(float, node->Op()->GetAttr("slope"));
auto offset_ = BOOST_GET_CONST(float, node->Op()->GetAttr("offset"));
return CreateBaseOp(graph, node, "popart_hardsigmoid", node->inputs,
node->outputs,
{
{"alpha", slope_},
{"beta", offset_},
});
}
Node *hard_swish_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto scale_ = BOOST_GET_CONST(float, node->Op()->GetAttr("scale"));
auto offset_ = BOOST_GET_CONST(float, node->Op()->GetAttr("offset"));
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
auto scale_node =
CreateConst(graph, node, std::vector<float>{scale_}, {1}, GetVarDType(x))
->outputs.front();
auto offset_node =
CreateConst(graph, node, std::vector<float>{offset_}, {1}, GetVarDType(x))
->outputs.front();
auto add_node = CreateBaseOp(graph, node, "popart_add", {x, offset_node}, {})
->outputs.front();
auto cli_min = CreateConst(graph, node, std::vector<float>{0.0}, {1},
ONNXDataType::FLOAT)
->outputs.front();
auto clip_max = CreateConst(graph, node, std::vector<float>{threshold_}, {1},
ONNXDataType::FLOAT)
->outputs.front();
auto clip_node = CreateBaseOp(graph, node, "popart_clip",
{add_node, cli_min, clip_max}, {})
->outputs.front();
auto mul_node = CreateBaseOp(graph, node, "popart_mul", {x, clip_node}, {})
->outputs.front();
return CreateBaseOp(graph, node, "popart_div", {mul_node, scale_node},
{GetOutputVarNode("Out", node)});
}
Node *leaky_relu_handler(Graph *graph, Node *node) {
auto alpha_ = BOOST_GET_CONST(float, node->Op()->GetAttr("alpha"));
return CreateBaseOp(graph, node, "popart_leakyrelu", node->inputs,
node->outputs,
{
{"alpha", alpha_},
});
}
Node *log10_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
float ln10 = 2.30258509299404568401;
auto ln10_tensor =
CreateConst(graph, node, std::vector<float>{ln10}, {1}, GetVarDType(x))
->outputs.front();
auto log = CreateBaseOp(graph, node, "popart_log", {x}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_div", {log, ln10_tensor},
node->outputs);
}
Node *log1p_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto one =
CreateConst(graph, node, std::vector<float>{1.0}, {1}, GetVarDType(x))
->outputs.front();
auto add =
CreateBaseOp(graph, node, "popart_add", {x, one}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_log", {add}, node->outputs);
}
Node *log2_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
float ln2 = 0.693147180559945309;
auto ln2_tensor =
CreateConst(graph, node, std::vector<float>{ln2}, {1}, GetVarDType(x))
->outputs.front();
auto log = CreateBaseOp(graph, node, "popart_log", {x}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_div", {log, ln2_tensor},
node->outputs);
}
Node *logsigmoid_handler(Graph *graph, Node *node) {
auto sigmoid = CreateBaseOp(graph, node, "popart_sigmoid",
{GetInputVarNode("X", node)}, {})
->outputs.front();
return CreateBaseOp(graph, node, "popart_log", {sigmoid}, node->outputs);
}
Node *mish_handler(Graph *graph, Node *node) {
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
if (!is_float_equal(threshold_, 20.0f)) {
PADDLE_THROW(platform::errors::Unimplemented(
"For mish op, only support threshold = 20.0"));
}
auto x = GetInputVarNode("X", node);
auto softplus =
CreateBaseOp(graph, node, "popart_softplus", {x}, {})->outputs.front();
auto tanh =
CreateBaseOp(graph, node, "popart_tanh", {softplus}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_mul", {x, tanh}, node->outputs);
}
Node *prelu_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto alpha = GetInputVarNode("Alpha", node);
auto out = GetOutputVarNode("Out", node);
auto x_rank = x->Var()->GetShape().size();
auto alpha_rank = alpha->Var()->GetShape().size();
if (x_rank != alpha_rank) {
if (alpha_rank > 1) {
PADDLE_THROW(platform::errors::Unimplemented(
"For prelu op, Only support rank of alpha <=1 while Rank(alpha) != "
"Rank(input)."));
}
}
if (x_rank != alpha_rank) {
if (alpha_rank > 1) {
PADDLE_THROW(platform::errors::Unimplemented(
"For prelu op, Only support rank of alpha <= 1 while rank of alpha "
"is not equal with rank of input for operator prelu"));
}
if (x_rank <= 1) {
PADDLE_THROW(
platform::errors::Unimplemented("For prelu op, Rank of input should "
"greater than 2 for operator prelu"));
}
auto shape = std::vector<int64_t>(x_rank - 1, 1);
shape[0] = -1;
int64_t size = shape.size();
auto dim = std::vector<int64_t>{size};
auto reshape_const =
CreateConst(graph, node, shape, dim, ONNXDataType::INT64)
->outputs.front();
alpha =
CreateBaseOp(graph, node, "popart_reshape", {alpha, reshape_const}, {})
->outputs.front();
}
return CreateBaseOp(graph, node, "popart_prelu", {x, alpha}, {out});
}
Node *relu6_handler(Graph *graph, Node *node) {
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
auto cli_min = CreateConst(graph, node, std::vector<float>{0.0}, {1},
ONNXDataType::FLOAT)
->outputs.front();
auto clip_max = CreateConst(graph, node, std::vector<float>{threshold_}, {1},
ONNXDataType::FLOAT)
->outputs.front();
return CreateBaseOp(graph, node, "popart_clip",
{GetInputVarNode("X", node), cli_min, clip_max},
node->outputs);
}
Node *rsqrt_handler(Graph *graph, Node *node) {
auto rsqrt =
CreateBaseOp(graph, node, "popart_sqrt", {GetInputVarNode("X", node)}, {})
->outputs.front();
return CreateBaseOp(graph, node, "popart_reciprocal", {rsqrt}, node->outputs);
}
Node *selu_handler(Graph *graph, Node *node) {
auto alpha_ = BOOST_GET_CONST(float, node->Op()->GetAttr("alpha"));
auto scale_ = BOOST_GET_CONST(float, node->Op()->GetAttr("scale"));
return CreateBaseOp(graph, node, "popart_selu", node->inputs, node->outputs,
{
{"alpha", alpha_},
{"gamma", scale_},
});
}
Node *silu_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto sigmoid =
CreateBaseOp(graph, node, "popart_sigmoid", {x}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_mul", {x, sigmoid}, node->outputs);
}
Node *softshrink_handler(Graph *graph, Node *node) {
auto lambda_ = BOOST_GET_CONST(float, node->Op()->GetAttr("lambda"));
return CreateBaseOp(graph, node, "popart_shrink", node->inputs, node->outputs,
{
{"lambd", lambda_},
{"bias", lambda_},
});
}
Node *square_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
return CreateBaseOp(graph, node, "popart_mul", {x, x}, node->outputs);
}
Node *swish_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto out = GetOutputVarNode("Out", node);
auto beta_ = BOOST_GET_CONST(float, node->Op()->GetAttr("beta"));
auto beta_node =
CreateConst(graph, node, std::vector<float>{beta_}, {1}, GetVarDType(x))
->outputs.front();
auto beta_x_node = CreateBaseOp(graph, node, "popart_mul", {x, beta_node}, {})
->outputs.front();
auto sigmod_node =
CreateBaseOp(graph, node, "popart_sigmoid", {beta_x_node}, {})
->outputs.front();
return CreateBaseOp(graph, node, "popart_mul", {x, sigmod_node}, {out});
}
Node *tanh_shrink_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto tanh =
CreateBaseOp(graph, node, "popart_tanh", {x}, {})->outputs.front();
return CreateBaseOp(graph, node, "popart_sub", {x, tanh}, node->outputs);
}
Node *thresholded_relu_handler(Graph *graph, Node *node) {
auto threshold_ = BOOST_GET_CONST(float, node->Op()->GetAttr("threshold"));
auto x = GetInputVarNode("X", node);
return CreateBaseOp(graph, node, "popart_thresholdedrelu", {x}, node->outputs,
{
{"alpha", threshold_},
});
}
} // namespace
} // namespace ipu
} // namespace platform
......@@ -188,5 +442,26 @@ REGISTER_HANDLER(softsign, softsign_handler);
REGISTER_HANDLER(sqrt, sqrt_handler);
REGISTER_HANDLER(tan, tan_handler);
REGISTER_HANDLER(tanh, tanh_handler);
REGISTER_HANDLER(brelu, brelu_handler);
REGISTER_HANDLER(gelu, gelu_handler);
REGISTER_HANDLER(log_softmax, log_softmax_handler);
REGISTER_HANDLER(elu, elu_handler);
REGISTER_HANDLER(hard_shrink, hard_shrink_handler);
REGISTER_HANDLER(hard_sigmoid, hard_sigmoid_handler);
REGISTER_HANDLER(hard_swish, hard_swish_handler);
REGISTER_HANDLER(leaky_relu, leaky_relu_handler);
REGISTER_HANDLER(log10, log10_handler);
REGISTER_HANDLER(log1p, log1p_handler);
REGISTER_HANDLER(log2, log2_handler);
REGISTER_HANDLER(logsigmoid, logsigmoid_handler);
REGISTER_HANDLER(mish, mish_handler);
REGISTER_HANDLER(prelu, prelu_handler);
REGISTER_HANDLER(relu6, relu6_handler);
REGISTER_HANDLER(rsqrt, rsqrt_handler);
REGISTER_HANDLER(selu, selu_handler);
REGISTER_HANDLER(silu, silu_handler);
REGISTER_HANDLER(softshrink, softshrink_handler);
REGISTER_HANDLER(square, square_handler);
REGISTER_HANDLER(swish, swish_handler);
REGISTER_HANDLER(tanh_shrink, tanh_shrink_handler);
REGISTER_HANDLER(thresholded_relu, thresholded_relu_handler);
......@@ -117,15 +117,20 @@ const bool is_float_equal(float a, float b, float eps) {
return std::fabs(a - b) <= eps;
}
const int GetOutputVarDType(const Node *node, const std::string &output_name) {
auto out_node = GetOutputVarNode(output_name, node);
PADDLE_ENFORCE_NOT_NULL(out_node, platform::errors::Unavailable(
"Node's out node does not exist."));
auto var = out_node->Var();
const ONNXDataType GetVarDType(const Node *node) {
auto var = node->Var();
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::Unavailable("Node is not a variable."));
auto proto_var_type = var->GetDataType();
return static_cast<int>(VarType2OnnxDType(proto_var_type));
return VarType2OnnxDType(proto_var_type);
}
const ONNXDataType GetOutputVarDType(const Node *node,
const std::string &output_name) {
auto out_node = GetOutputVarNode(output_name, node);
PADDLE_ENFORCE_NOT_NULL(out_node, platform::errors::Unavailable(
"Node's out node does not exist."));
return GetVarDType(out_node);
}
} // namespace ipu
......
......@@ -78,8 +78,9 @@ Node *GetOutputVarNodeByVarName(const std::string &var_name,
const Node *op_node);
const bool is_float_equal(float a, float b, float eps = 1e-8);
const int GetOutputVarDType(const Node *node,
const std::string &output_name = "Out");
const ONNXDataType GetVarDType(const Node *node);
const ONNXDataType GetOutputVarDType(const Node *node,
const std::string &output_name = "Out");
} // namespace ipu
} // namespace platform
......
......@@ -40,10 +40,10 @@ Node *pow_handler(Graph *graph, Node *node) {
} else {
// Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow)
auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor"));
auto attrs =
MakeConstAttrMapFromValue<float>(value_, {1}, GetOutputVarDType(node));
auto new_node_const =
CreateConst(graph, node, std::vector<decltype(value_)>{value_}, {1},
GetOutputVarDType(node));
auto new_node_const = CreateConst(graph, node, {}, {}, attrs);
return CreateBaseOp(
graph, node, "popart_pow",
{GetInputVarNode("X", node), new_node_const->outputs[0]},
......@@ -135,8 +135,9 @@ Node *matmul_handler(Graph *graph, Node *node) {
} else {
auto o_node =
CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node}, {});
auto attr = MakeConstAttrMapFromValue(alpha, {1}, GetOutputVarDType(node));
auto const_node = CreateConst(graph, node, {}, {}, attr);
auto const_node =
CreateConst(graph, node, std::vector<decltype(alpha)>{alpha}, {1},
GetOutputVarDType(node));
return CreateBaseOp(graph, node, "popart_mul",
{o_node->outputs[0], const_node->outputs[0]},
node->outputs);
......@@ -163,8 +164,8 @@ Node *scale_handler(Graph *graph, Node *node) {
BOOST_GET_CONST(bool, op->GetAttr("bias_after_scale"));
auto data_type_ = GetInputVarNode("X", node)->Var()->GetDataType();
auto cast = CreateCast(graph, node, {GetInputVarNode("X", node)}, {},
static_cast<int>(framework::proto::VarType::FP32));
auto cast =
CreateCast(graph, node, {GetInputVarNode("X", node)}, {}, VarType::FP32);
Node *result = nullptr;
if (!op->Input("ScaleTensor").empty()) {
......@@ -232,8 +233,7 @@ Node *scale_handler(Graph *graph, Node *node) {
}
}
auto result_after_cast =
CreateCast(graph, node, result->outputs, node->outputs,
static_cast<int>(data_type_));
CreateCast(graph, node, result->outputs, node->outputs, data_type_);
return result_after_cast;
}
......@@ -241,12 +241,11 @@ Node *cross_entropy2_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index"));
Node *new_cast = nullptr;
if (GetInputVarNode("Label", node)->Var()->GetDataType() ==
framework::proto::VarType::INT32) {
if (GetInputVarNode("Label", node)->Var()->GetDataType() == VarType::INT32) {
new_cast = GetInputVarNode("Label", node);
} else {
auto new_cast = CreateCast(graph, node, {GetInputVarNode("Label", node)},
{}, framework::proto::VarType::INT32);
{}, VarType::INT32);
new_cast = new_cast->outputs[0];
}
auto label_shape_ = GetInputVarNode("Label", node)->Var()->GetShape();
......@@ -310,12 +309,11 @@ Node *softmax_with_cross_entropy_handler(Graph *graph, Node *node) {
"soft_label is not supported yet in IPU"));
}
Node *new_cast = nullptr;
if (GetInputVarNode("Label", node)->Var()->GetDataType() ==
framework::proto::VarType::INT32) {
if (GetInputVarNode("Label", node)->Var()->GetDataType() == VarType::INT32) {
new_cast = GetInputVarNode("Label", node);
} else {
auto new_cast = CreateCast(graph, node, {GetInputVarNode("Label", node)},
{}, framework::proto::VarType::INT32);
{}, VarType::INT32);
new_cast = new_cast->outputs[0];
}
auto softmax_node = CreateSoftmaxOpset11(
......@@ -432,6 +430,12 @@ Node *matmul_v2_handler(Graph *graph, Node *node) {
node->outputs);
}
Node *bmm_handler(Graph *graph, Node *node) {
return CreateBaseOp(graph, node, "popart_matmul",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
node->outputs);
}
Node *arg_max_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto axis = BOOST_GET_CONST(int64_t, op->GetAttr("axis"));
......@@ -441,6 +445,15 @@ Node *arg_max_handler(Graph *graph, Node *node) {
{{"axis", axis}, {"keepdims", int64_t{0}}});
}
Node *arg_min_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto axis = BOOST_GET_CONST(int64_t, op->GetAttr("axis"));
return CreateBaseOp(graph, node, "popart_argmin",
{GetInputVarNode("X", node)},
{GetOutputVarNode("Out", node)},
{{"axis", axis}, {"keepdims", int64_t{0}}});
}
} // namespace
} // namespace ipu
} // namespace platform
......@@ -458,4 +471,6 @@ REGISTER_HANDLER(softmax_with_cross_entropy,
REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler);
REGISTER_HANDLER(cumsum, cumsum_handler);
REGISTER_HANDLER(matmul_v2, matmul_v2_handler);
REGISTER_HANDLER(bmm, bmm_handler);
REGISTER_HANDLER(arg_max, arg_max_handler);
REGISTER_HANDLER(arg_min, arg_min_handler);
......@@ -123,8 +123,9 @@ Node *CreateConst(Graph *graph, Node *node, const std::vector<Node *> &inputs,
}
Node *CreateCast(Graph *graph, Node *node, const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs, const int otype) {
auto to = VarType2PopartStr(static_cast<VarType::Type>(otype));
const std::vector<Node *> &outputs,
const VarType::Type otype) {
auto to = VarType2PopartStr(otype);
return CreateBaseOp(graph, node, "popart_cast", inputs, outputs,
{{"to", to}});
}
......
......@@ -24,22 +24,6 @@ namespace paddle {
namespace platform {
namespace ipu {
template <typename T>
AttributeMap MakeConstAttrMap(std::vector<T> value, std::vector<int64_t> dims,
int dtype) {
return AttributeMap{{"value", value}, {"dims", dims}, {"dtype", dtype}};
}
template <typename T>
AttributeMap MakeConstAttrMapFromValue(T v, std::vector<int64_t> dims,
int dtype) {
size_t size = 1;
for (auto &dim : dims) {
size *= dim;
}
return MakeConstAttrMap<T>(std::vector<T>(size, v), dims, dtype);
}
const std::string GenerateVarName();
const std::string CreateOpIdentifyId(Node *node);
......@@ -57,9 +41,16 @@ Node *CreateConst(Graph *graph, Node *node, const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs,
const AttributeMap &attrs);
// otype is framework::proto::VarType::Type
template <typename T>
Node *CreateConst(Graph *graph, Node *node, const std::vector<T> &value,
const std::vector<int64_t> &dims, ONNXDataType dtype) {
return CreateConst(
graph, node, {}, {},
AttributeMap{{"value", value}, {"dims", dims}, {"dtype", dtype}});
}
Node *CreateCast(Graph *graph, Node *node, const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs, const int otype);
const std::vector<Node *> &outputs, const VarType::Type otype);
Node *CreateGemm(Graph *graph, Node *node, const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs, int64_t transA = 0,
......
......@@ -83,7 +83,30 @@ Node *topk_handler(Graph *graph, Node *node) {
{"largest", int64_t{largest}},
{"sorted", int64_t{sorted}}});
return CreateCast(graph, node, {var_i}, {GetOutputVarNode("Indices", node)},
static_cast<int>(framework::proto::VarType::INT32));
VarType::INT32);
}
Node *argsort_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape();
auto axis_ = BOOST_GET_CONST(int, op->GetAttr("axis"));
auto descending_ = BOOST_GET_CONST(bool, op->GetAttr("descending"));
if (axis_ < 0) {
axis_ = axis_ + x_shape.size();
}
auto *dim_size =
CreateConst(graph, node, std::vector<int64_t>{x_shape[axis_]}, {1},
ONNXDataType::INT64)
->outputs.front();
int64_t largest = descending_ ? 1 : 0;
return CreateBaseOp(
graph, node, "popart_topk", {GetInputVarNode("X", node), dim_size},
{GetOutputVarNode("Out", node), GetOutputVarNode("Indices", node)},
{
{"axis", int64_t{axis_}},
{"largest", int64_t{largest}},
{"sorted", int64_t{0}},
});
}
} // namespace
......@@ -93,3 +116,4 @@ Node *topk_handler(Graph *graph, Node *node) {
REGISTER_HANDLER(top_k, topk_handler);
REGISTER_HANDLER(top_k_v2, topk_handler);
REGISTER_HANDLER(argsort, argsort_handler);
......@@ -179,7 +179,8 @@ Node *squeeze_handler(Graph *graph, Node *node) {
Node *cast_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto otype = BOOST_GET_CONST(int, op->GetAttr("out_dtype"));
auto new_node = CreateCast(graph, node, node->inputs, node->outputs, otype);
auto new_node = CreateCast(graph, node, node->inputs, node->outputs,
static_cast<VarType::Type>(otype));
// Cast op created in mixed-precison has no pipline attrs
auto &prev_nodes = node->inputs.front()->inputs;
if (!prev_nodes.empty()) {
......@@ -356,8 +357,8 @@ Node *slice_handler(Graph *graph, Node *node) {
} else {
auto starts_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("starts"));
auto dim = int64_t(starts_.size());
auto attr = MakeConstAttrMap<int>(starts_, {dim}, ONNXDataType::INT32);
starts = CreateConst(graph, node, {}, {}, attr);
starts = CreateConst(graph, node, std::vector<int>{starts_}, {dim},
ONNXDataType::INT32);
starts = starts->outputs[0];
}
Node *ends = nullptr;
......@@ -366,16 +367,16 @@ Node *slice_handler(Graph *graph, Node *node) {
} else {
auto ends_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("ends"));
auto dim = int64_t(ends_.size());
auto attr = MakeConstAttrMap<int>(ends_, {dim}, ONNXDataType::INT32);
ends = CreateConst(graph, node, {}, {}, attr);
ends = CreateConst(graph, node, std::vector<int>{ends_}, {dim},
ONNXDataType::INT32);
ends = ends->outputs[0];
}
Node *axes = nullptr;
{
auto axes_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("axes"));
auto dim = int64_t(axes_.size());
auto attr = MakeConstAttrMap<int>(axes_, {dim}, ONNXDataType::INT32);
axes = CreateConst(graph, node, {}, {}, attr);
axes = CreateConst(graph, node, std::vector<int>{axes_}, {dim},
ONNXDataType::INT32);
}
auto decrease_axis_ =
......@@ -424,9 +425,8 @@ Node *expand_handler(Graph *graph, Node *node) {
auto expand_times_ =
std::vector<int64_t>{expand_times_i32.begin(), expand_times_i32.end()};
auto dim = int64_t(expand_times_.size());
auto attr =
MakeConstAttrMap<int64_t>(expand_times_, {dim}, ONNXDataType::INT64);
expand_times = CreateConst(graph, node, {}, {}, attr);
expand_times = CreateConst(graph, node, std::vector<int64_t>{expand_times_},
{dim}, ONNXDataType::INT64);
}
auto new_node = CreateBaseOp(
graph, node, "popart_tile",
......@@ -593,6 +593,19 @@ Node *split_handler(Graph *graph, Node *node) {
{"split", std::vector<int64_t>{sections.begin(), sections.end()}}});
}
Node *dot_handler(Graph *graph, Node *node) {
auto x = GetInputVarNode("X", node);
auto mul_node = CreateBaseOp(graph, node, "popart_mul",
{x, GetInputVarNode("Y", node)}, {})
->outputs.front();
int64_t axes = x->Var()->GetShape().size() - 1;
return CreateBaseOp(graph, node, "popart_reducesum", {mul_node},
{GetOutputVarNode("Out", node)},
{
{"axes", std::vector<int64_t>{axes}},
});
}
} // namespace
} // namespace ipu
} // namespace platform
......@@ -621,3 +634,4 @@ REGISTER_HANDLER(lookup_table_v2, lookup_table_v2_handler);
REGISTER_HANDLER(split, split_handler);
REGISTER_HANDLER(one_hot, one_hot_handler);
REGISTER_HANDLER(one_hot_v2, one_hot_v2_handler);
REGISTER_HANDLER(dot, dot_handler);
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_test_op()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
def set_test_op(self):
self.op = F.elu
self.op_attrs = {}
def set_data_feed(self):
data = np.random.uniform(size=[1, 3, 10, 10])
self.feed_fp32 = {'in_0': data.astype(np.float32)}
self.feed_fp16 = {'in_0': data.astype(np.float16)}
self.feed_list = list(self.feed_fp32.keys())
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
self.feed_dtype = [x.dtype for x in self.feed_fp32.values()]
@IPUOpTest.static_graph
def build_model(self):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
out = self.op(x, **self.op_attrs)
self.fetch_list = [out.name]
def run_model(self, exec_mode):
self.run_op_test(exec_mode)
def test(self):
for m in IPUOpTest.ExecutionMode:
if not self.skip_mode(m):
self.build_model()
self.run_model(m)
self.check()
class TestBReluCase0(TestBase):
def set_data_feed(self):
data = np.random.uniform(size=[1, 3, 10, 10]) * 30
self.feed_fp32 = {'in_0': data.astype(np.float32)}
self.feed_fp16 = {'in_0': data.astype(np.float16)}
self.feed_list = list(self.feed_fp32.keys())
def set_test_op(self):
self.op = paddle.fluid.layers.brelu
self.op_attrs = {}
class TestBReluCase1(TestBReluCase0):
def set_test_op(self):
self.op = paddle.fluid.layers.brelu
self.op_attrs = {"t_min": 0.1, 't_max': 10.0}
class TestEluCase1(TestBase):
def set_test_op(self):
self.op = F.elu
self.op_attrs = {"alpha": 0.3}
class TestHardShrinkCase0(TestBase):
def set_test_op(self):
self.op = F.hardshrink
self.op_attrs = {}
class TestHardSigmoidCase0(TestBase):
def set_test_op(self):
self.op = F.hardsigmoid
self.op_attrs = {}
class TestHardSigmoidCase1(TestBase):
def set_test_op(self):
self.op = F.hardsigmoid
self.op_attrs = {
'slope': 0.2,
'offset': 0.33,
}
class TestHardSwishCase0(TestBase):
def set_test_op(self):
self.op = F.hardswish
self.op_attrs = {}
class TestLeakyReluCase0(TestBase):
def set_test_op(self):
self.op = F.leaky_relu
self.op_attrs = {}
class TestLeakyReluCase1(TestBase):
def set_test_op(self):
self.op = F.leaky_relu
self.op_attrs = {'negative_slope': 0.2333}
class TestLog10Case0(TestBase):
def set_test_op(self):
self.op = paddle.log10
self.op_attrs = {}
class TestLog1pCase0(TestBase):
def set_test_op(self):
self.op = paddle.log1p
self.op_attrs = {}
class TestLog2Case0(TestBase):
def set_test_op(self):
self.op = paddle.log2
self.op_attrs = {}
class TestLogSigmoidCase0(TestBase):
def set_test_op(self):
self.op = F.log_sigmoid
self.op_attrs = {}
class TestLogSoftmaxCase0(TestBase):
def set_test_op(self):
self.op = F.log_softmax
self.op_attrs = {}
class TestMishCase0(TestBase):
def set_test_op(self):
self.op = F.mish
self.op_attrs = {}
class TestRelu6Case0(TestBase):
def set_test_op(self):
self.op = F.relu6
self.op_attrs = {}
class TestRsqrtCase0(TestBase):
def set_test_op(self):
self.op = paddle.rsqrt
self.op_attrs = {}
class TestSeluCase0(TestBase):
def set_test_op(self):
self.op = F.selu
self.op_attrs = {}
class TestSiluCase0(TestBase):
def set_test_op(self):
self.op = F.silu
self.op_attrs = {}
class TestSoftShrinkCase0(TestBase):
def set_test_op(self):
self.op = F.softshrink
self.op_attrs = {}
class TestSoftShrinkCase1(TestBase):
def set_test_op(self):
self.op = F.softshrink
self.op_attrs = {'threshold': 0.2333}
class TestSquareCase0(TestBase):
def set_test_op(self):
self.op = paddle.square
self.op_attrs = {}
class TestSwishCase0(TestBase):
def set_test_op(self):
self.op = F.swish
self.op_attrs = {}
class TestTanhShrinkCase0(TestBase):
def set_atol(self):
super().set_atol()
self.atol = 1e-7
def set_test_op(self):
self.op = F.tanhshrink
self.op_attrs = {}
class TestThresholdedReluCase0(TestBase):
def set_test_op(self):
self.op = F.thresholded_relu
self.op_attrs = {}
class TestThresholdedReluCase1(TestBase):
def set_test_op(self):
self.op = F.thresholded_relu
self.op_attrs = {'threshold': 0.2333}
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -16,7 +16,6 @@ import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
......@@ -33,10 +32,9 @@ class TestBase(IPUOpTest):
self.set_op_attrs()
def set_data_feed(self):
data = np.random.uniform(size=[1, 3, 10, 10])
self.feed_fp32 = {'in_0': data.astype(np.float32)}
self.feed_fp16 = {'in_0': data.astype(np.float16)}
self.feed_list = list(self.feed_fp32.keys())
data = np.random.uniform(size=[10, 500]).astype(np.float16)
self.feed_fp32 = {"in_0": data.astype(np.float32)}
self.feed_fp16 = {"in_0": data.astype(np.float16)}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
......@@ -51,7 +49,7 @@ class TestBase(IPUOpTest):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
out = F.log_softmax(x, **self.attrs)
out = paddle.fluid.layers.argmin(x, **self.attrs)
self.fetch_list = [out.name]
def run_model(self, exec_mode):
......@@ -62,13 +60,15 @@ class TestBase(IPUOpTest):
if not self.skip_mode(m):
self.build_model()
self.run_model(m)
for k, v in self.output_dict.items():
self.output_dict[k] = v.astype(np.int32)
self.check()
class TestCase1(TestBase):
def set_attrs(self):
self.attrs = {"axis": 1}
def set_op_attrs(self):
self.attrs = {"axis": 0}
if __name__ == "__main__":
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
self.set_op_attrs()
def set_data_feed(self):
data = np.random.uniform(size=[1, 2, 3, 3]).astype(np.float16)
self.feed_fp32 = {"in_0": data.astype(np.float32)}
self.feed_fp16 = {"in_0": data.astype(np.float16)}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
self.feed_dtype = [x.dtype for x in self.feed_fp32.values()]
def set_op_attrs(self):
self.attrs = {
'axis': -1,
'descending': False,
}
@IPUOpTest.static_graph
def build_model(self):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
out, _ = paddle.fluid.layers.argsort(x, **self.attrs)
self.fetch_list = [out.name]
def run_model(self, exec_mode):
self.run_op_test(exec_mode)
def test(self):
for m in IPUOpTest.ExecutionMode:
if not self.skip_mode(m):
self.build_model()
self.run_model(m)
for k, v in self.output_dict.items():
self.output_dict[k] = v.astype(np.int32)
self.check()
class TestCase1(TestBase):
def set_op_attrs(self):
self.attrs = {
'axis': 0,
'descending': False,
}
class TestCase2(TestBase):
def set_op_attrs(self):
self.attrs = {
'axis': 1,
'descending': True,
}
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
self.set_op_attrs()
def set_data_feed(self):
x = np.random.uniform(size=[4, 2, 3])
y = np.random.uniform(size=[4, 3, 2])
self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)}
self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def set_op_attrs(self):
self.attrs = {}
@IPUOpTest.static_graph
def build_model(self):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
y = paddle.static.data(name=self.feed_list[1],
shape=self.feed_shape[1],
dtype='float32')
out = paddle.bmm(x, y, **self.attrs)
self.fetch_list = [out.name]
def run_model(self, exec_mode):
self.run_op_test(exec_mode)
def test(self):
for m in IPUOpTest.ExecutionMode:
if not self.skip_mode(m):
self.build_model()
self.run_model(m)
self.check()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
self.set_op_attrs()
def set_data_feed(self):
x = np.random.uniform(size=[4, 6])
y = np.random.uniform(size=[4, 6])
self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)}
self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def set_op_attrs(self):
self.attrs = {}
@IPUOpTest.static_graph
def build_model(self):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
y = paddle.static.data(name=self.feed_list[1],
shape=self.feed_shape[1],
dtype='float32')
out = paddle.dot(x, y, **self.attrs)
self.fetch_list = [out.name]
def run_model(self, exec_mode):
self.run_op_test(exec_mode)
def test(self):
for m in IPUOpTest.ExecutionMode:
if not self.skip_mode(m):
self.build_model()
self.run_model(m)
self.check()
class TestCase1(TestBase):
def set_data_feed(self):
x = np.random.uniform(size=[6])
y = np.random.uniform(size=[6])
self.feed_fp32 = {"x": x.astype(np.float32), "y": y.astype(np.float32)}
self.feed_fp16 = {"x": x.astype(np.float16), "y": y.astype(np.float16)}
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
self.set_op_attrs()
def set_data_feed(self):
data = np.random.uniform(size=[1, 3, 10, 10])
self.feed_fp32 = {'x': data.astype(np.float32)}
self.feed_fp16 = {'x': data.astype(np.float16)}
self.feed_list = list(self.feed_fp32.keys())
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
self.feed_dtype = [x.dtype for x in self.feed_fp32.values()]
def set_op_attrs(self):
self.attrs = {}
@IPUOpTest.static_graph
def build_model(self):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
array = np.random.uniform(size=[1]).astype(np.float32)
result1 = paddle.zeros(shape=[1], dtype='float32')
weight = paddle.assign(array, result1)
out = F.prelu(x, weight=weight, **self.attrs)
self.fetch_list = [out.name]
def run_model(self, exec_mode):
ipu_strategy = paddle.static.IpuStrategy()
ipu_strategy.set_graph_config(is_training=self.is_training)
ipu_strategy.set_options({'onnx_dump_path': 'onnx_dump_path.onnx'})
self.run_op_test(exec_mode, ipu_strategy=ipu_strategy)
def test(self):
for m in IPUOpTest.ExecutionMode:
if not self.skip_mode(m):
self.build_model()
self.run_model(m)
self.check()
class TestCase1(TestBase):
@IPUOpTest.static_graph
def build_model(self):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype='float32')
array = np.random.uniform(size=[3]).astype(np.float32)
result1 = paddle.zeros(shape=[3], dtype='float32')
weight = paddle.assign(array, result1)
out = F.prelu(x, weight=weight, **self.attrs)
self.fetch_list = [out.name]
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册