From 1bcae5d84b937ea17b70ff25824ea292b8d95f4f Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Mon, 26 Nov 2018 09:31:56 -0800 Subject: [PATCH] StridedSlice op + some unit tests Fix typo Refactor. Add Ok unit tests Improve unit tests, comments. --- .../contrib/tensorrt/convert/convert_graph.cc | 65 ++-- .../contrib/tensorrt/convert/convert_nodes.cc | 251 ++++++++++++++- .../tensorrt/convert/convert_nodes_test.cc | 302 +++++++++++++++++- 3 files changed, 582 insertions(+), 36 deletions(-) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index ae211a93c32..623cd79f32e 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -89,51 +89,52 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) { // TODO(laigd): move this set to TrtNodeValidator where it should belong. // LINT.IfChange static const std::set candidate_ops = { - "Identity", - "Snapshot", - "Const", - "Conv2D", - "MaxPool", - "BiasAdd", - "Relu", - "Sigmoid", - "Tanh", + "Abs", "Add", - "Mul", - "Sub", - "Rsqrt", - "Pad", - "Mean", "AvgPool", + "BatchMatMul", + "BiasAdd", "ConcatV2", + "Const", + "Conv2D", "DepthwiseConv2dNative", - "FusedBatchNorm", - "FusedBatchNormV2", "Div", - "RealDiv", - "Rsqrt", - "Reciprocal", "Exp", + "ExpandDims", + "FusedBatchNorm", + "FusedBatchNormV2", + "Identity", "Log", - "Sqrt", - "Abs", - "Neg", - "Transpose", - "Reshape", "MatMul", - "BatchMatMul", - "Softmax", - "Minimum", - "Maximum", - "TopKV2", - "Sum", - "Prod", "Max", + "MaxPool", + "Maximum", + "Mean", "Min", + "Minimum", + "Mul", + "Neg", + "Pad", + "Prod", + "RealDiv", + "Reciprocal", + "Relu", "Relu6", + "Reshape", + "Rsqrt", + "Rsqrt", + "Sigmoid", + "Snapshot", + "Softmax", + "Sqrt", "Square", - "ExpandDims", "Squeeze", + "StridedSlice", + "Sub", + "Sum", + "Tanh", + "TopKV2", + "Transpose", }; bool is_supported_op_type = (candidate_ops.count(node->type_string()) || diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 561ea37daeb..fdecfe59285 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -632,6 +632,11 @@ bool TFAttrs::get(const string& key) const { return this->at(key)->b(); } +template <> +int TFAttrs::get(const string& key) const { + return this->at(key)->i(); +} + // TODO(jie): reorder4 & reorder2 should be merged? // TODO(aaroey): fix the order of parameters. template @@ -2028,6 +2033,245 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) { return tensorflow::Status::OK(); } +tensorflow::Status GetStridedSliceBound( + const std::vector& input_dims, + const TRT_ShapedWeights& bound_weights, + string bound_name, + string node_name, + std::vector& output_bound) { + const int* weights_ptr = + static_cast(const_cast(bound_weights.GetValues())); + output_bound = std::vector(weights_ptr, + weights_ptr + bound_weights.count()); + if (output_bound.size() != input_dims.size()) { + return tensorflow::errors::InvalidArgument( + "StridedSlice \"", bound_name, "\" specified ", + std::to_string(output_bound.size()), " dimensions, but input rank is ", + std::to_string(input_dims.size()), ", at ", node_name); + } + for (int i = 0; i < output_bound.size(); i++) { + // Make sure bound is valid. + if ((output_bound[i] < -input_dims[i]) || + (output_bound[i] > input_dims[i])) { + return tensorflow::errors::InvalidArgument( + bound_name, " for StridedSlice is invalid, must be in the range " + "[-rank(input), rank(input)], at ", node_name); + } + // Convert negative values to their positive equivalent. + if (output_bound[i] < 0) { + output_bound[i] += input_dims[i]; + } + } + return tensorflow::Status::OK(); +} + +tensorflow::Status ConvertStridedSlice(OpConverterParams* params) { + const auto& inputs = params->inputs; + const auto& node_def = params->node_def; + if (inputs.size() != 4) { + return tensorflow::errors::InvalidArgument( + "StridedSlice expects 4 inputs, at ", node_def.name()); + } + if (!inputs.at(1).is_weights() || + !inputs.at(2).is_weights() || + !inputs.at(3).is_weights()) { + return tensorflow::errors::InvalidArgument( + "StridedSlice expects weights for begin, end, and strides, at ", + node_def.name()); + } + if (!inputs.at(0).is_tensor()) { + return tensorflow::errors::Unimplemented( + "StridedSlice is only implemented for tensors, at ", + node_def.name()); + } + // Get input dims. + nvinfer1::Dims dims = inputs.at(0).GetTrtDims(); + std::vector input_dims(dims.d, dims.d + dims.nbDims); + if (inputs.at(0).is_tensor()) { + // Temporarily add batch dimension so that indexes line up properly. + input_dims.insert(input_dims.begin(), inputs.at(0).batch_size()); + } + if (input_dims.size() > 4) { + return tensorflow::errors::Unimplemented( + "StridedSlice is not implemented for tensors with rank > 4, at ", + node_def.name()); + } + TFAttrs attrs(node_def); + // Get begin and end bounds per axis. + std::vector begin, end; + TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(1).weights(), + "begin", node_def.name(), begin)); + TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(2).weights(), + "end", node_def.name(), end)); + int begin_mask = attrs.get("begin_mask"); + for (int i = 0; i < begin.size(); i++) { + if ((1 << i) & begin_mask) { + begin[i] = 0; + } + } + int end_mask = attrs.get("end_mask"); + for (int i = 0; i < end.size(); i++) { + if ((1 << i) & end_mask) { + end[i] = input_dims[i]; + } + } + // Get strides per axis (must all be 1). + TRT_ShapedWeights stride_weights = inputs.at(3).weights(); + const int* stride_weights_ptr = + static_cast(const_cast(stride_weights.GetValues())); + std::vector strides(stride_weights_ptr, + stride_weights_ptr + stride_weights.count()); + for (int x : strides) { + if (x != 1) { + return tensorflow::errors::Unimplemented( + "StridedSlice is only implemented for stride of 1, at ", + node_def.name()); + } + } + // Unsupported options. + for (string attr : {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) { + int ellipsis_mask = attrs.get(attr); + if (ellipsis_mask != 0) { + return tensorflow::errors::Unimplemented( + attr, " is not implemented for StridedSlice, at ", + node_def.name()); + } + } + + nvinfer1::ITensor* tensor = const_cast( + inputs.at(0).tensor()); + // Reshape if necessary to 4-D. + const bool need_reshape = (input_dims.size() != 4); + int reshape_dims_added = 0; + nvinfer1::Dims reshape_dims; + if (need_reshape) { + // Add new dims after batch dim until tensor is 4D. + while (input_dims.size() < 4) { + input_dims.insert(input_dims.begin()+1, 1); + begin.insert(begin.begin()+1, 0); + end.insert(end.begin()+1, 1); + reshape_dims_added++; + } + reshape_dims = VectorToTrtDims(input_dims, /*ignore_first_dim=*/true); + } + // Find dimensions which need to be sliced. + std::vector pad_dims; + for (int i = 0; i < input_dims.size(); i++) { + if (begin[i] != 0 || (end[i] - input_dims[i]) != 0) { + if (i == 0) { + return tensorflow::errors::Unimplemented( + "StridedSlice can't modify batch dim, at ", node_def.name()); + } + else if ((end[i] - begin[i]) < 0) { + LOG(INFO) << begin[i] << ", " << end[i]; + return tensorflow::errors::InvalidArgument( + "New size of sliced dimension is negative, at ", node_def.name()); + } + pad_dims.push_back(i); + } + } + if (pad_dims.size() == 0) { + // No dimensions are changed. We could create a padding layer anyway with + // values of 0. + if (params->validation_only) return Status::OK(); + params->outputs->push_back(inputs.at(0)); + return tensorflow::Status::OK(); + } else if (pad_dims.size() == 1) { + // Only one dim is modified but we have to have 2, mark a second dim which + // will have padding of 0. + if (pad_dims[0] == 1 || pad_dims[0] == 3) { + pad_dims.push_back(2); + } else if (pad_dims[0] == 2) { + pad_dims.push_back(3); + } + } else if (pad_dims.size() > 2) { + return tensorflow::errors::Unimplemented( + "StridedSlice can only modify 2 dimensions, at ", + node_def.name()); + } + std::sort(pad_dims.begin(), pad_dims.end()); + // Convert to pre/post padding values. + nvinfer1::DimsHW pre_padding, post_padding; + for (int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + pre_padding.d[i] = -begin[axis]; + post_padding.d[i] = end[axis] - input_dims[axis]; + } + + // IPaddingLayer will always apply the padding to dims 2,3 (input format is + // NCHW). + const bool need_transpose = !(pad_dims[0] == 2 && pad_dims[1] == 3); + std::vector transpose_order(input_dims.size()); + std::vector inv_transpose_order(input_dims.size()); + if (need_transpose) { + if (pad_dims[0] == 1 && pad_dims[1] == 3) { + transpose_order = {0, 2, 1, 3}; + inv_transpose_order = {0, 2, 1, 3}; + } else if (pad_dims[0] == 1 && pad_dims[1] == 2) { + transpose_order = {0, 3, 1, 2}; + inv_transpose_order = {0, 2, 3, 1}; + } + } + if (params->validation_only) return Status::OK(); + + // Start conversion. + if (need_reshape) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + inputs.at(0), reshape_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + + // Add padding layer + nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding( + *const_cast(tensor), pre_padding, post_padding); + TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name()); + tensor = layer->getOutput(0); + + // Restore transpose + if (need_transpose) { + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->TransposeTensor( + tensor, inv_transpose_order, &output_tensor)); + tensor = const_cast(output_tensor); + } + // Restore reshape + if (need_reshape) { + // Calculate output dimensions + for(int i = 0; i < pad_dims.size(); i++) { + const int axis = pad_dims[i]; + input_dims[axis] = end[axis] - begin[axis]; + } + // Remove added 1 dimensions + for (int i = 0; i < reshape_dims_added; i++) { + int value = input_dims[1]; + if (value != 1) { + return tensorflow::errors::Internal( + "StridedSlice error when reshaping, at ", + node_def.name()); + } + input_dims.erase(input_dims.begin()+1); + } + + nvinfer1::Dims new_dims = VectorToTrtDims(input_dims, + /*ignore_first_dim=*/true); + const nvinfer1::ITensor* output_tensor = nullptr; + TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape( + TRT_TensorOrWeights(tensor), new_dims, &output_tensor)); + tensor = const_cast(output_tensor); + } + + params->outputs->push_back( + TRT_TensorOrWeights(const_cast(tensor))); + return tensorflow::Status::OK(); +} + tensorflow::Status ConvertConv2D(OpConverterParams* params) { return ConvertConv2DHelper(params, ConvolutionType::DEFAULT); } @@ -3335,14 +3579,15 @@ static void RegisterValidatableOpConverters( (*registration)["Const"] = ConvertConst; (*registration)["Conv2D"] = ConvertConv2D; (*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise; - (*registration)["Transpose"] = ConvertTranspose; - (*registration)["Reshape"] = ConvertReshape; + (*registration)["ExpandDims"] = ConvertExpandDims; (*registration)["MatMul"] = ConvertMatMul; (*registration)["Pad"] = ConvertPad; (*registration)["Relu6"] = ConvertRelu6; + (*registration)["Reshape"] = ConvertReshape; (*registration)["Square"] = ConvertSquare; - (*registration)["ExpandDims"] = ConvertExpandDims; (*registration)["Squeeze"] = ConvertSqueeze; + (*registration)["StridedSlice"] = ConvertStridedSlice; + (*registration)["Transpose"] = ConvertTranspose; for (auto quantization_op_type : {"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3", diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc index c37a43dd5de..07649f04b29 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes_test.cc @@ -2129,7 +2129,6 @@ TEST_F(OpConverterTest, ConvertExpandDims) { auto expanddims = ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights); const NodeDef& node_def = expanddims.operation.node()->def(); - { // Input is weights, should fail. Reset(); @@ -2349,6 +2348,307 @@ TEST_F(OpConverterTest, ConvertSqueeze) { } } +TEST_F(OpConverterTest, ConvertStridedSlice) { + { + // Input list is empty, should fail. + NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "StridedSlice expects 4 inputs, at my_strided_slice"); + } + + // Get nodedef for StridedSlice layer. + auto get_strided_slice_nodedef = [](int begin_mask = 0, + int ellipsis_mask = 0, + int end_mask = 0, + int new_axis_mask = 0, + int shrink_axis_mask = 0) -> NodeDef { + Scope s = Scope::NewRootScope(); + auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); + auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32); + auto end = ops::Placeholder(s.WithOpName("end"), DT_INT32); + auto strides = ops::Placeholder(s.WithOpName("strides"), DT_INT32); + ops::StridedSlice::Attrs strided_slice_attrs; + strided_slice_attrs.begin_mask_ = begin_mask; + strided_slice_attrs.ellipsis_mask_ = ellipsis_mask; + strided_slice_attrs.end_mask_ = end_mask; + strided_slice_attrs.new_axis_mask_ = new_axis_mask; + strided_slice_attrs.shrink_axis_mask_ = shrink_axis_mask; + auto strided_slice = ops::StridedSlice(s.WithOpName("my_strided_slice"), + input, begin, end, strides, strided_slice_attrs); + return strided_slice.operation.node()->def(); + }; + + { + NodeDef node_def = get_strided_slice_nodedef(); + AddTestWeights("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "StridedSlice is only implemented for tensors, at my_strided_slice"); + } + { + // Begin, end, strides are tensors, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestTensor("begin", {4}); + AddTestTensor("end", {4}); + AddTestTensor("strides", {4}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "StridedSlice expects weights for begin, end, and strides, at " + "my_strided_slice"); + } + { + // Non-zero ellipsis_mask, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0, + /*ellipsis_mask=*/2, /*end_mask=*/0, /*new_axis_mask=*/0, + /*shrink_axis_mask=*/0); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "ellipsis_mask is not implemented for StridedSlice, at " + "my_strided_slice"); + } + { + // Non-zero ellipsis_mask, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0, + /*ellipsis_mask=*/0, /*end_mask=*/0, /*new_axis_mask=*/2, + /*shrink_axis_mask=*/0); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "new_axis_mask is not implemented for StridedSlice, at " + "my_strided_slice"); + } + { + // Non-zero shrink_axis_mask, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0, + /*ellipsis_mask=*/0, /*end_mask=*/0, /*new_axis_mask=*/0, + /*shrink_axis_mask=*/2); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "shrink_axis_mask is not implemented for StridedSlice, at " + "my_strided_slice"); + } + { + // Modify batch dim, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {0, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, + "StridedSlice can't modify batch dim, at my_strided_slice"); + } + { + // Stride is not 1, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 2, -1, 3}); + RunValidationAndConversion( + node_def, error::UNIMPLEMENTED, "StridedSlice is only implemented for " + "stride of 1, at my_strided_slice"); + } + { + // Begin out of bounds, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {1, 2, 3, 4}); + AddTestWeights("end", {4}, {0, 1, 2, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "begin for StridedSlice is invalid, must be in the range " + "[-rank(input), rank(input)], at my_strided_slice"); + } + { + // End out of bounds, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 0, 0}); + AddTestWeights("end", {4}, {1, 2, 3, 4}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "end for StridedSlice is invalid, must be in the range " + "[-rank(input), rank(input)], at my_strided_slice"); + } + { + // Size of sliced dim is negative, should fail. + Reset(); + NodeDef node_def = get_strided_slice_nodedef(); + AddTestTensor("input", {1, 2, 3}); + AddTestWeights("begin", {4}, {0, 0, 2, 0}); + AddTestWeights("end", {4}, {1, 1, 0, 3}); + AddTestWeights("strides", {4}, {1, 1, 1, 1}); + RunValidationAndConversion( + node_def, error::INVALID_ARGUMENT, + "New size of sliced dimension is negative, at my_strided_slice"); + } + + struct TestParams { + TestParams(const std::vector& input_dims, + const std::vector& expected_output_dims, + const std::vector& begin, + const std::vector& end, + const std::vector& begin_mask, + const std::vector& end_mask, + const std::vector& expected_output) + : input_dims(input_dims), + expected_output_dims(expected_output_dims), + begin(begin), + end(end), + expected_output(expected_output) { + // Masks are provided in terms of vectors for readability. Convert them to + // binary here. + this->begin_mask = 0; + for (int i = 0; i < begin_mask.size(); i++) { + if (begin_mask[i]) this->begin_mask |= (1 << i); + } + this->end_mask = 0; + for (int i = 0; i < end_mask.size(); i++) { + if (end_mask[i]) this->end_mask |= (1 << i); + } + } + + std::vector input_dims; + std::vector expected_output_dims; + std::vector begin; + std::vector end; + int begin_mask; + int end_mask; + std::vector expected_output; + }; + + // Ok. + const int kStridedSliceOKCases = 18; + TestParams ok_params[kStridedSliceOKCases] = { + // 2D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 1, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{5, 6}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with transpose. + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1}, + /*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{5, 6}}, + TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2}, + /*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0}, + /*expected_output=*/{5, 6}}, + // 2D Crop, with reshape. + TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, + /*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 0}, + /*expected_output=*/{1, 2}}, + TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2}, + /*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 1}, + /*expected_output=*/{5, 6}}, + // 1D Crop. + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 2, 2}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 0}, + /*expected_output=*/{1, 2, 4, 5}}, + TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 3}, + /*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with transpose. + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, + /*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 1, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1}, + /*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0}, + /*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1}, + /*expected_output=*/{4, 5, 6}}, + // 1D Crop, with reshape. + TestParams{/*input_dims=*/{6}, /*expected_output_dims=*/{3}, + /*begin=*/{0, 0}, /*end=*/{0, 3}, + /*begin_mask=*/{0, 0}, /*end_mask=*/{1, 0}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{1, 6}, /*expected_output_dims=*/{1, 3}, + /*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 0}, + /*expected_output=*/{3, 4, 5}}, + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, + /*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{3, 4, 5}}, + // Negative axis. + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1}, + /*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{1, 2, 3}}, + TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{5, 1}, + /*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0}, + /*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1}, + /*expected_output=*/{1, 2, 3, 4, 5}}, + }; + + for (int i = 0; i < kStridedSliceOKCases; i++) { + Reset(); + NodeDef node_def = get_strided_slice_nodedef(ok_params[i].begin_mask, 0, + ok_params[i].end_mask); + AddTestTensor("input", ok_params[i].input_dims); + AddTestWeights("begin", {ok_params[i].begin.size()}, + ok_params[i].begin); + AddTestWeights("end", {ok_params[i].end.size()}, ok_params[i].end); + std::vector strides(ok_params[i].input_dims.size(), 1); + AddTestWeights("strides", {strides.size()}, strides); + RunValidationAndConversion(node_def); + + TRT_TensorOrWeights output; + TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output)); + std::vector output_data(ok_params[i].expected_output.size()); + BuildAndRun({{"input", {1, 2, 3, 4, 5, 6}}}, "my_strided_slice", + &output_data); + EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output)); + } +} + } // namespace convert } // namespace tensorrt } // namespace tensorflow -- GitLab