提交 89a8b192 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fuse sgnn projection op

PiperOrigin-RevId: 328181545
Change-Id: I4e5d52e775a060fe686ee623986b423787171e81
上级 f55a491b
......@@ -3435,4 +3435,19 @@ func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_
}
// CHECK: func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
// CHECK: %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F720000006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E373E040104FF152C0204141404082401"> : tensor<77xi8>} : (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>)
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
\ No newline at end of file
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[[1902835825], [-1475704015], [473120514], [1254202069], [1558833093], [1756181982], [1906603252], [-1034142694], [542842690], [535515822]]> : tensor<10x1xi64>} : () -> tensor<10x1xi64>
%1 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 2147483647 : i64} : (tensor<?x!tf.string>) -> tensor<?xi64>
%2 = "tf.Sgnn"(%1, %0) {device = ""} : (tensor<?xi64>, tensor<10x1xi64>) -> tensor<10x?xf64>
%3 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64>
%4 = "tf.Reshape"(%2, %3) : (tensor<10x?xf64>, tensor<1xi64>) -> tensor<?x10xf64>
return %4 : tensor<?x10xf64>
}
// CHECK: func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
// CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "tftext:custom:SgnnProjection", custom_option = opaque<"tfl", "0x686173685F736565640000000A00000071F86A71318B0AA8023F331CD59AC14AC5E7E95CDE35AD68F474A4711A3C5CC2421F5B20AE52EB1F6275636B6574730002094200030000000100000002000000FFFFFF7F44000000062E0A2601"> : tensor<93xi8>} : (tensor<?x!tf.string>, tensor<?xi64>) -> tensor<?x10xf64>
// CHECK: return %0 : tensor<?x10xf64>
......@@ -47,6 +47,7 @@ namespace {
constexpr char kNgrams[] = "tftext:Ngrams";
constexpr char kWhitespaceTokenizer[] = "tftext:WhitespaceTokenizer";
constexpr char kCustomSgnnProjection[] = "tftext:custom:SgnnProjection";
constexpr char kTFImplements[] = "tf._implements";
using mlir::TF::FuncAttr;
......@@ -269,6 +270,85 @@ LogicalResult ConvertNgrams(FuncOp func, llvm::StringRef api, FuncAttr attr) {
return success();
}
LogicalResult VerifySgnnProjection(FuncOp func, FuncAttr attr) {
if (func.getType().getNumInputs() != 2 ||
func.getType().getNumResults() != 1) {
return func.emitError() << "Mismatched number of inputs and outputs.";
}
auto values_type = GetInputType(func, 0);
if (!values_type || !values_type.getElementType().isa<StringType>()) {
return func.emitError() << "First input should be a string tensor";
}
auto row_splits_type = GetInputType(func, 1);
if (!row_splits_type ||
!row_splits_type.getElementType().isa<IntegerType>()) {
return func.emitError() << "Second input should be an integer tensor";
}
auto hash_seed =
attr.GetAttrs().get("hash_seed").dyn_cast_or_null<ArrayAttr>();
if (!hash_seed) {
return func.emitError()
<< "'hash_seed' attribute is not set or not an array";
}
auto output_type = GetResultType(func, 0);
if (!output_type || !output_type.getElementType().isa<FloatType>() ||
!RankEquals(output_type, 2)) {
return func.emitError() << "Output should be a 2D float tensor.";
}
if (output_type.getDimSize(1) != hash_seed.size()) {
return func.emitError()
<< "Output 2nd dimension should be the num of hash seeds.";
}
auto buckets = attr.GetAttrs().get("buckets").dyn_cast_or_null<IntegerAttr>();
if (!buckets) {
return func.emitError() << "'buckets' attribute is not set or not int";
}
return success();
}
LogicalResult CreateSgnnProjectionCustomOption(
FuncOp func, DictionaryAttr attrs, std::string& custom_option_buffer) {
flexbuffers::Builder fbb;
size_t start_map = fbb.StartMap();
auto hash_seed = attrs.get("hash_seed").dyn_cast_or_null<ArrayAttr>();
auto vector_start = fbb.StartVector("hash_seed");
for (int i = 0; i < hash_seed.size(); i++) {
fbb.Add(static_cast<int32_t>(
(hash_seed.getValue().data() + i)->dyn_cast<IntegerAttr>().getInt()));
}
fbb.EndVector(vector_start, /*typed=*/true, /*fixed=*/false);
auto buckets = attrs.get("buckets").dyn_cast_or_null<IntegerAttr>();
fbb.Int("buckets", buckets.getInt());
fbb.EndMap(start_map);
fbb.Finish();
custom_option_buffer.assign(fbb.GetBuffer().begin(), fbb.GetBuffer().end());
return success();
}
LogicalResult ConvertSgnnProjection(FuncOp func, llvm::StringRef api,
FuncAttr attr) {
// See more details in tensorflow_models/sequence_projection/sgnn/sgnn.py
func.eraseBody();
func.addEntryBlock();
func.setAttr(kTFImplements, attr);
OpBuilder builder(func.getBody());
std::string custom_option_buffer;
if (failed(CreateSgnnProjectionCustomOption(func, attr.GetAttrs(),
custom_option_buffer))) {
return failure();
}
auto op = builder.create<CustomOp>(
func.getLoc(), func.getType().getResults(), func.getArguments(), api,
CustomOption(&builder, custom_option_buffer));
builder.create<ReturnOp>(func.getLoc(), op.getResults());
return success();
}
} // namespace
LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api,
......@@ -281,6 +361,10 @@ LogicalResult ConvertTFTextAPI(FuncOp func, llvm::StringRef api,
if (succeeded(VerifyNgrams(func))) {
return ConvertNgrams(func, api, attr);
}
} else if (api.str() == kCustomSgnnProjection) {
if (succeeded(VerifySgnnProjection(func, attr))) {
return ConvertSgnnProjection(func, api, attr);
}
}
return failure();
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册