From bf87a43185ea11be65b974627273f2a191d66873 Mon Sep 17 00:00:00 2001 From: Dmitry Kurtaev Date: Tue, 3 Apr 2018 18:28:05 +0300 Subject: [PATCH] Faster-RCNN object detection models from TensorFlow --- .../dnn/include/opencv2/dnn/all_layers.hpp | 6 + modules/dnn/src/init.cpp | 1 + .../dnn/src/layers/crop_and_resize_layer.cpp | 108 +++++++ .../dnn/src/layers/detection_output_layer.cpp | 5 +- modules/dnn/src/tensorflow/tf_importer.cpp | 31 +- modules/dnn/test/test_tf_importer.cpp | 16 + samples/dnn/README.md | 4 +- samples/dnn/tf_text_graph_faster_rcnn.py | 291 ++++++++++++++++++ 8 files changed, 457 insertions(+), 5 deletions(-) create mode 100644 modules/dnn/src/layers/crop_and_resize_layer.cpp create mode 100644 samples/dnn/tf_text_graph_faster_rcnn.py diff --git a/modules/dnn/include/opencv2/dnn/all_layers.hpp b/modules/dnn/include/opencv2/dnn/all_layers.hpp index f2124dd516..ffb09a2b95 100644 --- a/modules/dnn/include/opencv2/dnn/all_layers.hpp +++ b/modules/dnn/include/opencv2/dnn/all_layers.hpp @@ -581,6 +581,12 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN static Ptr create(const LayerParams& params); }; + class CV_EXPORTS CropAndResizeLayer : public Layer + { + public: + static Ptr create(const LayerParams& params); + }; + //! @} //! @} CV__DNN_EXPERIMENTAL_NS_END diff --git a/modules/dnn/src/init.cpp b/modules/dnn/src/init.cpp index 28759daf2f..2bff16c4eb 100644 --- a/modules/dnn/src/init.cpp +++ b/modules/dnn/src/init.cpp @@ -84,6 +84,7 @@ void initializeLayerFactory() CV_DNN_REGISTER_LAYER_CLASS(Reshape, ReshapeLayer); CV_DNN_REGISTER_LAYER_CLASS(Flatten, FlattenLayer); CV_DNN_REGISTER_LAYER_CLASS(ResizeNearestNeighbor, ResizeNearestNeighborLayer); + CV_DNN_REGISTER_LAYER_CLASS(CropAndResize, CropAndResizeLayer); CV_DNN_REGISTER_LAYER_CLASS(Convolution, ConvolutionLayer); CV_DNN_REGISTER_LAYER_CLASS(Deconvolution, DeconvolutionLayer); diff --git a/modules/dnn/src/layers/crop_and_resize_layer.cpp b/modules/dnn/src/layers/crop_and_resize_layer.cpp new file mode 100644 index 0000000000..3f92a8488d --- /dev/null +++ b/modules/dnn/src/layers/crop_and_resize_layer.cpp @@ -0,0 +1,108 @@ +#include "../precomp.hpp" +#include "layers_common.hpp" + +namespace cv { namespace dnn { + +class CropAndResizeLayerImpl CV_FINAL : public CropAndResizeLayer +{ +public: + CropAndResizeLayerImpl(const LayerParams& params) + { + CV_Assert(params.has("width"), params.has("height")); + outWidth = params.get("width"); + outHeight = params.get("height"); + } + + bool getMemoryShapes(const std::vector &inputs, + const int requiredOutputs, + std::vector &outputs, + std::vector &internals) const CV_OVERRIDE + { + CV_Assert(inputs.size() == 2, inputs[0].size() == 4); + if (inputs[0][0] != 1) + CV_Error(Error::StsNotImplemented, ""); + outputs.resize(1, MatShape(4)); + outputs[0][0] = inputs[1][2]; // Number of bounding boxes. + outputs[0][1] = inputs[0][1]; // Number of channels. + outputs[0][2] = outHeight; + outputs[0][3] = outWidth; + return false; + } + + void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + Layer::forward_fallback(inputs_arr, outputs_arr, internals_arr); + } + + void forward(std::vector &inputs, std::vector &outputs, std::vector &internals) CV_OVERRIDE + { + CV_TRACE_FUNCTION(); + CV_TRACE_ARG_VALUE(name, "name", name.c_str()); + + Mat& inp = *inputs[0]; + Mat& out = outputs[0]; + Mat boxes = inputs[1]->reshape(1, inputs[1]->total() / 7); + const int numChannels = inp.size[1]; + const int inpHeight = inp.size[2]; + const int inpWidth = inp.size[3]; + const int inpSpatialSize = inpHeight * inpWidth; + const int outSpatialSize = outHeight * outWidth; + CV_Assert(inp.isContinuous(), out.isContinuous()); + + for (int b = 0; b < boxes.rows; ++b) + { + float* outDataBox = out.ptr(b); + float left = boxes.at(b, 3); + float top = boxes.at(b, 4); + float right = boxes.at(b, 5); + float bottom = boxes.at(b, 6); + float boxWidth = right - left; + float boxHeight = bottom - top; + + float heightScale = boxHeight * static_cast(inpHeight - 1) / (outHeight - 1); + float widthScale = boxWidth * static_cast(inpWidth - 1) / (outWidth - 1); + for (int y = 0; y < outHeight; ++y) + { + float input_y = top * (inpHeight - 1) + y * heightScale; + int y0 = static_cast(input_y); + const float* inpData_row0 = (float*)inp.data + y0 * inpWidth; + const float* inpData_row1 = (y0 + 1 < inpHeight) ? (inpData_row0 + inpWidth) : inpData_row0; + for (int x = 0; x < outWidth; ++x) + { + float input_x = left * (inpWidth - 1) + x * widthScale; + int x0 = static_cast(input_x); + int x1 = std::min(x0 + 1, inpWidth - 1); + + float* outData = outDataBox + y * outWidth + x; + const float* inpData_row0_c = inpData_row0; + const float* inpData_row1_c = inpData_row1; + for (int c = 0; c < numChannels; ++c) + { + *outData = inpData_row0_c[x0] + + (input_y - y0) * (inpData_row1_c[x0] - inpData_row0_c[x0]) + + (input_x - x0) * (inpData_row0_c[x1] - inpData_row0_c[x0] + + (input_y - y0) * (inpData_row1_c[x1] - inpData_row0_c[x1] - inpData_row1_c[x0] + inpData_row0_c[x0])); + + inpData_row0_c += inpSpatialSize; + inpData_row1_c += inpSpatialSize; + outData += outSpatialSize; + } + } + } + } + } + +private: + int outWidth, outHeight; +}; + +Ptr CropAndResizeLayer::create(const LayerParams& params) +{ + return Ptr(new CropAndResizeLayerImpl(params)); +} + +} // namespace dnn +} // namespace cv diff --git a/modules/dnn/src/layers/detection_output_layer.cpp b/modules/dnn/src/layers/detection_output_layer.cpp index 44f7b32853..ee1ad95e61 100644 --- a/modules/dnn/src/layers/detection_output_layer.cpp +++ b/modules/dnn/src/layers/detection_output_layer.cpp @@ -208,8 +208,9 @@ public: CV_Assert(inputs[0][0] == inputs[1][0]); int numPriors = inputs[2][2] / 4; - CV_Assert((numPriors * _numLocClasses * 4) == inputs[0][1]); - CV_Assert(int(numPriors * _numClasses) == inputs[1][1]); + CV_Assert((numPriors * _numLocClasses * 4) == total(inputs[0], 1)); + CV_Assert(int(numPriors * _numClasses) == total(inputs[1], 1)); + CV_Assert(inputs[2][1] == 1 + (int)(!_varianceEncodedInTarget)); // num() and channels() are 1. // Since the number of bboxes to be kept is unknown before nms, we manually diff --git a/modules/dnn/src/tensorflow/tf_importer.cpp b/modules/dnn/src/tensorflow/tf_importer.cpp index bca150e3b5..f19daf9cc6 100644 --- a/modules/dnn/src/tensorflow/tf_importer.cpp +++ b/modules/dnn/src/tensorflow/tf_importer.cpp @@ -1094,9 +1094,9 @@ void TFImporter::populateNet(Net dstNet) CV_Assert(!begins.empty(), !sizes.empty(), begins.type() == CV_32SC1, sizes.type() == CV_32SC1); - if (begins.total() == 4) + if (begins.total() == 4 && data_layouts[name] == DATA_LAYOUT_NHWC) { - // Perhabs, we have an NHWC order. Swap it to NCHW. + // Swap NHWC parameters' order to NCHW. std::swap(*begins.ptr(0, 2), *begins.ptr(0, 3)); std::swap(*begins.ptr(0, 1), *begins.ptr(0, 2)); std::swap(*sizes.ptr(0, 2), *sizes.ptr(0, 3)); @@ -1176,6 +1176,9 @@ void TFImporter::populateNet(Net dstNet) layers_to_ignore.insert(next_layers[0].first); } + if (hasLayerAttr(layer, "axis")) + layerParams.set("axis", getLayerAttr(layer, "axis").i()); + id = dstNet.addLayer(name, "Scale", layerParams); } layer_id[name] = id; @@ -1547,6 +1550,10 @@ void TFImporter::populateNet(Net dstNet) layerParams.set("confidence_threshold", getLayerAttr(layer, "confidence_threshold").f()); if (hasLayerAttr(layer, "loc_pred_transposed")) layerParams.set("loc_pred_transposed", getLayerAttr(layer, "loc_pred_transposed").b()); + if (hasLayerAttr(layer, "clip")) + layerParams.set("clip", getLayerAttr(layer, "clip").b()); + if (hasLayerAttr(layer, "variance_encoded_in_target")) + layerParams.set("variance_encoded_in_target", getLayerAttr(layer, "variance_encoded_in_target").b()); int id = dstNet.addLayer(name, "DetectionOutput", layerParams); layer_id[name] = id; @@ -1563,6 +1570,26 @@ void TFImporter::populateNet(Net dstNet) layer_id[name] = id; connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, layer.input_size()); } + else if (type == "CropAndResize") + { + // op: "CropAndResize" + // input: "input" + // input: "boxes" + // input: "sizes" + CV_Assert(layer.input_size() == 3); + + Mat cropSize = getTensorContent(getConstBlob(layer, value_id, 2)); + CV_Assert(cropSize.type() == CV_32SC1, cropSize.total() == 2); + + layerParams.set("height", cropSize.at(0)); + layerParams.set("width", cropSize.at(1)); + + int id = dstNet.addLayer(name, "CropAndResize", layerParams); + layer_id[name] = id; + + connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0); + connect(layer_id, dstNet, parsePin(layer.input(1)), id, 1); + } else if (type == "Mean") { Mat indices = getTensorContent(getConstBlob(layer, value_id, 1)); diff --git a/modules/dnn/test/test_tf_importer.cpp b/modules/dnn/test/test_tf_importer.cpp index b090fd7a16..84205f72fb 100644 --- a/modules/dnn/test/test_tf_importer.cpp +++ b/modules/dnn/test/test_tf_importer.cpp @@ -270,6 +270,22 @@ TEST_P(Test_TensorFlow_nets, Inception_v2_SSD) normAssertDetections(ref, out, "", 0.5); } +TEST_P(Test_TensorFlow_nets, Inception_v2_Faster_RCNN) +{ + std::string proto = findDataFile("dnn/faster_rcnn_inception_v2_coco_2018_01_28.pbtxt", false); + std::string model = findDataFile("dnn/faster_rcnn_inception_v2_coco_2018_01_28.pb", false); + + Net net = readNetFromTensorflow(model, proto); + Mat img = imread(findDataFile("dnn/dog416.png", false)); + Mat blob = blobFromImage(img, 1.0f / 127.5, Size(800, 600), Scalar(127.5, 127.5, 127.5), true, false); + + net.setInput(blob); + Mat out = net.forward(); + + Mat ref = blobFromNPY(findDataFile("dnn/tensorflow/faster_rcnn_inception_v2_coco_2018_01_28.detection_out.npy")); + normAssertDetections(ref, out, "", 0.3); +} + TEST_P(Test_TensorFlow_nets, opencv_face_detector_uint8) { std::string proto = findDataFile("dnn/opencv_face_detector.pbtxt", false); diff --git a/samples/dnn/README.md b/samples/dnn/README.md index c438bb0910..9072ddb2a8 100644 --- a/samples/dnn/README.md +++ b/samples/dnn/README.md @@ -11,8 +11,10 @@ | [SSDs from TensorFlow](https://github.com/tensorflow/models/tree/master/research/object_detection/) | `0.00784 (2/255)` | `300x300` | `127.5 127.5 127.5` | RGB | | [YOLO](https://pjreddie.com/darknet/yolo/) | `0.00392 (1/255)` | `416x416` | `0 0 0` | RGB | | [VGG16-SSD](https://github.com/weiliu89/caffe/tree/ssd) | `1.0` | `300x300` | `104 117 123` | BGR | -| [Faster-RCNN](https://github.com/rbgirshick/py-faster-rcnn) | `1.0` | `800x600` | `102.9801, 115.9465, 122.7717` | BGR | +| [Faster-RCNN](https://github.com/rbgirshick/py-faster-rcnn) | `1.0` | `800x600` | `102.9801 115.9465 122.7717` | BGR | | [R-FCN](https://github.com/YuwenXiong/py-R-FCN) | `1.0` | `800x600` | `102.9801 115.9465 122.7717` | BGR | +| [Faster-RCNN, ResNet backbone](https://github.com/tensorflow/models/tree/master/research/object_detection/) | `1.0` | `300x300` | `103.939 116.779 123.68` | RGB | +| [Faster-RCNN, InceptionV2 backbone](https://github.com/tensorflow/models/tree/master/research/object_detection/) | `0.00784 (2/255)` | `300x300` | `127.5 127.5 127.5` | RGB | #### Face detection [An origin model](https://github.com/opencv/opencv/tree/master/samples/dnn/face_detector) diff --git a/samples/dnn/tf_text_graph_faster_rcnn.py b/samples/dnn/tf_text_graph_faster_rcnn.py new file mode 100644 index 0000000000..7ad5de283a --- /dev/null +++ b/samples/dnn/tf_text_graph_faster_rcnn.py @@ -0,0 +1,291 @@ +import argparse +import numpy as np +import tensorflow as tf + +from tensorflow.core.framework.node_def_pb2 import NodeDef +from tensorflow.tools.graph_transforms import TransformGraph +from google.protobuf import text_format + +parser = argparse.ArgumentParser(description='Run this script to get a text graph of ' + 'SSD model from TensorFlow Object Detection API. ' + 'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.') +parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.') +parser.add_argument('--output', required=True, help='Path to output text graph.') +parser.add_argument('--num_classes', default=90, type=int, help='Number of trained classes.') +parser.add_argument('--scales', default=[0.25, 0.5, 1.0, 2.0], type=float, nargs='+', + help='Hyper-parameter of grid_anchor_generator from a config file.') +parser.add_argument('--aspect_ratios', default=[0.5, 1.0, 2.0], type=float, nargs='+', + help='Hyper-parameter of grid_anchor_generator from a config file.') +parser.add_argument('--features_stride', default=16, type=float, nargs='+', + help='Hyper-parameter from a config file.') +args = parser.parse_args() + +scopesToKeep = ('FirstStageFeatureExtractor', 'Conv', + 'FirstStageBoxPredictor/BoxEncodingPredictor', + 'FirstStageBoxPredictor/ClassPredictor', + 'CropAndResize', + 'MaxPool2D', + 'SecondStageFeatureExtractor', + 'SecondStageBoxPredictor', + 'image_tensor') + +scopesToIgnore = ('FirstStageFeatureExtractor/Assert', + 'FirstStageFeatureExtractor/Shape', + 'FirstStageFeatureExtractor/strided_slice', + 'FirstStageFeatureExtractor/GreaterEqual', + 'FirstStageFeatureExtractor/LogicalAnd') + +unusedAttrs = ['T', 'Tshape', 'N', 'Tidx', 'Tdim', 'use_cudnn_on_gpu', + 'Index', 'Tperm', 'is_training', 'Tpaddings'] + +# Read the graph. +with tf.gfile.FastGFile(args.input, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + +# Removes Identity nodes +def removeIdentity(): + identities = {} + for node in graph_def.node: + if node.op == 'Identity': + identities[node.name] = node.input[0] + graph_def.node.remove(node) + + for node in graph_def.node: + for i in range(len(node.input)): + if node.input[i] in identities: + node.input[i] = identities[node.input[i]] + +removeIdentity() + +removedNodes = [] + +for i in reversed(range(len(graph_def.node))): + op = graph_def.node[i].op + name = graph_def.node[i].name + + if op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep): + if op != 'Const': + removedNodes.append(name) + + del graph_def.node[i] + else: + for attr in unusedAttrs: + if attr in graph_def.node[i].attr: + del graph_def.node[i].attr[attr] + +# Remove references to removed nodes except Const nodes. +for node in graph_def.node: + for i in reversed(range(len(node.input))): + if node.input[i] in removedNodes: + del node.input[i] + + +# Connect input node to the first layer +assert(graph_def.node[0].op == 'Placeholder') +graph_def.node[1].input.insert(0, graph_def.node[0].name) + +# Temporarily remove top nodes. +topNodes = [] +while True: + node = graph_def.node.pop() + topNodes.append(node) + if node.op == 'CropAndResize': + break + +def tensorMsg(values): + if all([isinstance(v, float) for v in values]): + dtype = 'DT_FLOAT' + field = 'float_val' + elif all([isinstance(v, int) for v in values]): + dtype = 'DT_INT32' + field = 'int_val' + else: + raise Exception('Wrong values types') + + msg = 'tensor { dtype: ' + dtype + ' tensor_shape { dim { size: %d } }' % len(values) + for value in values: + msg += '%s: %s ' % (field, str(value)) + return msg + '}' + +def addSlice(inp, out, begins, sizes): + beginsNode = NodeDef() + beginsNode.name = out + '/begins' + beginsNode.op = 'Const' + text_format.Merge(tensorMsg(begins), beginsNode.attr["value"]) + graph_def.node.extend([beginsNode]) + + sizesNode = NodeDef() + sizesNode.name = out + '/sizes' + sizesNode.op = 'Const' + text_format.Merge(tensorMsg(sizes), sizesNode.attr["value"]) + graph_def.node.extend([sizesNode]) + + sliced = NodeDef() + sliced.name = out + sliced.op = 'Slice' + sliced.input.append(inp) + sliced.input.append(beginsNode.name) + sliced.input.append(sizesNode.name) + graph_def.node.extend([sliced]) + +def addReshape(inp, out, shape): + shapeNode = NodeDef() + shapeNode.name = out + '/shape' + shapeNode.op = 'Const' + text_format.Merge(tensorMsg(shape), shapeNode.attr["value"]) + graph_def.node.extend([shapeNode]) + + reshape = NodeDef() + reshape.name = out + reshape.op = 'Reshape' + reshape.input.append(inp) + reshape.input.append(shapeNode.name) + graph_def.node.extend([reshape]) + +def addSoftMax(inp, out): + softmax = NodeDef() + softmax.name = out + softmax.op = 'Softmax' + text_format.Merge('i: -1', softmax.attr['axis']) + softmax.input.append(inp) + graph_def.node.extend([softmax]) + +addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd', + 'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2]) + +addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1', + 'FirstStageBoxPredictor/ClassPredictor/softmax') # Compare with Reshape_4 + +flatten = NodeDef() +flatten.name = 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten' # Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd +flatten.op = 'Flatten' +flatten.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd') +graph_def.node.extend([flatten]) + +proposals = NodeDef() +proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized) +proposals.op = 'PriorBox' +proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd') +proposals.input.append(graph_def.node[0].name) # image_tensor + +text_format.Merge('b: false', proposals.attr["flip"]) +text_format.Merge('b: true', proposals.attr["clip"]) +text_format.Merge('f: %f' % args.features_stride, proposals.attr["step"]) +text_format.Merge('f: 0.0', proposals.attr["offset"]) +text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), proposals.attr["variance"]) + +widths = [] +heights = [] +for a in args.aspect_ratios: + for s in args.scales: + ar = np.sqrt(a) + heights.append((args.features_stride**2) * s / ar) + widths.append((args.features_stride**2) * s * ar) + +text_format.Merge(tensorMsg(widths), proposals.attr["width"]) +text_format.Merge(tensorMsg(heights), proposals.attr["height"]) + +graph_def.node.extend([proposals]) + +# Compare with Reshape_5 +detectionOut = NodeDef() +detectionOut.name = 'detection_out' +detectionOut.op = 'DetectionOutput' + +detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten') +detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax') +detectionOut.input.append('proposals') + +text_format.Merge('i: 2', detectionOut.attr['num_classes']) +text_format.Merge('b: true', detectionOut.attr['share_location']) +text_format.Merge('i: 0', detectionOut.attr['background_label_id']) +text_format.Merge('f: 0.7', detectionOut.attr['nms_threshold']) +text_format.Merge('i: 6000', detectionOut.attr['top_k']) +text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type']) +text_format.Merge('i: 100', detectionOut.attr['keep_top_k']) +text_format.Merge('b: true', detectionOut.attr['clip']) +text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed']) + +graph_def.node.extend([detectionOut]) + +# Save as text. +for node in reversed(topNodes): + graph_def.node.extend([node]) + +addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax') + +addSlice('SecondStageBoxPredictor/Reshape_1/softmax', + 'SecondStageBoxPredictor/Reshape_1/slice', + [0, 0, 1], [-1, -1, -1]) + +addReshape('SecondStageBoxPredictor/Reshape_1/slice', + 'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1]) + +# Replace Flatten subgraph onto a single node. +for i in reversed(range(len(graph_def.node))): + if graph_def.node[i].op == 'CropAndResize': + graph_def.node[i].input.insert(1, 'detection_out') + + if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape': + shapeNode = NodeDef() + shapeNode.name = 'SecondStageBoxPredictor/Reshape/shape2' + shapeNode.op = 'Const' + text_format.Merge(tensorMsg([1, -1, 4]), shapeNode.attr["value"]) + graph_def.node.extend([shapeNode]) + + graph_def.node[i].input.pop() + graph_def.node[i].input.append(shapeNode.name) + + if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape', + 'SecondStageBoxPredictor/Flatten/flatten/strided_slice', + 'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape']: + del graph_def.node[i] + +for node in graph_def.node: + if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape': + node.op = 'Flatten' + node.input.pop() + break + +################################################################################ +### Postprocessing +################################################################################ +addSlice('detection_out', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4]) + +variance = NodeDef() +variance.name = 'proposals/variance' +variance.op = 'Const' +text_format.Merge(tensorMsg([0.1, 0.1, 0.2, 0.2]), variance.attr["value"]) +graph_def.node.extend([variance]) + +varianceEncoder = NodeDef() +varianceEncoder.name = 'variance_encoded' +varianceEncoder.op = 'Mul' +varianceEncoder.input.append('SecondStageBoxPredictor/Reshape') +varianceEncoder.input.append(variance.name) +text_format.Merge('i: 2', varianceEncoder.attr["axis"]) +graph_def.node.extend([varianceEncoder]) + +addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1]) + +detectionOut = NodeDef() +detectionOut.name = 'detection_out_final' +detectionOut.op = 'DetectionOutput' + +detectionOut.input.append('variance_encoded') +detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape') +detectionOut.input.append('detection_out/slice/reshape') + +text_format.Merge('i: %d' % args.num_classes, detectionOut.attr['num_classes']) +text_format.Merge('b: false', detectionOut.attr['share_location']) +text_format.Merge('i: %d' % (args.num_classes + 1), detectionOut.attr['background_label_id']) +text_format.Merge('f: 0.6', detectionOut.attr['nms_threshold']) +text_format.Merge('s: "CENTER_SIZE"', detectionOut.attr['code_type']) +text_format.Merge('i: 100', detectionOut.attr['keep_top_k']) +text_format.Merge('b: true', detectionOut.attr['loc_pred_transposed']) +text_format.Merge('b: true', detectionOut.attr['clip']) +text_format.Merge('b: true', detectionOut.attr['variance_encoded_in_target']) +graph_def.node.extend([detectionOut]) + +tf.train.write_graph(graph_def, "", args.output, as_text=True) -- GitLab