提交 f723cede 编写于 作者: D David Geldreich

add loading TensorFlow/Caffe net from memory buffer

add a corresponding test
上级 6e4f9433
......@@ -634,11 +634,33 @@ CV__DNN_EXPERIMENTAL_NS_BEGIN
*/
CV_EXPORTS_W Net readNetFromCaffe(const String &prototxt, const String &caffeModel = String());
/** @brief Reads a network model stored in Caffe model in memory.
* @details This is an overloaded member function, provided for convenience.
* It differs from the above function only in what argument(s) it accepts.
* @param bufferProto buffer containing the content of the .prototxt file
* @param lenProto length of bufferProto
* @param bufferModel buffer containing the content of the .caffemodel file
* @param lenModel length of bufferModel
*/
CV_EXPORTS Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
const char *bufferModel = NULL, size_t lenModel = 0);
/** @brief Reads a network model stored in Tensorflow model file.
* @details This is shortcut consisting from createTensorflowImporter and Net::populateNet calls.
*/
CV_EXPORTS_W Net readNetFromTensorflow(const String &model, const String &config = String());
/** @brief Reads a network model stored in Tensorflow model in memory.
* @details This is an overloaded member function, provided for convenience.
* It differs from the above function only in what argument(s) it accepts.
* @param bufferModel buffer containing the content of the pb file
* @param lenModel length of bufferModel
* @param bufferConfig buffer containing the content of the pbtxt file
* @param lenConfig length of bufferConfig
*/
CV_EXPORTS Net readNetFromTensorflow(const char *bufferModel, size_t lenModel,
const char *bufferConfig = NULL, size_t lenConfig = 0);
/** @brief Reads a network model stored in Torch model file.
* @details This is shortcut consisting from createTorchImporter and Net::populateNet calls.
*/
......
......@@ -94,6 +94,17 @@ public:
ReadNetParamsFromBinaryFileOrDie(caffeModel, &netBinary);
}
CaffeImporter(const char *dataProto, size_t lenProto,
const char *dataModel, size_t lenModel)
{
CV_TRACE_FUNCTION();
ReadNetParamsFromTextBufferOrDie(dataProto, lenProto, &net);
if (dataModel != NULL && lenModel > 0)
ReadNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBinary);
}
void addParam(const Message &msg, const FieldDescriptor *field, cv::dnn::LayerParams &params)
{
const Reflection *refl = msg.GetReflection();
......@@ -400,6 +411,15 @@ Net readNetFromCaffe(const String &prototxt, const String &caffeModel /*= String
return net;
}
Net readNetFromCaffe(const char *bufferProto, size_t lenProto,
const char *bufferModel, size_t lenModel)
{
CaffeImporter caffeImporter(bufferProto, lenProto, bufferModel, lenModel);
Net net;
caffeImporter.populateNet(net);
return net;
}
#endif //HAVE_PROTOBUF
CV__DNN_EXPERIMENTAL_NS_END
......
......@@ -1108,28 +1108,37 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) {
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
bool ReadProtoFromBinary(ZeroCopyInputStream* input, Message *proto) {
CodedInputStream coded_input(input);
coded_input.SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
return proto->ParseFromCodedStream(&coded_input);
}
bool ReadProtoFromTextFile(const char* filename, Message* proto) {
std::ifstream fs(filename, std::ifstream::in);
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
IstreamInputStream input(&fs);
bool success = google::protobuf::TextFormat::Parse(&input, proto);
fs.close();
return success;
return google::protobuf::TextFormat::Parse(&input, proto);
}
bool ReadProtoFromBinaryFile(const char* filename, Message* proto) {
std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary);
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
ZeroCopyInputStream* raw_input = new IstreamInputStream(&fs);
CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
IstreamInputStream raw_input(&fs);
return ReadProtoFromBinary(&raw_input, proto);
}
bool ReadProtoFromTextBuffer(const char* data, size_t len, Message* proto) {
ArrayInputStream input(data, len);
return google::protobuf::TextFormat::Parse(&input, proto);
}
bool success = proto->ParseFromCodedStream(coded_input);
delete coded_input;
delete raw_input;
fs.close();
return success;
bool ReadProtoFromBinaryBuffer(const char* data, size_t len, Message* proto) {
ArrayInputStream raw_input(data, len);
return ReadProtoFromBinary(&raw_input, proto);
}
void ReadNetParamsFromTextFileOrDie(const char* param_file,
......@@ -1139,6 +1148,13 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file,
UpgradeNetAsNeeded(param_file, param);
}
void ReadNetParamsFromTextBufferOrDie(const char* data, size_t len,
NetParameter* param) {
CHECK(ReadProtoFromTextBuffer(data, len, param))
<< "Failed to parse NetParameter buffer";
UpgradeNetAsNeeded("memory buffer", param);
}
void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
NetParameter* param) {
CHECK(ReadProtoFromBinaryFile(param_file, param))
......@@ -1146,6 +1162,13 @@ void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
UpgradeNetAsNeeded(param_file, param);
}
void ReadNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
NetParameter* param) {
CHECK(ReadProtoFromBinaryBuffer(data, len, param))
<< "Failed to parse NetParameter buffer";
UpgradeNetAsNeeded("memory buffer", param);
}
}
}
#endif
......@@ -102,6 +102,18 @@ void ReadNetParamsFromTextFileOrDie(const char* param_file,
void ReadNetParamsFromBinaryFileOrDie(const char* param_file,
caffe::NetParameter* param);
// Read parameters from a memory buffer into a NetParammeter proto message.
void ReadNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
caffe::NetParameter* param);
void ReadNetParamsFromTextBufferOrDie(const char* data, size_t len,
caffe::NetParameter* param);
// Utility functions used internally by Caffe and TensorFlow loaders
bool ReadProtoFromTextFile(const char* filename, ::google::protobuf::Message* proto);
bool ReadProtoFromBinaryFile(const char* filename, ::google::protobuf::Message* proto);
bool ReadProtoFromTextBuffer(const char* data, size_t len, ::google::protobuf::Message* proto);
bool ReadProtoFromBinaryBuffer(const char* data, size_t len, ::google::protobuf::Message* proto);
}
}
#endif
......
......@@ -449,6 +449,9 @@ void ExcludeLayer(tensorflow::GraphDef& net, const int layer_index, const int in
class TFImporter : public Importer {
public:
TFImporter(const char *model, const char *config = NULL);
TFImporter(const char *dataModel, size_t lenModel,
const char *dataConfig = NULL, size_t lenConfig = 0);
void populateNet(Net dstNet);
~TFImporter() {}
......@@ -479,6 +482,15 @@ TFImporter::TFImporter(const char *model, const char *config)
ReadTFNetParamsFromTextFileOrDie(config, &netTxt);
}
TFImporter::TFImporter(const char *dataModel, size_t lenModel,
const char *dataConfig, size_t lenConfig)
{
if (dataModel != NULL && lenModel > 0)
ReadTFNetParamsFromBinaryBufferOrDie(dataModel, lenModel, &netBin);
if (dataConfig != NULL && lenConfig > 0)
ReadTFNetParamsFromTextBufferOrDie(dataConfig, lenConfig, &netTxt);
}
void TFImporter::kernelFromTensor(const tensorflow::TensorProto &tensor, Mat &dstBlob)
{
MatShape shape;
......@@ -1326,5 +1338,14 @@ Net readNetFromTensorflow(const String &model, const String &config)
return net;
}
Net readNetFromTensorflow(const char* bufferModel, size_t lenModel,
const char* bufferConfig, size_t lenConfig)
{
TFImporter importer(bufferModel, lenModel, bufferConfig, lenConfig);
Net net;
importer.populateNet(net);
return net;
}
CV__DNN_EXPERIMENTAL_NS_END
}} // namespace
......@@ -23,6 +23,7 @@ Implementation of various functions which are related to Tensorflow models readi
#include "graph.pb.h"
#include "tf_io.hpp"
#include "../caffe/caffe_io.hpp"
#include "../caffe/glog_emulator.hpp"
namespace cv {
......@@ -36,41 +37,28 @@ using namespace ::google::protobuf::io;
const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte.
// TODO: remove Caffe duplicate
bool ReadProtoFromBinaryFileTF(const char* filename, Message* proto) {
std::ifstream fs(filename, std::ifstream::in | std::ifstream::binary);
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
ZeroCopyInputStream* raw_input = new IstreamInputStream(&fs);
CodedInputStream* coded_input = new CodedInputStream(raw_input);
coded_input->SetTotalBytesLimit(kProtoReadBytesLimit, 536870912);
bool success = proto->ParseFromCodedStream(coded_input);
delete coded_input;
delete raw_input;
fs.close();
return success;
void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromBinaryFile(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
}
bool ReadProtoFromTextFileTF(const char* filename, Message* proto) {
std::ifstream fs(filename, std::ifstream::in);
CHECK(fs.is_open()) << "Can't open \"" << filename << "\"";
IstreamInputStream input(&fs);
bool success = google::protobuf::TextFormat::Parse(&input, proto);
fs.close();
return success;
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromBinaryBuffer(data, len, param))
<< "Failed to parse GraphDef buffer";
}
void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromBinaryFileTF(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
CHECK(ReadProtoFromTextFile(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
}
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromTextFileTF(param_file, param))
<< "Failed to parse GraphDef file: " << param_file;
void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param) {
CHECK(ReadProtoFromTextBuffer(data, len, param))
<< "Failed to parse GraphDef buffer";
}
}
......
......@@ -25,6 +25,13 @@ void ReadTFNetParamsFromBinaryFileOrDie(const char* param_file,
void ReadTFNetParamsFromTextFileOrDie(const char* param_file,
tensorflow::GraphDef* param);
// Read parameters from a memory buffer into a GraphDef proto message.
void ReadTFNetParamsFromBinaryBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param);
void ReadTFNetParamsFromTextBufferOrDie(const char* data, size_t len,
tensorflow::GraphDef* param);
}
}
......
......@@ -55,6 +55,24 @@ static std::string _tf(TString filename)
return (getOpenCVExtraDir() + "/dnn/") + filename;
}
TEST(Test_Caffe, memory_read)
{
const string proto = findDataFile("dnn/bvlc_googlenet.prototxt", false);
const string model = findDataFile("dnn/bvlc_googlenet.caffemodel", false);
string dataProto;
ASSERT_TRUE(readFileInMemory(proto, dataProto));
string dataModel;
ASSERT_TRUE(readFileInMemory(model, dataModel));
Net net = readNetFromCaffe(dataProto.c_str(), dataProto.size());
ASSERT_FALSE(net.empty());
Net net2 = readNetFromCaffe(dataProto.c_str(), dataProto.size(),
dataModel.c_str(), dataModel.size());
ASSERT_FALSE(net2.empty());
}
TEST(Test_Caffe, read_gtsrb)
{
Net net = readNetFromCaffe(_tf("gtsrb.prototxt"));
......@@ -67,13 +85,26 @@ TEST(Test_Caffe, read_googlenet)
ASSERT_FALSE(net.empty());
}
TEST(Reproducibility_AlexNet, Accuracy)
typedef testing::TestWithParam<tuple<bool> > Reproducibility_AlexNet;
TEST_P(Reproducibility_AlexNet, Accuracy)
{
bool readFromMemory = get<0>(GetParam());
Net net;
{
const string proto = findDataFile("dnn/bvlc_alexnet.prototxt", false);
const string model = findDataFile("dnn/bvlc_alexnet.caffemodel", false);
net = readNetFromCaffe(proto, model);
if (readFromMemory)
{
string dataProto;
ASSERT_TRUE(readFileInMemory(proto, dataProto));
string dataModel;
ASSERT_TRUE(readFileInMemory(model, dataModel));
net = readNetFromCaffe(dataProto.c_str(), dataProto.size(),
dataModel.c_str(), dataModel.size());
}
else
net = readNetFromCaffe(proto, model);
ASSERT_FALSE(net.empty());
}
......@@ -86,6 +117,8 @@ TEST(Reproducibility_AlexNet, Accuracy)
normAssert(ref, out);
}
INSTANTIATE_TEST_CASE_P(Test_Caffe, Reproducibility_AlexNet, testing::Values(true, false));
#if !defined(_WIN32) || defined(_WIN64)
TEST(Reproducibility_FCN, Accuracy)
{
......
......@@ -57,4 +57,23 @@ inline void normAssert(cv::InputArray ref, cv::InputArray test, const char *comm
EXPECT_LE(normInf, lInf) << comment;
}
inline bool readFileInMemory(const std::string& filename, std::string& content)
{
std::ios::openmode mode = std::ios::in | std::ios::binary;
std::ifstream ifs(filename.c_str(), mode);
if (!ifs.is_open())
return false;
content.clear();
ifs.seekg(0, std::ios::end);
content.reserve(ifs.tellg());
ifs.seekg(0, std::ios::beg);
content.assign((std::istreambuf_iterator<char>(ifs)),
std::istreambuf_iterator<char>());
return true;
}
#endif
......@@ -75,14 +75,32 @@ static std::string path(const std::string& file)
}
static void runTensorFlowNet(const std::string& prefix, bool hasText = false,
double l1 = 1e-5, double lInf = 1e-4)
double l1 = 1e-5, double lInf = 1e-4,
bool memoryLoad = false)
{
std::string netPath = path(prefix + "_net.pb");
std::string netConfig = (hasText ? path(prefix + "_net.pbtxt") : "");
std::string inpPath = path(prefix + "_in.npy");
std::string outPath = path(prefix + "_out.npy");
Net net = readNetFromTensorflow(netPath, netConfig);
Net net;
if (memoryLoad)
{
// Load files into a memory buffers
string dataModel;
ASSERT_TRUE(readFileInMemory(netPath, dataModel));
string dataConfig;
if (hasText)
ASSERT_TRUE(readFileInMemory(netConfig, dataConfig));
net = readNetFromTensorflow(dataModel.c_str(), dataModel.size(),
dataConfig.c_str(), dataConfig.size());
}
else
net = readNetFromTensorflow(netPath, netConfig);
ASSERT_FALSE(net.empty());
cv::Mat input = blobFromNPY(inpPath);
cv::Mat target = blobFromNPY(outPath);
......@@ -216,4 +234,15 @@ TEST(Test_TensorFlow, resize_nearest_neighbor)
runTensorFlowNet("resize_nearest_neighbor");
}
TEST(Test_TensorFlow, memory_read)
{
double l1 = 1e-5;
double lInf = 1e-4;
runTensorFlowNet("lstm", true, l1, lInf, true);
runTensorFlowNet("batch_norm", false, l1, lInf, true);
runTensorFlowNet("fused_batch_norm", false, l1, lInf, true);
runTensorFlowNet("batch_norm_text", true, l1, lInf, true);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册