diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index cef7d1961c45c0bf6af85e33dd9a7f5225fdaacb..288a63ecc3d9e9a169262d3255edaa53aaa7a528 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -625,6 +625,7 @@ cc_library( "//tensorflow/c:tf_status", "//tensorflow/c/eager:c_api", "//tensorflow/core:framework", + "//tensorflow/core:lib", "//tensorflow/stream_executor", "//tensorflow/stream_executor/lib", "@llvm//:support", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc index 0ef0072390d16b048285ccebb0589b3325dff3ad..11eafdede08e6763e2a871c36f7297a2f1a95f44 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/c/eager/c_api.h" #include "tensorflow/c/tf_status.h" #include "tensorflow/compiler/mlir/tensorflow/utils/eval_util.h" +#include "tensorflow/core/platform/mutex.h" namespace mlir { namespace TF { @@ -59,6 +60,10 @@ LogicalResult ConstantFoldFallbackHook( inputs.push_back(input.cast()); } + // Avoid overlapping folds with the same context. + // TODO(jpienaar): Avoid using global context & mutex here. + static auto* mu = new tensorflow::mutex(); + tensorflow::mutex_lock l(*mu); return tensorflow::EvaluateOperation(inst, inputs, ctx, &results); }