提交 f05649a4 编写于 作者: L liuqi

Replace pb with code.

上级 23238388
......@@ -62,7 +62,6 @@ cc_library(
]),
deps = [
":logging",
"//mace/proto:cc_proto",
"//mace/proto:stats_proto",
"//mace/utils",
":opencl_runtime",
......
......@@ -9,7 +9,7 @@
#include <malloc.h>
#include "mace/core/common.h"
#include "mace/core/registry.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
#include "mace/core/types.h"
namespace mace {
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/mace.h"
#include "mace/core/logging.h"
namespace mace {
TensorProto::TensorProto(const std::string &name,
unsigned char *data,
const std::vector<int64_t> &dims,
const DataType data_type,
uint32_t node_id) :
name_(name),
data_(data),
dims_(dims),
data_type_(data_type),
node_id_(node_id) {}
TensorProto::TensorProto(const std::string &name,
unsigned char *data,
const std::vector<int64_t> &dims,
const int data_type,
uint32_t node_id) :
name_(name),
data_(data),
dims_(dims),
data_type_(static_cast<DataType>(data_type)),
node_id_(node_id) {}
const std::string &TensorProto::name() const {
return name_;
}
unsigned char *TensorProto::data() const {
return data_;
}
const int TensorProto::data_size() const {
return data_size_;
}
const std::vector<int64_t> &TensorProto::dims() const {
return dims_;
}
DataType TensorProto::data_type() const {
return data_type_;
}
uint32_t TensorProto::node_id() const {
return node_id_;
}
Argument::Argument() : has_bits_(0) {}
void Argument::CopyFrom(const Argument &from) {
this->name_ = from.name();
this->f_ = from.f();
this->i_ = from.i();
this->s_ = from.s();
auto floats = from.floats();
this->floats_.resize(floats.size());
std::copy(floats.begin(), floats.end(), this->floats_.begin());
auto ints = from.ints();
this->ints_.resize(ints.size());
std::copy(ints.begin(), ints.end(), this->ints_.begin());
auto strings = from.floats();
this->strings_.resize(strings.size());
std::copy(floats.begin(), floats.end(), this->floats_.begin());
this->has_bits_ = from.has_bits_;
}
const std::string &Argument::name() const {
return name_;
}
void Argument::set_name(const std::string &value) {
name_ = value;
}
bool Argument::has_f() const {
return (has_bits_ & 0x00000001u) != 0;
}
void Argument::set_has_f() {
has_bits_ |= 0x00000001u;
}
float Argument::f() const {
return f_;
}
void Argument::set_f(float value) {
set_has_f();
f_ = value;
}
bool Argument::has_i() const {
return (has_bits_ & 0x00000002u) != 0;
}
void Argument::set_has_i() {
has_bits_ |= 0x00000002u;
}
int64_t Argument::i() const {
return i_;
}
void Argument::set_i(int64_t value) {
set_has_i();
i_ = value;
}
bool Argument::has_s() const {
return (has_bits_ & 0x00000004u) != 0;
}
void Argument::set_has_s() {
has_bits_ |= 0x00000004u;
}
std::string Argument::s() const {
return s_;
}
void Argument::set_s(const std::string &value) {
set_has_s();
s_ = value;
}
const std::vector<float> &Argument::floats() const {
return floats_;
}
void Argument::add_floats(float value) {
floats_.push_back(value);
}
void Argument::set_floats(const std::vector<float> &value) {
floats_.reserve(value.size());
std::copy(value.begin(), value.end(), floats_.begin());
}
const std::vector<int64_t> &Argument::ints() const {
return ints_;
}
void Argument::add_ints(int64_t value) {
ints_.push_back(value);
}
void Argument::set_ints(const std::vector<int64_t> &value) {
ints_.reserve(value.size());
std::copy(value.begin(), value.end(), ints_.begin());
}
const std::vector<std::string> &Argument::strings() const {
return strings_;
}
void Argument::add_strings(const ::std::string &value) {
strings_.push_back(value);
}
void Argument::set_strings(const std::vector<std::string> &value) {
strings_.reserve(value.size());
std::copy(value.begin(), value.end(), strings_.begin());
}
void OperatorDef::CopyFrom(const OperatorDef &from) {
name_ = from.name();
type_ = from.type();
auto from_input = from.input();
input_.resize(from_input.size());
std::copy(from_input.begin(), from_input.end(), input_.begin());
auto from_output = from.output();
output_.resize(from_output.size());
std::copy(from_output.begin(), from_output.end(), output_.begin());
auto from_arg = from.arg();
arg_.resize(from_arg.size());
for (int i = 0; i < from_arg.size(); ++i) {
arg_[i].CopyFrom(from_arg[i]);
}
auto from_output_shape = from.output_shape();
output_shape_.resize(from_output_shape.size());
for (int i = 0; i < from_output_shape.size(); ++i) {
output_shape_[i].CopyFrom(from_output_shape[i]);
}
auto from_data_type = from.output_type();
output_type_.resize(from_data_type.size());
std::copy(from_data_type.begin(), from_data_type.end(), output_type_.begin());
mem_id_ = from.mem_id();
// nnlib
node_id_ = from.node_id();
op_id_ = from.op_id();
padding_ = from.padding();
auto from_node_input = from.node_input();
node_input_.resize(from_node_input.size());
for (int i = 0; i < from_node_input.size(); ++i) {
node_input_[i].CopyFrom(from_node_input[i]);
}
auto from_out_max_byte_size = from.out_max_byte_size();
out_max_byte_size_.resize(from_out_max_byte_size.size());
std::copy(from_out_max_byte_size.begin(), from_out_max_byte_size.end(), out_max_byte_size_.begin());
has_bits_ = from.has_bits_;
}
const std::string &OperatorDef::name() const {
return name_;
}
void OperatorDef::set_name(const std::string &name_) {
set_has_name();
OperatorDef::name_ = name_;
}
bool OperatorDef::has_name() const {
return (has_bits_ & 0x00000001u) != 0;
}
void OperatorDef::set_has_name() {
has_bits_ |= 0x00000001u;
}
const std::string &OperatorDef::type() const {
return type_;
}
void OperatorDef::set_type(const std::string &type_) {
set_has_type();
OperatorDef::type_ = type_;
}
bool OperatorDef::has_type() const {
return (has_bits_ & 0x00000002u) != 0;
}
void OperatorDef::set_has_type() {
has_bits_ |= 0x00000002u;
}
int OperatorDef::mem_id() const {
return mem_id_;
}
void OperatorDef::set_mem_id(const int mem_id) {
set_has_mem_id();
mem_id_ = mem_id;
}
bool OperatorDef::has_mem_id() const {
return (has_bits_ & 0x00000004u) != 0;
}
void OperatorDef::set_has_mem_id() {
has_bits_ |= 0x00000004u;
}
uint32_t OperatorDef::node_id() const {
return node_id_;
}
uint32_t OperatorDef::op_id() const {
return op_id_;
}
uint32_t OperatorDef::padding() const {
return padding_;
}
const std::vector<NodeInput> &OperatorDef::node_input() const {
return node_input_;
}
const std::vector<int> &OperatorDef::out_max_byte_size() const {
return out_max_byte_size_;
}
const std::vector<std::string> &OperatorDef::input() const {
return input_;
}
const std::string &OperatorDef::input(int index) const {
MACE_CHECK(0 <= index && index <= input_.size());
return input_[index];
}
std::string *OperatorDef::add_input() {
input_.push_back("");
return &input_.back();
}
void OperatorDef::add_input(const ::std::string &value) {
input_.push_back(value);
}
void OperatorDef::add_input(::std::string &&value) {
input_.push_back(value);
}
void OperatorDef::set_input(const std::vector<std::string> &value) {
input_.reserve(value.size());
std::copy(value.begin(), value.end(), input_.begin());
}
const std::vector<std::string> &OperatorDef::output() const {
return output_;
}
const std::string &OperatorDef::output(int index) const {
MACE_CHECK(0 <= index && index <= output_.size());
return output_[index];
}
std::string *OperatorDef::add_output() {
output_.push_back("");
return &output_.back();
}
void OperatorDef::add_output(const ::std::string &value) {
output_.push_back(value);
}
void OperatorDef::add_output(::std::string &&value) {
output_.push_back(value);
}
void OperatorDef::set_output(const std::vector<std::string> &value) {
output_.reserve(value.size());
std::copy(value.begin(), value.end(), output_.begin());
}
const std::vector<Argument> &OperatorDef::arg() const {
return arg_;
}
Argument *OperatorDef::add_arg() {
arg_.emplace_back(Argument());
return &arg_.back();
}
const std::vector<OutputShape> &OperatorDef::output_shape() const {
return output_shape_;
}
void OperatorDef::set_output_shape(const std::vector<OutputShape> &value) {
output_shape_.reserve(value.size());
for (int i = 0; i < value.size(); ++i) {
output_shape_[i].CopyFrom(value[i]);
}
}
const std::vector<DataType> &OperatorDef::output_type() const {
return output_type_;
}
void OperatorDef::set_output_type(const std::vector<DataType> &value) {
output_type_.resize(value.size());
std::copy(value.begin(), value.end(), output_type_.begin());
}
MemoryBlock::MemoryBlock(int mem_id, uint32_t x, uint32_t y) :
mem_id_(mem_id), x_(x), y_(y) {}
int MemoryBlock::mem_id() const {
return mem_id_;
}
uint32_t MemoryBlock::x() const {
return x_;
}
uint32_t MemoryBlock::y() const {
return y_;
}
NetDef::NetDef() : has_bits_(0) {}
const std::string &NetDef::name() const {
return name_;
}
void NetDef::set_name(const std::string &value) {
set_has_name();
name_ = value;
}
bool NetDef::has_name() const {
return (has_bits_ & 0x00000001u) != 0;
}
void NetDef::set_has_name() {
has_bits_ |= 0x00000001u;
}
const std::string &NetDef::version() const {
return version_;
}
void NetDef::set_version(const std::string &value) {
set_has_version();
version_ = value;
}
bool NetDef::has_version() const {
return (has_bits_ & 0x00000002u) != 0;
}
void NetDef::set_has_version() {
has_bits_ |= 0x00000002u;
}
const std::vector<OperatorDef> &NetDef::op() const {
return op_;
}
OperatorDef *NetDef::add_op() {
op_.emplace_back(OperatorDef());
return &op_.back();
}
std::vector<OperatorDef> &NetDef::mutable_op() {
return op_;
}
const std::vector<Argument> &NetDef::arg() const {
return arg_;
}
Argument *NetDef::add_arg() {
arg_.emplace_back(Argument());
return &arg_.back();
}
std::vector<Argument> &NetDef::mutable_arg() {
return arg_;
}
const std::vector<TensorProto> &NetDef::tensors() const {
return tensors_;
}
std::vector<TensorProto> &NetDef::mutable_tensors() {
return tensors_;
}
const MemoryArena &NetDef::mem_arena() const {
return mem_arena_;
}
MemoryArena &NetDef::mutable_mem_arena() {
set_has_mem_arena();
return mem_arena_;
}
bool NetDef::has_mem_arena() const {
return (has_bits_ & 0x00000004u) != 0;
}
void NetDef::set_has_mem_arena() {
has_bits_ |= 0x00000004u;
}
const std::vector<InputInfo> &NetDef::input_info() const {
return input_info_;
}
const std::vector<OutputInfo> &NetDef::output_info() const {
return output_info_;
}
int NetDef::op_size() const {
return op_.size();
}
const OperatorDef &NetDef::op(const int idx) const {
MACE_CHECK(0 <= idx && idx < op_size());
return op_[idx];
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_MACE_H_
#define MACE_CORE_MACE_H_
#include <cstdint>
#include <vector>
#include <string>
#include "mace/core/logging.h"
namespace mace {
enum NetMode {
INIT = 0,
NORMAL = 1
};
enum DeviceType {
CPU = 0,
NEON = 1,
OPENCL = 2
};
enum DataType {
DT_INVALID = 0,
DT_FLOAT = 1,
DT_DOUBLE = 2,
DT_INT32 = 3,
DT_UINT8 = 4,
DT_INT16 = 5,
DT_INT8 = 6,
DT_STRING = 7,
DT_INT64 = 8,
DT_UINT16 = 9,
DT_BOOL = 10,
DT_HALF = 19,
DT_UINT32 = 22
};
class TensorProto {
public:
TensorProto(const std::string &name,
unsigned char *data,
const std::vector<int64_t> &dims,
const DataType data_type = DT_FLOAT,
uint32_t node_id = 0);
TensorProto(const std::string &name,
unsigned char *data,
const std::vector<int64_t> &dims,
const int data_type,
uint32_t node_id = 0);
const std::string &name() const;
unsigned char *data() const;
const int data_size() const;
const std::vector<int64_t> &dims() const;
DataType data_type() const;
uint32_t node_id() const;
private:
std::string name_;
unsigned char *data_;
int data_size_;
std::vector<int64_t> dims_;
DataType data_type_;
uint32_t node_id_;
};
class Argument {
public:
Argument();
void CopyFrom(const Argument &from) ;
public:
const std::string &name() const;
void set_name(const std::string& value);
bool has_f() const;
float f() const ;
void set_f(float value) ;
bool has_i() const ;
int64_t i() const ;
void set_i(int64_t value);
bool has_s() const ;
std::string s() const ;
void set_s(const std::string& value) ;
const std::vector<float> &floats() const ;
void add_floats(float value) ;
void set_floats(const std::vector<float> &value);
const std::vector<int64_t> &ints() const ;
void add_ints(int64_t value) ;
void set_ints(const std::vector<int64_t> &value);
const std::vector<std::string> &strings() const ;
void add_strings(const ::std::string& value) ;
void set_strings(const std::vector<std::string> &value);
private:
void set_has_f() ;
void set_has_i() ;
void set_has_s() ;
private:
std::string name_;
float f_;
int64_t i_;
std::string s_;
std::vector<float> floats_;
std::vector<int64_t > ints_;
std::vector<std::string> strings_;
uint32_t has_bits_;
};
class NodeInput {
public:
void CopyFrom(const NodeInput &from) {
node_id_ = from.node_id();
output_port_ = from.output_port();
}
public:
int node_id() const {
return node_id_;
}
int output_port() const {
return output_port_;
}
private:
int node_id_;
int output_port_;
};
class OutputShape {
public:
void CopyFrom(const OutputShape &from) {
auto from_dims = from.dims();
dims_.resize(from_dims.size());
std::copy(from_dims.begin(), from_dims.end(), dims_.begin());
}
public:
const std::vector<int64_t> &dims() const {
return dims_;
}
private:
std::vector<int64_t> dims_;
};
class OperatorDef {
public:
void CopyFrom(const OperatorDef &from);
public:
const std::string &name() const;
void set_name(const std::string &name_);
bool has_name() const;
const std::string &type() const;
void set_type(const std::string &type_);
bool has_type() const;
int mem_id() const;
void set_mem_id(const int mem_id);
bool has_mem_id() const;
uint32_t node_id() const;
uint32_t op_id() const;
uint32_t padding() const;
const std::vector<NodeInput> &node_input() const;
const std::vector<int> &out_max_byte_size() const;
const std::vector<std::string> &input() const;
const std::string& input(int index) const;
std::string* add_input();
void add_input(const ::std::string& value);
void add_input(::std::string&& value);
void set_input(const std::vector<std::string> &value);
const std::vector<std::string> &output() const;
const std::string& output(int index) const;
std::string* add_output();
void add_output(const ::std::string& value);
void add_output(::std::string&& value);
void set_output(const std::vector<std::string> &value);
const std::vector<Argument> &arg() const;
Argument* add_arg();
const std::vector<OutputShape> &output_shape() const;
void set_output_shape(const std::vector<OutputShape> &value);
const std::vector<DataType> &output_type() const;
void set_output_type(const std::vector<DataType> &value);
private:
void set_has_name();
void set_has_type();
void set_has_mem_id();
private:
std::string name_;
std::string type_;
std::vector<std::string> input_;
std::vector<std::string> output_;
std::vector<Argument> arg_;
std::vector<OutputShape> output_shape_;
std::vector<DataType> output_type_;
int mem_id_;
// nnlib
uint32_t node_id_;
uint32_t op_id_;
uint32_t padding_;
std::vector<NodeInput> node_input_;
std::vector<int> out_max_byte_size_;
uint32_t has_bits_;
};
class MemoryBlock {
public:
MemoryBlock(int mem_id, uint32_t x, uint32_t y);
public:
int mem_id() const;
uint32_t x() const;
uint32_t y() const;
private:
int mem_id_;
uint32_t x_;
uint32_t y_;
};
class MemoryArena {
public:
inline const std::vector<MemoryBlock> &mem_block() const {
return mem_block_;
}
inline std::vector<MemoryBlock> &mutable_mem_block() {
return mem_block_;
}
inline int mem_block_size() const {
return mem_block_.size();
}
private:
std::vector<MemoryBlock> mem_block_;
};
// for hexagon mace-nnlib
class InputInfo {
public:
const std::string &name() const {
return name_;
}
int32_t node_id() const {
return node_id_;
}
int32_t max_byte_size() const {
return max_byte_size_;
}
DataType data_type() const {
return data_type_;
}
const std::vector<int32_t> &dims() const {
return dims_;
}
private:
std::string name_;
int32_t node_id_;
int32_t max_byte_size_; // only support 32-bit len
DataType data_type_;
std::vector<int32_t> dims_;
};
class OutputInfo {
public:
const std::string &name() const {
return name_;
}
int32_t node_id() const {
return node_id_;
}
int32_t max_byte_size() const {
return max_byte_size_;
}
DataType data_type() const {
return data_type_;
}
const std::vector<int32_t> &dims() const {
return dims_;
}
private:
std::string name_;
int32_t node_id_;
int32_t max_byte_size_; // only support 32-bit len
DataType data_type_;
std::vector<int32_t> dims_;
};
class NetDef {
public:
NetDef();
int op_size() const;
const OperatorDef &op(const int idx) const;
public:
const std::string &name() const;
bool has_name() const;
void set_name(const std::string& value);
const std::string &version() const;
bool has_version() const;
void set_version(const std::string& value);
const std::vector<OperatorDef> &op() const;
OperatorDef* add_op();
std::vector<OperatorDef> &mutable_op();
const std::vector<Argument> &arg() const;
Argument *add_arg();
std::vector<Argument> &mutable_arg();
const std::vector<TensorProto> &tensors() const;
std::vector<TensorProto> &mutable_tensors();
const MemoryArena &mem_arena() const;
bool has_mem_arena() const;
MemoryArena &mutable_mem_arena();
const std::vector<InputInfo> &input_info() const;
const std::vector<OutputInfo> &output_info() const;
private:
void set_has_name();
void set_has_version();
void set_has_mem_arena();
private:
std::string name_;
std::string version_;
std::vector<OperatorDef> op_;
std::vector<Argument> arg_;
std::vector<TensorProto> tensors_;
// for mem optimization
MemoryArena mem_arena_;
// for hexagon mace-nnlib
std::vector<InputInfo> input_info_;
std::vector<OutputInfo> output_info_;
uint32_t has_bits_;
};
} // namespace mace
#endif // MACE_CORE_MACE_H_
......@@ -50,7 +50,7 @@ bool SimpleNet::Run(RunMetadata *run_metadata) {
}
}
if (!op->Run()) {
LOG(ERROR) << "Operator failed: " << ProtoDebugString(op->debug_def());
LOG(ERROR) << "Operator failed: " << op->debug_def().name();
return false;
}
......
......@@ -8,7 +8,7 @@
#include "mace/core/common.h"
#include "mace/core/operator.h"
#include "mace/core/workspace.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
#include "mace/proto/stats.pb.h"
namespace mace {
......
......@@ -10,7 +10,7 @@
#include "mace/core/registry.h"
#include "mace/core/tensor.h"
#include "mace/core/workspace.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
namespace mace {
......
......@@ -4,163 +4,12 @@
#include "mace/core/proto_utils.h"
#include <fcntl.h>
#include <unistd.h>
#include <cerrno>
#include <fstream>
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#ifndef MACE_USE_LITE_PROTO
#include "google/protobuf/text_format.h"
#endif // !MACE_USE_LITE_PROTO
namespace mace {
bool ReadStringFromFile(const char *filename, string *str) {
std::ifstream ifs(filename, std::ios::in);
if (!ifs) {
VLOG(1) << "File cannot be opened: " << filename
<< " error: " << ifs.rdstate();
return false;
}
ifs.seekg(0, std::ios::end);
size_t n = ifs.tellg();
str->resize(n);
ifs.seekg(0);
ifs.read(&(*str)[0], n);
return true;
}
bool WriteStringToFile(const string &str, const char *filename) {
std::ofstream ofs(filename, std::ios::out | std::ios::trunc);
if (!ofs.is_open()) {
VLOG(1) << "File cannot be created: " << filename
<< " error: " << ofs.rdstate();
return false;
}
ofs << str;
return true;
}
// IO-specific proto functions: we will deal with the protocol buffer lite and
// full versions differently.
#ifdef MACE_USE_LITE_PROTO
// Lite runtime.
namespace {
class IfstreamInputStream : public ::google::protobuf::io::CopyingInputStream {
public:
explicit IfstreamInputStream(const string &filename)
: ifs_(filename.c_str(), std::ios::in | std::ios::binary) {}
~IfstreamInputStream() { ifs_.close(); }
int Read(void *buffer, int size) {
if (!ifs_) {
return -1;
}
ifs_.read(static_cast<char *>(buffer), size);
return ifs_.gcount();
}
private:
std::ifstream ifs_;
};
} // namespace
bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto) {
::google::protobuf::io::CopyingInputStreamAdaptor stream(
new IfstreamInputStream(filename));
stream.SetOwnsCopyingStream(true);
// Total bytes hard limit / warning limit are set to 1GB and 512MB
// respectively.
::google::protobuf::io::CodedInputStream coded_stream(&stream);
coded_stream.SetTotalBytesLimit(1024LL << 20, 512LL << 20);
return proto->ParseFromCodedStream(&coded_stream);
}
void WriteProtoToBinaryFile(const MessageLite & /*proto*/,
const char * /*filename*/) {
LOG(FATAL) << "Not implemented yet.";
}
#else // MACE_USE_LITE_PROTO
// Full protocol buffer.
using ::google::protobuf::io::FileInputStream;
using ::google::protobuf::io::FileOutputStream;
using ::google::protobuf::io::ZeroCopyInputStream;
using ::google::protobuf::io::CodedInputStream;
using ::google::protobuf::io::ZeroCopyOutputStream;
using ::google::protobuf::io::CodedOutputStream;
bool ReadProtoFromTextFile(const char *filename, Message *proto) {
int fd = open(filename, O_RDONLY);
MACE_CHECK(fd != -1, "File not found: ", filename);
FileInputStream *input = new FileInputStream(fd);
bool success = google::protobuf::TextFormat::Parse(input, proto);
delete input;
close(fd);
return success;
}
void WriteProtoToTextFile(const Message &proto, const char *filename) {
int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
FileOutputStream *output = new FileOutputStream(fd);
MACE_CHECK(google::protobuf::TextFormat::Print(proto, output));
delete output;
close(fd);
}
bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto) {
#if defined(_MSC_VER) // for MSC compiler binary flag needs to be specified
int fd = open(filename, O_RDONLY | O_BINARY);
#else
int fd = open(filename, O_RDONLY);
#endif
MACE_CHECK(fd != -1, "File not found: ", filename);
std::unique_ptr<ZeroCopyInputStream> raw_input(new FileInputStream(fd));
std::unique_ptr<CodedInputStream> coded_input(
new CodedInputStream(raw_input.get()));
// A hack to manually allow using very large protocol buffers.
coded_input->SetTotalBytesLimit(1073741824, 536870912);
bool success = proto->ParseFromCodedStream(coded_input.get());
coded_input.reset();
raw_input.reset();
close(fd);
return success;
}
void WriteProtoToBinaryFile(const MessageLite &proto, const char *filename) {
int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644);
MACE_CHECK(fd != -1, "File cannot be created: ", filename, " error number: ",
errno);
std::unique_ptr<ZeroCopyOutputStream> raw_output(new FileOutputStream(fd));
std::unique_ptr<CodedOutputStream> coded_output(
new CodedOutputStream(raw_output.get()));
MACE_CHECK(proto.SerializeToCodedStream(coded_output.get()));
coded_output.reset();
raw_output.reset();
close(fd);
}
#endif // MACE_USE_LITE_PROTO
ArgumentHelper::ArgumentHelper(const OperatorDef &def) {
for (auto &arg : def.arg()) {
if (arg_map_.find(arg.name()) != arg_map_.end()) {
MACE_CHECK(
arg.SerializeAsString() == arg_map_[arg.name()].SerializeAsString(),
"Found argument of the same name '", arg.name(),
"' but with different contents: ", ProtoDebugString(def));
LOG(WARNING) << "Duplicated argument name found in operator def: "
<< ProtoDebugString(def)
<< ", arg: " << ProtoDebugString(arg);
LOG(WARNING) << "Duplicated argument name found in operator def.";
}
arg_map_[arg.name()] = arg;
......@@ -170,8 +19,7 @@ ArgumentHelper::ArgumentHelper(const OperatorDef &def) {
ArgumentHelper::ArgumentHelper(const NetDef &netdef) {
for (auto &arg : netdef.arg()) {
MACE_CHECK(arg_map_.count(arg.name()) == 0,
"Duplicated argument name found in net def: ",
ProtoDebugString(netdef));
"Duplicated argument name found in net def.");
arg_map_[arg.name()] = arg;
}
}
......@@ -265,88 +113,4 @@ INSTANTIATE_GET_REPEATED_ARGUMENT(size_t, ints, true)
INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings, false)
#undef INSTANTIATE_GET_REPEATED_ARGUMENT
#define MACE_MAKE_SINGULAR_ARGUMENT(T, fieldname) \
template <> \
Argument MakeArgument(const string &name, const T &value) { \
Argument arg; \
arg.set_name(name); \
arg.set_##fieldname(value); \
return arg; \
}
MACE_MAKE_SINGULAR_ARGUMENT(bool, i)
MACE_MAKE_SINGULAR_ARGUMENT(float, f)
MACE_MAKE_SINGULAR_ARGUMENT(int, i)
MACE_MAKE_SINGULAR_ARGUMENT(int64_t, i)
MACE_MAKE_SINGULAR_ARGUMENT(string, s)
#undef MACE_MAKE_SINGULAR_ARGUMENT
template <>
Argument MakeArgument(const string &name, const MessageLite &value) {
Argument arg;
arg.set_name(name);
arg.set_s(value.SerializeAsString());
return arg;
}
#define MACE_MAKE_REPEATED_ARGUMENT(T, fieldname) \
template <> \
Argument MakeArgument(const string &name, const vector<T> &value) { \
Argument arg; \
arg.set_name(name); \
for (const auto &v : value) { \
arg.add_##fieldname(v); \
} \
return arg; \
}
MACE_MAKE_REPEATED_ARGUMENT(float, floats)
MACE_MAKE_REPEATED_ARGUMENT(int, ints)
MACE_MAKE_REPEATED_ARGUMENT(int64_t, ints)
MACE_MAKE_REPEATED_ARGUMENT(string, strings)
#undef MACE_MAKE_REPEATED_ARGUMENT
const Argument &GetArgument(const OperatorDef &def, const string &name) {
for (const Argument &arg : def.arg()) {
if (arg.name() == name) {
return arg;
}
}
MACE_CHECK(false, "Argument named ", name, "does not exist in operator ",
ProtoDebugString(def));
// should not reach here, just make compiler happy
return std::move(Argument());
}
bool GetFlagArgument(const OperatorDef &def,
const string &name,
bool def_value) {
for (const Argument &arg : def.arg()) {
if (arg.name() == name) {
MACE_CHECK(arg.has_i(), "Can't parse argument as bool: ",
ProtoDebugString(arg));
return arg.i();
}
}
return def_value;
}
Argument *GetMutableArgument(const string &name,
const bool create_if_missing,
OperatorDef *def) {
for (int i = 0; i < def->arg_size(); ++i) {
if (def->arg(i).name() == name) {
return def->mutable_arg(i);
}
}
// If no argument of the right name is found...
if (create_if_missing) {
Argument *arg = def->add_arg();
arg->set_name(name);
return arg;
} else {
return nullptr;
}
}
} // namespace mace
......@@ -7,137 +7,12 @@
#include <map>
#include "google/protobuf/message_lite.h"
#ifndef MACE_USE_LITE_PROTO
#include "google/protobuf/message.h"
#endif // !MACE_USE_LITE_PROTO
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
namespace mace {
using std::string;
using ::google::protobuf::MessageLite;
// Common interfaces that reads file contents into a string.
bool ReadStringFromFile(const char *filename, string *str);
bool WriteStringToFile(const string &str, const char *filename);
// Common interfaces that are supported by both lite and full protobuf.
bool ReadProtoFromBinaryFile(const char *filename, MessageLite *proto);
inline bool ReadProtoFromBinaryFile(const string filename, MessageLite *proto) {
return ReadProtoFromBinaryFile(filename.c_str(), proto);
}
void WriteProtoToBinaryFile(const MessageLite &proto, const char *filename);
inline void WriteProtoToBinaryFile(const MessageLite &proto,
const string &filename) {
return WriteProtoToBinaryFile(proto, filename.c_str());
}
#ifdef MACE_USE_LITE_PROTO
inline string ProtoDebugString(const MessageLite &proto) {
return proto.SerializeAsString();
}
// Text format MessageLite wrappers: these functions do nothing but just
// allowing things to compile. It will produce a runtime error if you are using
// MessageLite but still want text support.
inline bool ReadProtoFromTextFile(const char * /*filename*/,
MessageLite * /*proto*/) {
LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers.";
return false; // Just to suppress compiler warning.
}
inline bool ReadProtoFromTextFile(const string filename, MessageLite *proto) {
return ReadProtoFromTextFile(filename.c_str(), proto);
}
inline void WriteProtoToTextFile(const MessageLite & /*proto*/,
const char * /*filename*/) {
LOG(FATAL) << "If you are running lite version, you should not be "
<< "calling any text-format protobuffers.";
}
inline void WriteProtoToTextFile(const MessageLite &proto,
const string &filename) {
return WriteProtoToTextFile(proto, filename.c_str());
}
inline bool ReadProtoFromFile(const char *filename, MessageLite *proto) {
return (ReadProtoFromBinaryFile(filename, proto) ||
ReadProtoFromTextFile(filename, proto));
}
inline bool ReadProtoFromFile(const string &filename, MessageLite *proto) {
return ReadProtoFromFile(filename.c_str(), proto);
}
#else // MACE_USE_LITE_PROTO
using ::google::protobuf::Message;
inline string ProtoDebugString(const Message &proto) {
return proto.ShortDebugString();
}
bool ReadProtoFromTextFile(const char *filename, Message *proto);
inline bool ReadProtoFromTextFile(const string filename, Message *proto) {
return ReadProtoFromTextFile(filename.c_str(), proto);
}
void WriteProtoToTextFile(const Message &proto, const char *filename);
inline void WriteProtoToTextFile(const Message &proto, const string &filename) {
return WriteProtoToTextFile(proto, filename.c_str());
}
// Read Proto from a file, letting the code figure out if it is text or binary.
inline bool ReadProtoFromFile(const char *filename, Message *proto) {
return (ReadProtoFromBinaryFile(filename, proto) ||
ReadProtoFromTextFile(filename, proto));
}
inline bool ReadProtoFromFile(const string &filename, Message *proto) {
return ReadProtoFromFile(filename.c_str(), proto);
}
#endif // MACE_USE_LITE_PROTO
template <class IterableInputs = std::initializer_list<string>,
class IterableOutputs = std::initializer_list<string>,
class IterableArgs = std::initializer_list<Argument>>
OperatorDef CreateOperatorDef(const string &type,
const string &name,
const IterableInputs &inputs,
const IterableOutputs &outputs,
const IterableArgs &args) {
OperatorDef def;
def.set_type(type);
def.set_name(name);
for (const string &in : inputs) {
def.add_input(in);
}
for (const string &out : outputs) {
def.add_output(out);
}
for (const Argument &arg : args) {
def.add_arg()->CopyFrom(arg);
}
return def;
}
// A simplified version compared to the full CreateOperator, if you do not need
// to specify args.
template <class IterableInputs = std::initializer_list<string>,
class IterableOutputs = std::initializer_list<string>>
inline OperatorDef CreateOperatorDef(const string &type,
const string &name,
const IterableInputs &inputs,
const IterableOutputs &outputs) {
return CreateOperatorDef(type, name, inputs, outputs,
std::vector<Argument>());
}
/**
* @brief A helper class to index into arguments.
......@@ -174,17 +49,6 @@ class ArgumentHelper {
return ArgumentHelper(def).GetRepeatedArgument<T>(name, default_value);
}
template <typename Def, typename MessageType>
static MessageType GetMessageArgument(const Def &def, const string &name) {
return ArgumentHelper(def).GetMessageArgument<MessageType>(name);
}
template <typename Def, typename MessageType>
static vector<MessageType> GetRepeatedMessageArgument(const Def &def,
const string &name) {
return ArgumentHelper(def).GetRepeatedMessageArgument<MessageType>(name);
}
explicit ArgumentHelper(const OperatorDef &def);
explicit ArgumentHelper(const NetDef &netdef);
bool HasArgument(const string &name) const;
......@@ -198,51 +62,10 @@ class ArgumentHelper {
const string &name,
const std::vector<T> &default_value = std::vector<T>()) const;
template <typename MessageType>
MessageType GetMessageArgument(const string &name) const {
MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name);
MessageType message;
if (arg_map_.at(name).has_s()) {
MACE_CHECK(message.ParseFromString(arg_map_.at(name).s()),
"Faild to parse content from the string");
} else {
VLOG(1) << "Return empty message for parameter " << name;
}
return message;
}
template <typename MessageType>
vector<MessageType> GetRepeatedMessageArgument(const string &name) const {
MACE_CHECK(arg_map_.count(name), "Cannot find parameter named " + name);
vector<MessageType> messages(arg_map_.at(name).strings_size());
for (int i = 0; i < messages.size(); ++i) {
MACE_CHECK(messages[i].ParseFromString(arg_map_.at(name).strings(i)),
"Faild to parse content from the string");
}
return messages;
}
private:
std::map<string, Argument> arg_map_;
};
const Argument &GetArgument(const OperatorDef &def, const string &name);
bool GetFlagArgument(const OperatorDef &def,
const string &name,
bool def_value = false);
Argument *GetMutableArgument(const string &name,
const bool create_if_missing,
OperatorDef *def);
template <typename T>
Argument MakeArgument(const string &name, const T &value);
template <typename T>
inline void AddArgument(const string &name, const T &value, OperatorDef *def) {
GetMutableArgument(name, true, def)->CopyFrom(MakeArgument(name, value));
}
} // namespace mace
#endif // MACE_CORE_PROTO_UTILS_H_
......@@ -24,46 +24,32 @@ unique_ptr<Tensor> Serializer::Deserialize(const TensorProto &proto,
switch (proto.data_type()) {
case DT_FLOAT:
tensor->Copy<float>(proto.float_data().data(), proto.float_data().size());
tensor->Copy<float>(reinterpret_cast<float*>(proto.data()), proto.data_size());
break;
case DT_DOUBLE:
tensor->Copy<double>(proto.double_data().data(),
proto.double_data().size());
tensor->Copy<double>(reinterpret_cast<double*>(proto.data()), proto.data_size());
break;
case DT_INT32:
tensor->template Copy<int32_t>(proto.int32_data().data(),
proto.int32_data().size());
tensor->Copy<int32_t>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
break;
case DT_INT64:
tensor->Copy<int64_t>(reinterpret_cast<int64_t*>(proto.data()), proto.data_size());
break;
case DT_UINT8:
tensor->CopyWithCast<int32_t, uint8_t>(proto.int32_data().data(),
proto.int32_data().size());
tensor->CopyWithCast<int32_t, uint8_t>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
break;
case DT_INT16:
tensor->CopyWithCast<int32_t, int16_t>(proto.int32_data().data(),
proto.int32_data().size());
tensor->CopyWithCast<int32_t, uint16_t>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
break;
case DT_INT8:
tensor->CopyWithCast<int32_t, int8_t>(proto.int32_data().data(),
proto.int32_data().size());
break;
case DT_INT64:
tensor->Copy<int64_t>(proto.int64_data().data(),
proto.int64_data().size());
tensor->CopyWithCast<int32_t, int8_t>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
break;
case DT_UINT16:
tensor->CopyWithCast<int32_t, uint16_t>(proto.int32_data().data(),
proto.int32_data().size());
tensor->CopyWithCast<int32_t, int16_t>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
break;
case DT_BOOL:
tensor->CopyWithCast<int32_t, bool>(proto.int32_data().data(),
proto.int32_data().size());
tensor->CopyWithCast<int32_t, bool>(reinterpret_cast<int32_t*>(proto.data()), proto.data_size());
break;
case DT_STRING: {
string *content = tensor->mutable_data<string>();
for (int i = 0; i < proto.string_data().size(); ++i) {
content[i] = proto.string_data(i);
}
} break;
default:
MACE_NOT_IMPLEMENTED;
break;
......
......@@ -7,7 +7,7 @@
#include "mace/core/common.h"
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
namespace mace {
......
......@@ -9,7 +9,7 @@
#include "mace/core/common.h"
#include "mace/core/logging.h"
#include "mace/core/types.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
namespace mace {
......
......@@ -6,7 +6,7 @@
#define MACE_CORE_TYPES_H_
#include "mace/core/common.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
#include "mace/core/half.h"
......
......@@ -69,7 +69,7 @@ void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) {
}
void Workspace::CreateImageOutputTensor(const NetDef &net_def) {
if (!net_def.has_mem_arena() || net_def.mem_arena().mem_block_size() == 0) {
if (net_def.has_mem_arena() || net_def.mem_arena().mem_block_size() == 0) {
return;
}
std::map<std::string, std::shared_ptr<Tensor>> mem_tensor_map;
......
......@@ -7,7 +7,7 @@
#include "mace/core/common.h"
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
namespace mace {
......
......@@ -20,6 +20,8 @@
using namespace std;
using namespace mace;
extern NetDef CreateNet() ;
void ParseShape(const string &str, vector<index_t> *shape) {
string tmp = str;
while (!tmp.empty()) {
......@@ -34,6 +36,18 @@ void ParseShape(const string &str, vector<index_t> *shape) {
}
}
DeviceType ParseDeviceType(const string &device_str) {
if(device_str.compare("CPU") == 0) {
return DeviceType::CPU;
} else if (device_str.compare("NEON") == 0) {
return DeviceType::NEON;
} else if (device_str.compare("OPENCL") == 0) {
return DeviceType::OPENCL;
} else {
return DeviceType::CPU;
}
}
int main(int argc, char **argv) {
string model_file;
string input_node;
......@@ -76,13 +90,13 @@ int main(int argc, char **argv) {
ParseShape(input_shape, &shape);
// load model
ifstream file_stream(model_file, ios::in | ios::binary);
NetDef net_def;
net_def.ParseFromIstream(&file_stream);
file_stream.close();
// ifstream file_stream(model_file, ios::in | ios::binary);
// NetDef net_def;
// net_def.ParseFromIstream(&file_stream);
// file_stream.close();
NetDef net_def = CreateNet();
DeviceType device_type;
DeviceType_Parse(device, &device_type);
DeviceType device_type = ParseDeviceType(device);
VLOG(0) << device_type;
Workspace ws;
ws.LoadModelTensor(net_def, device_type);
......
......@@ -6,7 +6,7 @@
#define MACE_KERNELS_BATCH_NORM_H_
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
namespace mace {
namespace kernels {
......
......@@ -6,7 +6,7 @@
#define MACE_KERNELS_BIAS_ADD_H_
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
namespace mace {
namespace kernels {
......
......@@ -7,7 +7,7 @@
#include "mace/core/common.h"
#include "mace/core/types.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
#include "mace/core/tensor.h"
namespace mace {
......
......@@ -7,7 +7,7 @@
#include "mace/core/common.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
namespace mace {
namespace kernels {
......
......@@ -6,7 +6,7 @@
#define MACE_KERNELS_CONV_2D_H_
#include "mace/core/tensor.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
namespace mace {
namespace kernels {
......
......@@ -40,7 +40,6 @@ cc_library(
],
deps = [
"//mace/kernels",
"//mace/proto:cc_proto",
],
alwayslink = 1,
)
......
......@@ -7,7 +7,7 @@
#include "mace/core/operator.h"
#include "mace/kernels/concat.h"
#include "mace/proto/mace.pb.h"
#include "mace/core/mace.h"
namespace mace {
template <DeviceType D, typename T>
......
......@@ -159,7 +159,6 @@ class OpsTestNet {
for (auto &op_def_ : op_defs_) {
net_def.add_op()->CopyFrom(op_def_);
}
VLOG(3) << net_def.DebugString();
net_ = CreateNet(net_def, &ws_, device);
device_ = device;
return net_->Run();
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <vector>
#include <string>
#include "mace/core/mace.h"
namespace mace {
{% for tensor in tensors %}
static unsigned char {{ "_" + tensor.name[:-2].replace("/", "_") }}[] = {
{% for d in tensor.data %}{{"0x%02X, " % d }}{%endfor%}
};
{% endfor %}
static void CreateNetArg(NetDef &net_def) {
net_def.mutable_arg().reserve({{ net.arg|length }});
Argument *arg = nullptr;
{% for arg in net.arg %}
arg = net_def.add_arg();
arg->set_name({{ arg.name|tojson }});
{% if arg.has_f %}
arg->set_f({{ arg.f }});
{% endif %}
{% if arg.has_i %}
arg->set_i({{ arg.i }});
{% endif %}
{% if arg.has_s %}
arg->set_s({{ arg.s|tojson }});
{% endif %}
arg->set_floats({ {{ arg.floats|join(', ') }} });
arg->set_ints({ {{ arg.ints|join(', ') }} });
arg->set_strings({ {{ arg.strings|stringfy() }} });
{% endfor %}
}
static void UpdateOp(OperatorDef &op,
const std::string &name,
const std::string &type,
const int mem_id,
const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs,
const std::vector<OutputShape> &output_shapes,
const std::vector<DataType> &output_types) {
op.set_name(name);
op.set_type(type);
op.set_input(inputs);
op.set_output(outputs);
op.set_mem_id(mem_id);
op.set_output_shape(output_shapes);
op.set_output_type(output_types);
}
static void CreateOperators(std::vector<OperatorDef> &ops) {
ops.resize({{ net.op|length }});
Argument *arg = nullptr;
{% for i in range(net.op|length) %}
{% for arg in net.op[i].arg %}
arg = ops[{{i}}].add_arg();
arg->set_name({{ arg.name|tojson }});
{%- if arg.HasField('f') %}
arg->set_f({{ arg.f }});
{%- endif %}
{%- if arg.HasField('i') %}
arg->set_i({{ arg.i }});
{%- endif %}
{%- if arg.HasField('s') %}
arg->set_s({{ arg.s|tojson }});
{%- endif %}
arg->set_floats({ {{ arg.floats|join(', ') }} });
arg->set_ints({ {{ arg.ints|join(', ') }} });
arg->set_strings({ {{ arg.strings|stringfy() }} });
{% endfor %}
UpdateOp(ops[{{i}}], {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}}, {{ net.op[i].mem_id }},
{ {{ net.op[i].input|stringfy }} },
{ {{ net.op[i].output|stringfy }} },
{ {{ net.op[i].output_shape.dims|join(', ') }} },
{ {{ net.op[i].output_type|join(', ') }} });
{% endfor %}
}
static void CreateTensors(std::vector<TensorProto> &tensors) {
tensors.reserve({{ net.tensors|length }});
{% for tensor in net.tensors %}
tensors.emplace_back(TensorProto(
{{ tensor.name|tojson }}, {{ "_" + tensor.name[:-2].replace("/", "_") }},
{ {{ tensor.dims|join(', ') }} }, {{ tensor.data_type }},
{{ tensor.node_id }}
));
{% endfor %}
}
static void CreateMemoryArena(MemoryArena &mem_arena) {
auto mem_block = mem_arena.mutable_mem_block();
mem_block.reserve({{ net.mem_arena.mem_block|length }});
{% for mem_blk in net.mem_arena.mem_block %}
mem_block.emplace_back(MemoryBlock({{ mem_blk.mem_id }},
{{mem_blk.x}},
{{mem_blk.y}}));
{% endfor %}
}
NetDef CreateNet() {
NetDef net_def;
net_def.set_name("{{ net.name}}");
net_def.set_version("{{ net.version }}");
CreateNetArg(net_def);
CreateOperators(net_def.mutable_op());
CreateTensors(net_def.mutable_tensors());
CreateMemoryArena(net_def.mutable_mem_arena());
return net_def;
}
} // namespace mace
......@@ -2,13 +2,45 @@ import argparse
import sys
import tensorflow as tf
from tensorflow import gfile
from mace.proto import mace_pb2
from mace.python.tools import tf_converter_lib
from mace.python.tools import tf_dsp_converter_lib
import struct
from jinja2 import Environment, FileSystemLoader
import os
# ./bazel-bin/mace/python/tools/tf_converter --input quantized_test.pb --output quantized_test_dsp.pb --runtime dsp --input_dim input_node,1,28,28,3
FLAGS = None
class TensorInfo:
def __init__(self, t):
self.name = t.name
if t.data_type == mace_pb2.DT_FLOAT:
self.data = bytearray(struct.pack('%sf' % len(t.float_data), *t.float_data))
elif t.data_type == mace_pb2.DT_INT32:
self.data = bytearray(struct.pack('%si' % len(t.int32_data), *t.int32_data))
def stringfy(value):
return ', '.join('"{0}"'.format(w) for w in value)
def convert_to_source(net_def):
# Capture our current directory
template_dir = os.path.dirname(FLAGS.template)
template_name = os.path.basename(FLAGS.template)
print template_dir
# Create the jinja2 environment.
# Notice the use of trim_blocks, which greatly helps control whitespace.
j2_env = Environment(loader=FileSystemLoader(template_dir),
trim_blocks=True)
j2_env.filters['stringfy'] = stringfy
tensors = [TensorInfo(t) for t in net_def.tensors]
return j2_env.get_template(template_name).render(
tensors = tensors,
net = net_def
)
def main(unused_args):
if not gfile.Exists(FLAGS.input):
print("Input graph file '" + FLAGS.input + "' does not exist!")
......@@ -19,6 +51,7 @@ def main(unused_args):
data = f.read()
input_graph_def.ParseFromString(data)
print 'done'
if FLAGS.runtime == 'dsp':
output_graph_def = tf_dsp_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.prequantize)
......@@ -26,11 +59,16 @@ def main(unused_args):
output_graph_def = tf_converter_lib.convert_to_mace_pb(
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime)
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString())
with gfile.GFile(FLAGS.output + '_txt', "wb") as f:
# output_graph_def.ClearField('tensors')
f.write(str(output_graph_def))
if FLAGS.output_type == 'source':
source = convert_to_source(output_graph_def)
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(source)
else:
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString())
with gfile.GFile(FLAGS.output + '_txt', "wb") as f:
# output_graph_def.ClearField('tensors')
f.write(str(output_graph_def))
def parse_args():
......@@ -51,7 +89,7 @@ def parse_args():
"--runtime",
type=str,
default="cpu",
help="Runtime: cpu/gpu/dsp.")
help="Runtime: cpu/gpu/dsp")
parser.add_argument(
"--input_node",
type=str,
......@@ -72,6 +110,16 @@ def parse_args():
type=str,
default='DT_FLOAT',
help="e.g., DT_HALF/DT_FLOAT")
parser.add_argument(
"--output_type",
type=str,
default="source",
help="output type: source/pb")
parser.add_argument(
"--template",
type=str,
default="",
help="template path")
return parser.parse_known_args()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册