未验证 提交 60531231 编写于 作者: G G. Ramalingam 提交者: GitHub

Simplify function definition of context-dependent functions (#3882)

* Simplify function definition of context-dependent functions
Signed-off-by: NGanesan Ramalingam <grama@microsoft.com>

* Add missing parenthesis
Signed-off-by: NGanesan Ramalingam <grama@microsoft.com>

* Fix errors in function defs
Signed-off-by: NGanesan Ramalingam <grama@microsoft.com>

* Eliminate unused variable
Signed-off-by: NGanesan Ramalingam <grama@microsoft.com>

* Add int64 type specifier to literal
Signed-off-by: NGanesan Ramalingam <grama@microsoft.com>
Co-authored-by: NAshwini Khade <askhade@microsoft.com>
上级 0a894222
......@@ -14,6 +14,7 @@
#include "onnx/common/status.h"
#include "onnx/defs/schema.h"
#include "tensor_proto_util.h"
#include "onnx/defs/parser.h"
namespace ONNX_NAMESPACE {
// Helper function to expand a function node given the function proto
......@@ -106,4 +107,83 @@ class FunctionBodyHelper {
}
};
class FunctionBuilder {
public:
FunctionBuilder(FunctionProto& funProto_) : funProto(funProto_) {}
FunctionBuilder& Add(const char* nodes_txt) {
OnnxParser parser(nodes_txt);
auto& nodes = *funProto.mutable_node();
while (!parser.EndOfInput()) {
auto status = parser.Parse(*nodes.Add());
if (!status.IsOK())
ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
}
return *this;
}
FunctionBuilder& Add(const char* node_txt, const AttributeProto& attr) {
OnnxParser parser(node_txt);
auto& node = *funProto.add_node();
auto status = parser.Parse(node);
if (!status.IsOK()) {
ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
}
if (!parser.EndOfInput()) {
ONNX_THROW_EX(std::logic_error("Error unexpected extra input in node:" + status.ErrorMessage()));
}
*node.add_attribute() = attr;
return *this;
}
template <typename T>
FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, T attr_value) {
return Add(node_txt, MakeAttribute(attr_name, attr_value));
}
// Creates a scalar constant (a tensor of rank zero).
template <typename T>
FunctionBuilder& Const(const std::string& name, T const_value) {
std::string constant_op(name);
constant_op += " = Constant()";
return Add (constant_op.c_str(), MakeAttribute("value", ToTensor(const_value)));
}
// Creates a 1D tensor constant consisting of a single value.
template <typename T>
FunctionBuilder& Const1D(const std::string& name, T const_value) {
std::string constant_op(name);
constant_op += " = Constant()";
auto tensor = ToTensor(const_value);
tensor.add_dims(1);
return Add (constant_op.c_str(), MakeAttribute("value", tensor));
}
// Creates a 1D tensor constant consisting of zero or more values.
template <typename T>
FunctionBuilder& Const(const std::string& name, const std::vector<T>& values) {
std::string constant_op(name);
constant_op += " = Constant()";
auto tensor = ToTensor(values);
tensor.add_dims(values.size()); // Treat as 1D tensor.
return Add (constant_op.c_str(), MakeAttribute("value", tensor));
}
FunctionBuilder& AddOpset(const char* domain, int version) {
auto* opset = funProto.add_opset_import();
opset->set_domain(domain);
opset->set_version(version);
return *this;
}
private:
FunctionProto& funProto;
};
} // namespace ONNX_NAMESPACE
......@@ -892,36 +892,11 @@ ONNX_OPERATOR_SET_SCHEMA(
auto dtype = ctx.getAttribute("dtype") != nullptr
? static_cast<TensorProto_DataType>(ctx.getAttribute("dtype")->i())
: input_type;
auto seed_attr = ctx.getAttribute("seed");
std::vector<FunctionBodyHelper::NodeDef> body{
// nodes: {outputs, op, inputs, attributes}
// clang-format off
{
{"X_greater"},
"Greater",
{"X_random", "input"}
},
{
{"output"},
"Cast",
{"X_greater"},
{MakeAttribute("to", (int64_t)(dtype))}
}
// clang-format on
};
if (seed_attr != nullptr) {
float seed = seed_attr->f();
body.insert(body.begin(), {{"X_random"}, "RandomUniformLike", {"input"}, {MakeAttribute("high", 1.0f), MakeAttribute("low", 0.f), MakeAttribute("seed", (float)(seed)), MakeAttribute("dtype", (int64_t)(input_type))}});
} else {
body.insert(body.begin(), {{"X_random"}, "RandomUniformLike", {"input"}, {MakeAttribute("high", 1.0f), MakeAttribute("low", 0.f), MakeAttribute("dtype", (int64_t)(input_type))}});
}
auto func_nodes = FunctionBodyHelper::BuildNodes(body);
for (const auto& node : func_nodes) {
auto new_node = functionProto.add_node();
new_node->CopyFrom(node);
}
FunctionBuilder builder(functionProto);
builder
.Add("X_random = RandomUniformLike <low = 0.0, high = 1.0, seed = @seed> (input)", "dtype", int64_t(input_type))
.Add("X_greater = Greater (X_random, input)")
.Add("output = Cast (X_greater)", "to", int64_t(dtype));
schema.BuildFunction(functionProto);
return true;
}));
......
此差异已折叠。
......@@ -196,11 +196,10 @@ ONNX_OPERATOR_SET_SCHEMA(
return false;
}
auto target_elt_type = target_type->tensor_type().elem_type();
std::vector<FunctionBodyHelper::NodeDef> body{
// nodes: {outputs, op, inputs, attributes}
{ {"output"}, "Cast", {"input"}, {MakeAttribute("to", (int64_t)(target_elt_type))} }
};
return FunctionBodyHelper::BuildFunctionProto(functionProto, schema, body, {});
FunctionBuilder builder(functionProto);
builder.Add("output = Cast (input)", "to", (int64_t)(target_elt_type));
schema.BuildFunction(functionProto);
return true;
}));
static const char* Reshape_ver14_doc = R"DOC(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册