提交 1bcae5d8 编写于 作者: T Trevor Morris

StridedSlice op + some unit tests

Fix typo

Refactor. Add Ok unit tests

Improve unit tests, comments.
上级 03e72140
......@@ -89,51 +89,52 @@ Status TrtCandidateSelector::IsTensorRTCandidate(const tensorflow::Node* node) {
// TODO(laigd): move this set to TrtNodeValidator where it should belong.
// LINT.IfChange
static const std::set<string> candidate_ops = {
"Identity",
"Snapshot",
"Const",
"Conv2D",
"MaxPool",
"BiasAdd",
"Relu",
"Sigmoid",
"Tanh",
"Abs",
"Add",
"Mul",
"Sub",
"Rsqrt",
"Pad",
"Mean",
"AvgPool",
"BatchMatMul",
"BiasAdd",
"ConcatV2",
"Const",
"Conv2D",
"DepthwiseConv2dNative",
"FusedBatchNorm",
"FusedBatchNormV2",
"Div",
"RealDiv",
"Rsqrt",
"Reciprocal",
"Exp",
"ExpandDims",
"FusedBatchNorm",
"FusedBatchNormV2",
"Identity",
"Log",
"Sqrt",
"Abs",
"Neg",
"Transpose",
"Reshape",
"MatMul",
"BatchMatMul",
"Softmax",
"Minimum",
"Maximum",
"TopKV2",
"Sum",
"Prod",
"Max",
"MaxPool",
"Maximum",
"Mean",
"Min",
"Minimum",
"Mul",
"Neg",
"Pad",
"Prod",
"RealDiv",
"Reciprocal",
"Relu",
"Relu6",
"Reshape",
"Rsqrt",
"Rsqrt",
"Sigmoid",
"Snapshot",
"Softmax",
"Sqrt",
"Square",
"ExpandDims",
"Squeeze",
"StridedSlice",
"Sub",
"Sum",
"Tanh",
"TopKV2",
"Transpose",
};
bool is_supported_op_type =
(candidate_ops.count(node->type_string()) ||
......
......@@ -632,6 +632,11 @@ bool TFAttrs::get<bool>(const string& key) const {
return this->at(key)->b();
}
template <>
int TFAttrs::get<int>(const string& key) const {
return this->at(key)->i();
}
// TODO(jie): reorder4 & reorder2 should be merged?
// TODO(aaroey): fix the order of parameters.
template <typename T>
......@@ -2028,6 +2033,245 @@ tensorflow::Status ConvertSqueeze(OpConverterParams* params) {
return tensorflow::Status::OK();
}
tensorflow::Status GetStridedSliceBound(
const std::vector<int>& input_dims,
const TRT_ShapedWeights& bound_weights,
string bound_name,
string node_name,
std::vector<int>& output_bound) {
const int* weights_ptr =
static_cast<int*>(const_cast<void*>(bound_weights.GetValues()));
output_bound = std::vector<int>(weights_ptr,
weights_ptr + bound_weights.count());
if (output_bound.size() != input_dims.size()) {
return tensorflow::errors::InvalidArgument(
"StridedSlice \"", bound_name, "\" specified ",
std::to_string(output_bound.size()), " dimensions, but input rank is ",
std::to_string(input_dims.size()), ", at ", node_name);
}
for (int i = 0; i < output_bound.size(); i++) {
// Make sure bound is valid.
if ((output_bound[i] < -input_dims[i]) ||
(output_bound[i] > input_dims[i])) {
return tensorflow::errors::InvalidArgument(
bound_name, " for StridedSlice is invalid, must be in the range "
"[-rank(input), rank(input)], at ", node_name);
}
// Convert negative values to their positive equivalent.
if (output_bound[i] < 0) {
output_bound[i] += input_dims[i];
}
}
return tensorflow::Status::OK();
}
tensorflow::Status ConvertStridedSlice(OpConverterParams* params) {
const auto& inputs = params->inputs;
const auto& node_def = params->node_def;
if (inputs.size() != 4) {
return tensorflow::errors::InvalidArgument(
"StridedSlice expects 4 inputs, at ", node_def.name());
}
if (!inputs.at(1).is_weights() ||
!inputs.at(2).is_weights() ||
!inputs.at(3).is_weights()) {
return tensorflow::errors::InvalidArgument(
"StridedSlice expects weights for begin, end, and strides, at ",
node_def.name());
}
if (!inputs.at(0).is_tensor()) {
return tensorflow::errors::Unimplemented(
"StridedSlice is only implemented for tensors, at ",
node_def.name());
}
// Get input dims.
nvinfer1::Dims dims = inputs.at(0).GetTrtDims();
std::vector<int> input_dims(dims.d, dims.d + dims.nbDims);
if (inputs.at(0).is_tensor()) {
// Temporarily add batch dimension so that indexes line up properly.
input_dims.insert(input_dims.begin(), inputs.at(0).batch_size());
}
if (input_dims.size() > 4) {
return tensorflow::errors::Unimplemented(
"StridedSlice is not implemented for tensors with rank > 4, at ",
node_def.name());
}
TFAttrs attrs(node_def);
// Get begin and end bounds per axis.
std::vector<int> begin, end;
TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(1).weights(),
"begin", node_def.name(), begin));
TF_RETURN_IF_ERROR(GetStridedSliceBound(input_dims, inputs.at(2).weights(),
"end", node_def.name(), end));
int begin_mask = attrs.get<int>("begin_mask");
for (int i = 0; i < begin.size(); i++) {
if ((1 << i) & begin_mask) {
begin[i] = 0;
}
}
int end_mask = attrs.get<int>("end_mask");
for (int i = 0; i < end.size(); i++) {
if ((1 << i) & end_mask) {
end[i] = input_dims[i];
}
}
// Get strides per axis (must all be 1).
TRT_ShapedWeights stride_weights = inputs.at(3).weights();
const int* stride_weights_ptr =
static_cast<int*>(const_cast<void*>(stride_weights.GetValues()));
std::vector<int> strides(stride_weights_ptr,
stride_weights_ptr + stride_weights.count());
for (int x : strides) {
if (x != 1) {
return tensorflow::errors::Unimplemented(
"StridedSlice is only implemented for stride of 1, at ",
node_def.name());
}
}
// Unsupported options.
for (string attr : {"ellipsis_mask", "new_axis_mask", "shrink_axis_mask"}) {
int ellipsis_mask = attrs.get<int>(attr);
if (ellipsis_mask != 0) {
return tensorflow::errors::Unimplemented(
attr, " is not implemented for StridedSlice, at ",
node_def.name());
}
}
nvinfer1::ITensor* tensor = const_cast<nvinfer1::ITensor*>(
inputs.at(0).tensor());
// Reshape if necessary to 4-D.
const bool need_reshape = (input_dims.size() != 4);
int reshape_dims_added = 0;
nvinfer1::Dims reshape_dims;
if (need_reshape) {
// Add new dims after batch dim until tensor is 4D.
while (input_dims.size() < 4) {
input_dims.insert(input_dims.begin()+1, 1);
begin.insert(begin.begin()+1, 0);
end.insert(end.begin()+1, 1);
reshape_dims_added++;
}
reshape_dims = VectorToTrtDims(input_dims, /*ignore_first_dim=*/true);
}
// Find dimensions which need to be sliced.
std::vector<int> pad_dims;
for (int i = 0; i < input_dims.size(); i++) {
if (begin[i] != 0 || (end[i] - input_dims[i]) != 0) {
if (i == 0) {
return tensorflow::errors::Unimplemented(
"StridedSlice can't modify batch dim, at ", node_def.name());
}
else if ((end[i] - begin[i]) < 0) {
LOG(INFO) << begin[i] << ", " << end[i];
return tensorflow::errors::InvalidArgument(
"New size of sliced dimension is negative, at ", node_def.name());
}
pad_dims.push_back(i);
}
}
if (pad_dims.size() == 0) {
// No dimensions are changed. We could create a padding layer anyway with
// values of 0.
if (params->validation_only) return Status::OK();
params->outputs->push_back(inputs.at(0));
return tensorflow::Status::OK();
} else if (pad_dims.size() == 1) {
// Only one dim is modified but we have to have 2, mark a second dim which
// will have padding of 0.
if (pad_dims[0] == 1 || pad_dims[0] == 3) {
pad_dims.push_back(2);
} else if (pad_dims[0] == 2) {
pad_dims.push_back(3);
}
} else if (pad_dims.size() > 2) {
return tensorflow::errors::Unimplemented(
"StridedSlice can only modify 2 dimensions, at ",
node_def.name());
}
std::sort(pad_dims.begin(), pad_dims.end());
// Convert to pre/post padding values.
nvinfer1::DimsHW pre_padding, post_padding;
for (int i = 0; i < pad_dims.size(); i++) {
const int axis = pad_dims[i];
pre_padding.d[i] = -begin[axis];
post_padding.d[i] = end[axis] - input_dims[axis];
}
// IPaddingLayer will always apply the padding to dims 2,3 (input format is
// NCHW).
const bool need_transpose = !(pad_dims[0] == 2 && pad_dims[1] == 3);
std::vector<int> transpose_order(input_dims.size());
std::vector<int> inv_transpose_order(input_dims.size());
if (need_transpose) {
if (pad_dims[0] == 1 && pad_dims[1] == 3) {
transpose_order = {0, 2, 1, 3};
inv_transpose_order = {0, 2, 1, 3};
} else if (pad_dims[0] == 1 && pad_dims[1] == 2) {
transpose_order = {0, 3, 1, 2};
inv_transpose_order = {0, 2, 3, 1};
}
}
if (params->validation_only) return Status::OK();
// Start conversion.
if (need_reshape) {
const nvinfer1::ITensor* output_tensor = nullptr;
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
inputs.at(0), reshape_dims, &output_tensor));
tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
}
if (need_transpose) {
const nvinfer1::ITensor* output_tensor = nullptr;
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
tensor, transpose_order, &output_tensor));
tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
}
// Add padding layer
nvinfer1::IPaddingLayer* layer = params->converter->network()->addPadding(
*const_cast<nvinfer1::ITensor*>(tensor), pre_padding, post_padding);
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
tensor = layer->getOutput(0);
// Restore transpose
if (need_transpose) {
const nvinfer1::ITensor* output_tensor = nullptr;
TF_RETURN_IF_ERROR(params->converter->TransposeTensor(
tensor, inv_transpose_order, &output_tensor));
tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
}
// Restore reshape
if (need_reshape) {
// Calculate output dimensions
for(int i = 0; i < pad_dims.size(); i++) {
const int axis = pad_dims[i];
input_dims[axis] = end[axis] - begin[axis];
}
// Remove added 1 dimensions
for (int i = 0; i < reshape_dims_added; i++) {
int value = input_dims[1];
if (value != 1) {
return tensorflow::errors::Internal(
"StridedSlice error when reshaping, at ",
node_def.name());
}
input_dims.erase(input_dims.begin()+1);
}
nvinfer1::Dims new_dims = VectorToTrtDims(input_dims,
/*ignore_first_dim=*/true);
const nvinfer1::ITensor* output_tensor = nullptr;
TF_RETURN_IF_ERROR(params->converter->PrepareTensorForShape(
TRT_TensorOrWeights(tensor), new_dims, &output_tensor));
tensor = const_cast<nvinfer1::ITensor*>(output_tensor);
}
params->outputs->push_back(
TRT_TensorOrWeights(const_cast<nvinfer1::ITensor*>(tensor)));
return tensorflow::Status::OK();
}
tensorflow::Status ConvertConv2D(OpConverterParams* params) {
return ConvertConv2DHelper(params, ConvolutionType::DEFAULT);
}
......@@ -3335,14 +3579,15 @@ static void RegisterValidatableOpConverters(
(*registration)["Const"] = ConvertConst;
(*registration)["Conv2D"] = ConvertConv2D;
(*registration)["DepthwiseConv2dNative"] = ConvertConv2DDepthwise;
(*registration)["Transpose"] = ConvertTranspose;
(*registration)["Reshape"] = ConvertReshape;
(*registration)["ExpandDims"] = ConvertExpandDims;
(*registration)["MatMul"] = ConvertMatMul;
(*registration)["Pad"] = ConvertPad;
(*registration)["Relu6"] = ConvertRelu6;
(*registration)["Reshape"] = ConvertReshape;
(*registration)["Square"] = ConvertSquare;
(*registration)["ExpandDims"] = ConvertExpandDims;
(*registration)["Squeeze"] = ConvertSqueeze;
(*registration)["StridedSlice"] = ConvertStridedSlice;
(*registration)["Transpose"] = ConvertTranspose;
for (auto quantization_op_type :
{"QuantizeAndDequantizeV2", "QuantizeAndDequantizeV3",
......
......@@ -2129,7 +2129,6 @@ TEST_F(OpConverterTest, ConvertExpandDims) {
auto expanddims =
ops::ExpandDims(s.WithOpName("my_expanddims"), input, weights);
const NodeDef& node_def = expanddims.operation.node()->def();
{
// Input is weights, should fail.
Reset();
......@@ -2349,6 +2348,307 @@ TEST_F(OpConverterTest, ConvertSqueeze) {
}
}
TEST_F(OpConverterTest, ConvertStridedSlice) {
{
// Input list is empty, should fail.
NodeDef node_def = MakeNodeDef("my_strided_slice", "StridedSlice", {});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"StridedSlice expects 4 inputs, at my_strided_slice");
}
// Get nodedef for StridedSlice layer.
auto get_strided_slice_nodedef = [](int begin_mask = 0,
int ellipsis_mask = 0,
int end_mask = 0,
int new_axis_mask = 0,
int shrink_axis_mask = 0) -> NodeDef {
Scope s = Scope::NewRootScope();
auto input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT);
auto begin = ops::Placeholder(s.WithOpName("begin"), DT_INT32);
auto end = ops::Placeholder(s.WithOpName("end"), DT_INT32);
auto strides = ops::Placeholder(s.WithOpName("strides"), DT_INT32);
ops::StridedSlice::Attrs strided_slice_attrs;
strided_slice_attrs.begin_mask_ = begin_mask;
strided_slice_attrs.ellipsis_mask_ = ellipsis_mask;
strided_slice_attrs.end_mask_ = end_mask;
strided_slice_attrs.new_axis_mask_ = new_axis_mask;
strided_slice_attrs.shrink_axis_mask_ = shrink_axis_mask;
auto strided_slice = ops::StridedSlice(s.WithOpName("my_strided_slice"),
input, begin, end, strides, strided_slice_attrs);
return strided_slice.operation.node()->def();
};
{
NodeDef node_def = get_strided_slice_nodedef();
AddTestWeights<int32>("input", {1, 2, 3}, {1, 2, 3, 4, 5, 6});
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"StridedSlice is only implemented for tensors, at my_strided_slice");
}
{
// Begin, end, strides are tensors, should fail.
Reset();
NodeDef node_def = get_strided_slice_nodedef();
AddTestTensor("input", {1, 2, 3});
AddTestTensor("begin", {4});
AddTestTensor("end", {4});
AddTestTensor("strides", {4});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"StridedSlice expects weights for begin, end, and strides, at "
"my_strided_slice");
}
{
// Non-zero ellipsis_mask, should fail.
Reset();
NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0,
/*ellipsis_mask=*/2, /*end_mask=*/0, /*new_axis_mask=*/0,
/*shrink_axis_mask=*/0);
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"ellipsis_mask is not implemented for StridedSlice, at "
"my_strided_slice");
}
{
// Non-zero ellipsis_mask, should fail.
Reset();
NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0,
/*ellipsis_mask=*/0, /*end_mask=*/0, /*new_axis_mask=*/2,
/*shrink_axis_mask=*/0);
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"new_axis_mask is not implemented for StridedSlice, at "
"my_strided_slice");
}
{
// Non-zero shrink_axis_mask, should fail.
Reset();
NodeDef node_def = get_strided_slice_nodedef(/*begin_mask=*/0,
/*ellipsis_mask=*/0, /*end_mask=*/0, /*new_axis_mask=*/0,
/*shrink_axis_mask=*/2);
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"shrink_axis_mask is not implemented for StridedSlice, at "
"my_strided_slice");
}
{
// Modify batch dim, should fail.
Reset();
NodeDef node_def = get_strided_slice_nodedef();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
AddTestWeights<int32>("end", {4}, {0, 1, 2, 3});
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"StridedSlice can't modify batch dim, at my_strided_slice");
}
{
// Stride is not 1, should fail.
Reset();
NodeDef node_def = get_strided_slice_nodedef();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
AddTestWeights<int32>("end", {4}, {1, 1, 2, 3});
AddTestWeights<int32>("strides", {4}, {1, 2, -1, 3});
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED, "StridedSlice is only implemented for "
"stride of 1, at my_strided_slice");
}
{
// Begin out of bounds, should fail.
Reset();
NodeDef node_def = get_strided_slice_nodedef();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("begin", {4}, {1, 2, 3, 4});
AddTestWeights<int32>("end", {4}, {0, 1, 2, 3});
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"begin for StridedSlice is invalid, must be in the range "
"[-rank(input), rank(input)], at my_strided_slice");
}
{
// End out of bounds, should fail.
Reset();
NodeDef node_def = get_strided_slice_nodedef();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("begin", {4}, {0, 0, 0, 0});
AddTestWeights<int32>("end", {4}, {1, 2, 3, 4});
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"end for StridedSlice is invalid, must be in the range "
"[-rank(input), rank(input)], at my_strided_slice");
}
{
// Size of sliced dim is negative, should fail.
Reset();
NodeDef node_def = get_strided_slice_nodedef();
AddTestTensor("input", {1, 2, 3});
AddTestWeights<int32>("begin", {4}, {0, 0, 2, 0});
AddTestWeights<int32>("end", {4}, {1, 1, 0, 3});
AddTestWeights<int32>("strides", {4}, {1, 1, 1, 1});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"New size of sliced dimension is negative, at my_strided_slice");
}
struct TestParams {
TestParams(const std::vector<int>& input_dims,
const std::vector<int>& expected_output_dims,
const std::vector<int>& begin,
const std::vector<int>& end,
const std::vector<int>& begin_mask,
const std::vector<int>& end_mask,
const std::vector<int>& expected_output)
: input_dims(input_dims),
expected_output_dims(expected_output_dims),
begin(begin),
end(end),
expected_output(expected_output) {
// Masks are provided in terms of vectors for readability. Convert them to
// binary here.
this->begin_mask = 0;
for (int i = 0; i < begin_mask.size(); i++) {
if (begin_mask[i]) this->begin_mask |= (1 << i);
}
this->end_mask = 0;
for (int i = 0; i < end_mask.size(); i++) {
if (end_mask[i]) this->end_mask |= (1 << i);
}
}
std::vector<int> input_dims;
std::vector<int> expected_output_dims;
std::vector<int> begin;
std::vector<int> end;
int begin_mask;
int end_mask;
std::vector<int> expected_output;
};
// Ok.
const int kStridedSliceOKCases = 18;
TestParams ok_params[kStridedSliceOKCases] = {
// 2D Crop.
TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2},
/*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 1, 2},
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0},
/*expected_output=*/{1, 2}},
TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2},
/*begin=*/{0, 0, 1, 1}, /*end=*/{0, 0, 0, 0},
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1},
/*expected_output=*/{5, 6}},
TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 2},
/*begin=*/{0, 0, 1, 1}, /*end=*/{0, 1, 2, 3},
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 0, 0},
/*expected_output=*/{5, 6}},
// 2D Crop, with transpose.
TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1},
/*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 2, 1},
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0},
/*expected_output=*/{1, 2}},
TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 2, 1},
/*begin=*/{0, 1, 1, 0}, /*end=*/{0, 2, 3, 1},
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0},
/*expected_output=*/{5, 6}},
TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2},
/*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 1, 2},
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0},
/*expected_output=*/{1, 2}},
TestParams{/*input_dims=*/{2, 1, 3}, /*expected_output_dims=*/{1, 1, 2},
/*begin=*/{0, 1, 0, 1}, /*end=*/{0, 2, 1, 3},
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 0, 0},
/*expected_output=*/{5, 6}},
// 2D Crop, with reshape.
TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2},
/*begin=*/{0, 0, 0}, /*end=*/{0, 1, 2},
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 0},
/*expected_output=*/{1, 2}},
TestParams{/*input_dims=*/{2, 3}, /*expected_output_dims=*/{1, 2},
/*begin=*/{0, 1, 1}, /*end=*/{0, 0, 0},
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 1},
/*expected_output=*/{5, 6}},
// 1D Crop.
TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 2, 2},
/*begin=*/{0, 0, 0, 0}, /*end=*/{0, 0, 0, 2},
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 0},
/*expected_output=*/{1, 2, 4, 5}},
TestParams{/*input_dims=*/{1, 2, 3}, /*expected_output_dims=*/{1, 1, 3},
/*begin=*/{0, 0, 1, 0}, /*end=*/{0, 0, 0, 0},
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1},
/*expected_output=*/{4, 5, 6}},
// 1D Crop, with transpose.
TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1},
/*begin=*/{0, 0, 0, 0}, /*end=*/{0, 1, 0, 0},
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 0, 1, 1},
/*expected_output=*/{1, 2, 3}},
TestParams{/*input_dims=*/{2, 3, 1}, /*expected_output_dims=*/{1, 3, 1},
/*begin=*/{0, 1, 0, 0}, /*end=*/{0, 0, 0, 0},
/*begin_mask=*/{0, 0, 0, 0}, /*end_mask=*/{1, 1, 1, 1},
/*expected_output=*/{4, 5, 6}},
// 1D Crop, with reshape.
TestParams{/*input_dims=*/{6}, /*expected_output_dims=*/{3},
/*begin=*/{0, 0}, /*end=*/{0, 3},
/*begin_mask=*/{0, 0}, /*end_mask=*/{1, 0},
/*expected_output=*/{1, 2, 3}},
TestParams{/*input_dims=*/{1, 6}, /*expected_output_dims=*/{1, 3},
/*begin=*/{0, 0, 2}, /*end=*/{0, 0, 5},
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 1, 0},
/*expected_output=*/{3, 4, 5}},
TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1},
/*begin=*/{0, 2, 0}, /*end=*/{0, 5, 0},
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1},
/*expected_output=*/{3, 4, 5}},
// Negative axis.
TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{3, 1},
/*begin=*/{0, -6, 0}, /*end=*/{0, -3, 0},
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1},
/*expected_output=*/{1, 2, 3}},
TestParams{/*input_dims=*/{6, 1}, /*expected_output_dims=*/{5, 1},
/*begin=*/{0, 0, 0}, /*end=*/{0, -1, 0},
/*begin_mask=*/{0, 0, 0}, /*end_mask=*/{1, 0, 1},
/*expected_output=*/{1, 2, 3, 4, 5}},
};
for (int i = 0; i < kStridedSliceOKCases; i++) {
Reset();
NodeDef node_def = get_strided_slice_nodedef(ok_params[i].begin_mask, 0,
ok_params[i].end_mask);
AddTestTensor("input", ok_params[i].input_dims);
AddTestWeights<int32>("begin", {ok_params[i].begin.size()},
ok_params[i].begin);
AddTestWeights<int32>("end", {ok_params[i].end.size()}, ok_params[i].end);
std::vector<int> strides(ok_params[i].input_dims.size(), 1);
AddTestWeights<int32>("strides", {strides.size()}, strides);
RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_strided_slice", &output));
std::vector<float> output_data(ok_params[i].expected_output.size());
BuildAndRun<float>({{"input", {1, 2, 3, 4, 5, 6}}}, "my_strided_slice",
&output_data);
EXPECT_THAT(output_data, ElementsAreArray(ok_params[i].expected_output));
}
}
} // namespace convert
} // namespace tensorrt
} // namespace tensorflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册