diff --git a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc index a849aa418c633a3e3cbe4a6fdf43e7f76d7ad7c8..b8f9058d8f60285893991cc6ee3fa6bdcb453743 100644 --- a/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc +++ b/tensorflow/compiler/tf2tensorrt/ops/trt_engine_op.cc @@ -33,8 +33,6 @@ namespace tensorflow { // key to cache the instantiated functions for different executor subgraphs. REGISTER_OP("TRTEngineOp") .Attr("serialized_segment: string") - .Attr("input_shapes: list(shape)") - .Attr("output_shapes: list(shape)") .Attr("segment_funcdef_name: string") .Attr("InT: list({int8,float16,float32,int32})") .Attr("OutT: list({int8,float16,float32,int32})") @@ -55,6 +53,8 @@ REGISTER_OP("TRTEngineOp") // Deprecated attributes. .Attr("cached_engine_batches: list(int) >= 0 = []") .Attr("fixed_input_size: bool = true") + .Attr("input_shapes: list(shape)") + .Attr("output_shapes: list(shape)") .Attr("static_engine: bool = true"); } // namespace tensorflow