提交 d53aebce 编写于 作者: L liuqi

Finish the pb to source.

上级 f05649a4
......@@ -3,7 +3,7 @@
//
#include "mace/core/mace.h"
#include "mace/core/logging.h"
#include "mace/core/types.h"
namespace mace {
......@@ -14,9 +14,12 @@ TensorProto::TensorProto(const std::string &name,
uint32_t node_id) :
name_(name),
data_(data),
dims_(dims),
data_size_(0),
dims_(dims.begin(), dims.end()),
data_type_(data_type),
node_id_(node_id) {}
node_id_(node_id) {
data_size_ = std::accumulate(dims_.begin(), dims_.end(), 1, std::multiplies<int64_t>());
}
TensorProto::TensorProto(const std::string &name,
unsigned char *data,
......@@ -25,9 +28,12 @@ TensorProto::TensorProto(const std::string &name,
uint32_t node_id) :
name_(name),
data_(data),
dims_(dims),
data_size_(0),
dims_(dims.begin(), dims.end()),
data_type_(static_cast<DataType>(data_type)),
node_id_(node_id) {}
node_id_(node_id) {
data_size_ = std::accumulate(dims_.begin(), dims_.end(), 1, std::multiplies<int64_t>());
}
const std::string &TensorProto::name() const {
return name_;
......@@ -35,7 +41,7 @@ const std::string &TensorProto::name() const {
unsigned char *TensorProto::data() const {
return data_;
}
const int TensorProto::data_size() const {
const int64_t TensorProto::data_size() const {
return data_size_;
}
const std::vector<int64_t> &TensorProto::dims() const {
......@@ -119,7 +125,7 @@ void Argument::add_floats(float value) {
floats_.push_back(value);
}
void Argument::set_floats(const std::vector<float> &value) {
floats_.reserve(value.size());
floats_.resize(value.size());
std::copy(value.begin(), value.end(), floats_.begin());
}
const std::vector<int64_t> &Argument::ints() const {
......@@ -129,7 +135,7 @@ void Argument::add_ints(int64_t value) {
ints_.push_back(value);
}
void Argument::set_ints(const std::vector<int64_t> &value) {
ints_.reserve(value.size());
ints_.resize(value.size());
std::copy(value.begin(), value.end(), ints_.begin());
}
const std::vector<std::string> &Argument::strings() const {
......@@ -139,10 +145,25 @@ void Argument::add_strings(const ::std::string &value) {
strings_.push_back(value);
}
void Argument::set_strings(const std::vector<std::string> &value) {
strings_.reserve(value.size());
strings_.resize(value.size());
std::copy(value.begin(), value.end(), strings_.begin());
}
// OutputShape
OutputShape::OutputShape() {}
OutputShape::OutputShape(const std::vector<int64_t> &dims):
dims_(dims.begin(), dims.end()) {}
void OutputShape::CopyFrom(const OutputShape &from) {
auto from_dims = from.dims();
dims_.resize(from_dims.size());
std::copy(from_dims.begin(), from_dims.end(), dims_.begin());
}
const std::vector<int64_t> &OutputShape::dims() const {
return dims_;
}
// Operator Def
void OperatorDef::CopyFrom(const OperatorDef &from) {
name_ = from.name();
type_ = from.type();
......@@ -258,7 +279,7 @@ void OperatorDef::add_input(::std::string &&value) {
input_.push_back(value);
}
void OperatorDef::set_input(const std::vector<std::string> &value) {
input_.reserve(value.size());
input_.resize(value.size());
std::copy(value.begin(), value.end(), input_.begin());
}
const std::vector<std::string> &OperatorDef::output() const {
......@@ -279,7 +300,7 @@ void OperatorDef::add_output(::std::string &&value) {
output_.push_back(value);
}
void OperatorDef::set_output(const std::vector<std::string> &value) {
output_.reserve(value.size());
output_.resize(value.size());
std::copy(value.begin(), value.end(), output_.begin());
}
const std::vector<Argument> &OperatorDef::arg() const {
......@@ -292,11 +313,8 @@ Argument *OperatorDef::add_arg() {
const std::vector<OutputShape> &OperatorDef::output_shape() const {
return output_shape_;
}
void OperatorDef::set_output_shape(const std::vector<OutputShape> &value) {
output_shape_.reserve(value.size());
for (int i = 0; i < value.size(); ++i) {
output_shape_[i].CopyFrom(value[i]);
}
void OperatorDef::add_output_shape(const OutputShape &value) {
output_shape_.push_back(value);
}
const std::vector<DataType> &OperatorDef::output_type() const {
return output_type_;
......@@ -306,6 +324,7 @@ void OperatorDef::set_output_type(const std::vector<DataType> &value) {
std::copy(value.begin(), value.end(), output_type_.begin());
}
// MemoryBlock
MemoryBlock::MemoryBlock(int mem_id, uint32_t x, uint32_t y) :
mem_id_(mem_id), x_(x), y_(y) {}
......@@ -319,6 +338,7 @@ uint32_t MemoryBlock::y() const {
return y_;
}
// NetDef
NetDef::NetDef() : has_bits_(0) {}
const std::string &NetDef::name() const {
......
......@@ -53,7 +53,7 @@ class TensorProto {
const std::string &name() const;
unsigned char *data() const;
const int data_size() const;
const int64_t data_size() const;
const std::vector<int64_t> &dims() const;
DataType data_type() const;
uint32_t node_id() const;
......@@ -61,7 +61,7 @@ class TensorProto {
private:
std::string name_;
unsigned char *data_;
int data_size_;
int64_t data_size_;
std::vector<int64_t> dims_;
DataType data_type_;
uint32_t node_id_;
......@@ -129,15 +129,11 @@ class NodeInput {
class OutputShape {
public:
void CopyFrom(const OutputShape &from) {
auto from_dims = from.dims();
dims_.resize(from_dims.size());
std::copy(from_dims.begin(), from_dims.end(), dims_.begin());
}
OutputShape();
OutputShape(const std::vector<int64_t> &dims);
void CopyFrom(const OutputShape &from);
public:
const std::vector<int64_t> &dims() const {
return dims_;
}
const std::vector<int64_t> &dims() const;
private:
std::vector<int64_t> dims_;
};
......@@ -176,7 +172,7 @@ class OperatorDef {
const std::vector<Argument> &arg() const;
Argument* add_arg();
const std::vector<OutputShape> &output_shape() const;
void set_output_shape(const std::vector<OutputShape> &value);
void add_output_shape(const OutputShape &value);
const std::vector<DataType> &output_type() const;
void set_output_type(const std::vector<DataType> &value);
......
......@@ -20,8 +20,9 @@
using namespace std;
using namespace mace;
extern NetDef CreateNet() ;
namespace mace {
extern NetDef CreateNet();
}
void ParseShape(const string &str, vector<index_t> *shape) {
string tmp = str;
while (!tmp.empty()) {
......@@ -94,7 +95,7 @@ int main(int argc, char **argv) {
// NetDef net_def;
// net_def.ParseFromIstream(&file_stream);
// file_stream.close();
NetDef net_def = CreateNet();
NetDef net_def = mace::CreateNet();
DeviceType device_type = ParseDeviceType(device);
VLOG(0) << device_type;
......
......@@ -2,17 +2,26 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
{% if mode == 0 %}
namespace mace {
alignas(4) unsigned char {{ "_" + tensor.name[:-2].replace("/", "_") }}[] = {
{% for d in tensor.data %}{{"0x%02X, " % d }}{%endfor%}
};
} // namespace mace
{% else %}
#include <vector>
#include <string>
#include "mace/core/mace.h"
namespace mace {
{% for tensor in tensors %}
static unsigned char {{ "_" + tensor.name[:-2].replace("/", "_") }}[] = {
{% for d in tensor.data %}{{"0x%02X, " % d }}{%endfor%}
};
extern unsigned char {{ "_" + tensor.name[:-2].replace("/", "_") }}[];
{% endfor %}
{% if net.arg|length != 0 %}
static void CreateNetArg(NetDef &net_def) {
net_def.mutable_arg().reserve({{ net.arg|length }});
Argument *arg = nullptr;
......@@ -21,25 +30,32 @@ static void CreateNetArg(NetDef &net_def) {
arg = net_def.add_arg();
arg->set_name({{ arg.name|tojson }});
{% if arg.has_f %}
{%- if arg.HasField('f') %}
arg->set_f({{ arg.f }});
{% endif %}
{% if arg.has_i %}
{%- if arg.HasField('i') %}
arg->set_i({{ arg.i }});
{% endif %}
{% if arg.has_s %}
{%- if arg.HasField('s') %}
arg->set_s({{ arg.s|tojson }});
{% endif %}
{% if arg.floats|length != 0 %}
arg->set_floats({ {{ arg.floats|join(', ') }} });
{% endif %}
{% if arg.ints|length != 0 %}
arg->set_ints({ {{ arg.ints|join(', ') }} });
{% endif %}
{% if arg.strings|length != 0 %}
arg->set_strings({ {{ arg.strings|stringfy() }} });
{% endif %}
{% endfor %}
}
{% endif %}
static void UpdateOp(OperatorDef &op,
const std::string &name,
......@@ -47,14 +63,12 @@ static void UpdateOp(OperatorDef &op,
const int mem_id,
const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs,
const std::vector<OutputShape> &output_shapes,
const std::vector<DataType> &output_types) {
op.set_name(name);
op.set_type(type);
op.set_input(inputs);
op.set_output(outputs);
op.set_mem_id(mem_id);
op.set_output_shape(output_shapes);
op.set_output_type(output_types);
}
......@@ -77,15 +91,24 @@ static void CreateOperators(std::vector<OperatorDef> &ops) {
arg->set_s({{ arg.s|tojson }});
{%- endif %}
{% if arg.floats|length != 0 %}
arg->set_floats({ {{ arg.floats|join(', ') }} });
{% endif %}
{% if arg.ints|length != 0 %}
arg->set_ints({ {{ arg.ints|join(', ') }} });
{% endif %}
{% if arg.strings|length != 0 %}
arg->set_strings({ {{ arg.strings|stringfy() }} });
{% endif %}
{% endfor %}
{% for shape in net.op[i].output_shape %}
ops[{{i}}].add_output_shape(OutputShape({ {{ shape.dims|join(', ') }} }));
{% endfor %}
UpdateOp(ops[{{i}}], {{ net.op[i].name|tojson }}, {{ net.op[i].type|tojson}}, {{ net.op[i].mem_id }},
{ {{ net.op[i].input|stringfy }} },
{ {{ net.op[i].output|stringfy }} },
{ {{ net.op[i].output_shape.dims|join(', ') }} },
{ {{ net.op[i].output_type|join(', ') }} });
{% endfor %}
......@@ -108,6 +131,7 @@ static void CreateTensors(std::vector<TensorProto> &tensors) {
}
{% if net.mem_arena.mem_block|length != 0 %}
static void CreateMemoryArena(MemoryArena &mem_arena) {
auto mem_block = mem_arena.mutable_mem_block();
mem_block.reserve({{ net.mem_arena.mem_block|length }});
......@@ -119,21 +143,27 @@ static void CreateMemoryArena(MemoryArena &mem_arena) {
{% endfor %}
}
{% endif %}
NetDef CreateNet() {
NetDef net_def;
net_def.set_name("{{ net.name}}");
net_def.set_version("{{ net.version }}");
{% if net.arg|length != 0 %}
CreateNetArg(net_def);
{% endif %}
CreateOperators(net_def.mutable_op());
CreateTensors(net_def.mutable_tensors());
{% if net.mem_arena.mem_block|length != 0 %}
CreateMemoryArena(net_def.mutable_mem_arena());
{% endif %}
return net_def;
}
} // namespace mace
{% endif %}
......@@ -35,11 +35,26 @@ def convert_to_source(net_def):
j2_env = Environment(loader=FileSystemLoader(template_dir),
trim_blocks=True)
j2_env.filters['stringfy'] = stringfy
counter = 0
output_dir = os.path.dirname(FLAGS.output) + '/'
for t in net_def.tensors:
source = j2_env.get_template(template_name).render(
tensor = TensorInfo(t),
mode = 0,
)
with gfile.GFile(output_dir + str(counter) + '.cc', "wb") as f:
f.write(source)
counter += 1
tensors = [TensorInfo(t) for t in net_def.tensors]
return j2_env.get_template(template_name).render(
source = j2_env.get_template(template_name).render(
tensors = tensors,
net = net_def
net = net_def,
mode = 1
)
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(source)
def main(unused_args):
if not gfile.Exists(FLAGS.input):
......@@ -60,9 +75,7 @@ def main(unused_args):
input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime)
if FLAGS.output_type == 'source':
source = convert_to_source(output_graph_def)
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(source)
convert_to_source(output_graph_def)
else:
with gfile.GFile(FLAGS.output, "wb") as f:
f.write(output_graph_def.SerializeToString())
......
......@@ -27,14 +27,14 @@ python tools/validate.py --generate_data true --random_seed 1 \
--input_shape="${IMAGE_SIZE},${IMAGE_SIZE},3"
# Step 2: convert tf model to mace model
echo "Step 2: convert tf model to mace model and optimize memory"
bazel build //mace/python/tools:tf_converter
bazel-bin/mace/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \
--output=${MODEL_DIR}/${MACE_MODEL_NAME} \
--input_node=input \
--output_node=GCN/br_result_2/fcn_br \
--data_type=DT_HALF \
--runtime=gpu
#echo "Step 2: convert tf model to mace model and optimize memory"
#bazel build //mace/python/tools:tf_converter
#bazel-bin/mace/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \
# --output=${MODEL_DIR}/${MACE_MODEL_NAME} \
# --input_node=input \
# --output_node=GCN/br_result_2/fcn_br \
# --data_type=DT_HALF \
# --runtime=gpu
# Step 3: Run model on the phone
echo "Step 3: Run model on the phone"
......@@ -46,7 +46,7 @@ bazel build -c opt --strip always mace/examples:mace_run \
adb shell "mkdir -p ${PHONE_DATA_DIR}"
adb shell "mkdir -p ${KERNEL_DIR}"
adb push mace/kernels/opencl/cl/* ${KERNEL_DIR}
adb push ${MODEL_DIR}/${MACE_MODEL_NAME} ${PHONE_DATA_DIR}
#adb push ${MODEL_DIR}/${MACE_MODEL_NAME} ${PHONE_DATA_DIR}
adb push ${MODEL_DIR}/${INPUT_FILE_NAME} ${PHONE_DATA_DIR}
adb push bazel-bin/mace/examples/mace_run ${PHONE_DATA_DIR}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册