提交 3b4eda4f 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Remove check for input constant before running tf2xla op from MLIR

This check can be removed since tf2xla can run ops with non-const input even if CompileTimeConstant attribute is set with the help of valueinference.

PiperOrigin-RevId: 564851049
上级 c952951f
......@@ -178,7 +178,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
}
func.func @non_const_inputs(%arg0: tensor<2x2xf64>, %arg1: tensor<f64>, %arg2: tensor<2xi32>, %arg3: tensor<2xi32>, %arg4: tensor<2xi32>) -> tensor<6x5xf64> {
// expected-remark@+1 {{lowering requires operand #2 to be a constant}}
// expected-remark@+1 {{compilation to HLO failed: INVALID_ARGUMENT: Input 2 to node `tf.XlaPad` with op XlaPad must be a compile-time constant.}}
%0 = "tf.XlaPad"(%arg0, %arg1, %arg2, %arg3, %arg4) : (tensor<2x2xf64>, tensor<f64>, tensor<2xi32>, tensor<2xi32>, tensor<2xi32>) -> tensor<6x5xf64>
func.return %0 : tensor<6x5xf64>
}
......
......@@ -277,11 +277,6 @@ LogicalResult Tf2XlaRewriter::PrepareKernelInputs(
tensorflow::XlaExpression expr = GetExprForOperand(operand, op_, idx);
tensorflow::XlaExpression::Kind kind = expr.kind();
if (kind == tensorflow::XlaExpression::Kind::kInvalid) return failure();
if (required_consts.count(idx) &&
kind != tensorflow::XlaExpression::Kind::kConstant) {
return op_->emitRemark()
<< "lowering requires operand #" << idx << " to be a constant";
}
expressions.push_back(expr);
if (!tensorflow::DataTypeCanUseMemcpy(expr.dtype())) {
......
......@@ -286,6 +286,19 @@ TEST_F(Tf2XlaRewriterTest, InsertsConstantParameters) {
TF_ASSERT_OK(LegalizeModule(kModuleWithConstParam));
}
TEST_F(Tf2XlaRewriterTest, DoesntEnforceCompileTimeConstantCheck) {
static constexpr char kModuleWithNonConstParam[] = R"(
module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1610 : i32}} {
func.func @main(%arg0: tensor<3x3x10xbf16>, %arg1: tensor<3xi32>) -> tensor<1x?x4xbf16> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1,_arg2", outputs = "_retval0"}} {
%cst = "tf.Const"() {value = dense<[1, -1, 4]> : tensor<3xi32>} : () -> tensor<3xi32>
%0 = "tf.Slice"(%arg0, %arg1, %cst) {_XlaHasReferenceVars = false, _xla_inferred_shapes = [#tf_type.shape<1x?x4>], device = "/job:localhost/replica:0/task:0/device:TPU:0"} : (tensor<3x3x10xbf16>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x?x4xbf16>
return %0 : tensor<1x?x4xbf16>
}
})";
TF_ASSERT_OK(LegalizeModule(kModuleWithNonConstParam));
}
TEST_F(Tf2XlaRewriterTest, ErrorsWithInvalidNumberOfParametersToArgs) {
XlaBuilder builder("test_builder");
XlaComputation to_apply;
......
......@@ -20,7 +20,6 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
......@@ -119,13 +118,7 @@ class CompilerIrTest(xla_test.XLATestCase):
args = [ops.convert_to_tensor([1, 2, 3, 4])]
args_spec = nest.map_structure(tensor.TensorSpec.from_tensor, args)
concrete_fn = f2.get_concrete_function(*args_spec)
if test_util.is_mlir_bridge_enabled():
with self.assertRaisesRegex(
ValueError, 'TF to XLA legalization failed'
):
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
else:
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
_ = compiler_ir.from_concrete_function(concrete_fn)(stage='hlo')
def test_make_handledata_tensor_specs(self):
with ops.device('device:{}:0'.format(self.device)):
......@@ -174,9 +167,6 @@ class CompilerIrTest(xla_test.XLATestCase):
self._compareTwoMethodsCompilerIROutput(f4, [], kwargs)
def test_capture_variable_2(self):
if not test_util.is_mlir_bridge_enabled():
self.skipTest('Non_milr_bridge will fail here.')
if 'gpu' in self.device.lower():
self.skipTest('Skip test on GPU')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册