From 3e2c10870815b3bb4106734bbe1344ac3115be8e Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 18 Dec 2019 17:31:10 -0800 Subject: [PATCH] Avoid using constant folding fallback hook concurrently Ran into case where failed constant folds resulted in non-deterministic test behavior (e.g., failed to fold unrelated op). Need to dig into this a bit more. PiperOrigin-RevId: 286299132 Change-Id: I290dade5fba0fd5b08dae4d696878b9196fc516d --- tensorflow/compiler/mlir/tensorflow/BUILD | 1 + .../compiler/mlir/tensorflow/transforms/constant_fold.cc | 5 +++++ 2 files changed, 6 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index cef7d1961c4..288a63ecc3d 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -625,6 +625,7 @@ cc_library( "//tensorflow/c:tf_status", "//tensorflow/c/eager:c_api", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/stream_executor", "//tensorflow/stream_executor/lib", "@llvm//:support", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 0ef0072390d..11eafdede08 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h" +#include "tensorflow/core/platform/mutex.h" namespace mlir { namespace TF { @@ -59,6 +60,10 @@ LogicalResult ConstantFoldFallbackHook( inputs.push_back(input.cast()); } + // Avoid overlapping folds with the same context. + // TODO(jpienaar): Avoid using global context & mutex here. + static auto* mu = new tensorflow::mutex(); + tensorflow::mutex_lock l(*mu); return tensorflow::EvaluateOperation(inst, inputs, ctx, &results); } -- GitLab