未验证 提交 b4bb98ea 编写于 作者: G Gruhuang 提交者: GitHub

Merge pull request #21268 from pccvlab:tf_Arg

add argmax and argmin parsing for tensorflow

* add argmax and argmin for tf

* remove whitespace

* remove whitespace

* remove static_cast
Signed-off-by: NCrayon-new <1349159541@qq.com>
上级 f7aa91e6
......@@ -599,6 +599,8 @@ private:
void parseActivation (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
void parseExpandDims (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
void parseSquare (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
void parseArg (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
void parseCustomLayer (tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams);
};
......@@ -677,6 +679,7 @@ const TFImporter::DispatchMap TFImporter::buildDispatchMap()
dispatch["Elu"] = dispatch["Exp"] = dispatch["Identity"] = dispatch["Relu6"] = &TFImporter::parseActivation;
dispatch["ExpandDims"] = &TFImporter::parseExpandDims;
dispatch["Square"] = &TFImporter::parseSquare;
dispatch["ArgMax"] = dispatch["ArgMin"] = &TFImporter::parseArg;
return dispatch;
}
......@@ -2624,6 +2627,22 @@ void TFImporter::parseActivation(tensorflow::GraphDef& net, const tensorflow::No
connectToAllBlobs(layer_id, dstNet, parsePin(layer.input(0)), id, num_inputs);
}
void TFImporter::parseArg(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
const std::string& name = layer.name();
const std::string& type = layer.op();
Mat dimension = getTensorContent(getConstBlob(layer, value_id, 1));
CV_Assert(dimension.total() == 1 && dimension.type() == CV_32SC1);
layerParams.set("axis", *dimension.ptr<int>());
layerParams.set("op", type == "ArgMax" ? "max" : "min");
layerParams.set("keepdims", false); //tensorflow doesn't have this atrr, the output's dims minus one(default);
int id = dstNet.addLayer(name, "Arg", layerParams);
layer_id[name] = id;
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
}
void TFImporter::parseCustomLayer(tensorflow::GraphDef& net, const tensorflow::NodeDef& layer, LayerParams& layerParams)
{
// Importer does not know how to map this TensorFlow's operation onto OpenCV's layer.
......
......@@ -185,6 +185,14 @@ TEST_P(Test_TensorFlow_layers, reduce_sum_channel_keep_dims)
runTensorFlowNet("reduce_sum_channel", false, 0.0, 0.0, false, "_keep_dims");
}
TEST_P(Test_TensorFlow_layers, ArgLayer)
{
if (backend != DNN_BACKEND_OPENCV || target != DNN_TARGET_CPU)
throw SkipTestException("Only CPU is supported"); // FIXIT use tags
runTensorFlowNet("argmax");
runTensorFlowNet("argmin");
}
TEST_P(Test_TensorFlow_layers, conv_single_conv)
{
runTensorFlowNet("single_conv");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册