提交 5dce86c6 编写于 作者: L liuqi

Fix atrous convolution transform bug.

上级 f003b9ae
...@@ -15,8 +15,8 @@ ...@@ -15,8 +15,8 @@
namespace mace { namespace mace {
namespace {{tag}} { namespace {{tag}} {
void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors, void CreateTensor{{tensor_info.id}}(std::vector<mace::ConstTensor> &tensors,
const unsigned char *model_data) { const unsigned char *model_data) {
MACE_LATENCY_LOGGER(2, "Create tensor {{ tensor.name }}"); MACE_LATENCY_LOGGER(2, "Create tensor {{ tensor.name }}");
tensors.emplace_back(mace::ConstTensor( tensors.emplace_back(mace::ConstTensor(
{{ tensor.name|tojson }}, model_data + {{ offset }}, {{ tensor.name|tojson }}, model_data + {{ offset }},
...@@ -189,8 +189,8 @@ namespace mace { ...@@ -189,8 +189,8 @@ namespace mace {
namespace {{tag}} { namespace {{tag}} {
{% for tensor in tensors %} {% for tensor in tensors %}
extern void Create{{ tensor.name }}(std::vector<mace::ConstTensor> &tensors, extern void CreateTensor{{ tensor.id }}(std::vector<mace::ConstTensor> &tensors,
const unsigned char *model_data); const unsigned char *model_data);
{% endfor %} {% endfor %}
...@@ -269,9 +269,9 @@ void CreateTensors(std::vector<mace::ConstTensor> &tensors, ...@@ -269,9 +269,9 @@ void CreateTensors(std::vector<mace::ConstTensor> &tensors,
MACE_LATENCY_LOGGER(1, "Create tensors"); MACE_LATENCY_LOGGER(1, "Create tensors");
tensors.reserve({{ net.tensors|length }}); tensors.reserve({{ net.tensors|length }});
{% for tensor in net.tensors %} {% for tensor in tensors %}
mace::{{tag}}::Create{{tensor.name}}(tensors, model_data); mace::{{tag}}::CreateTensor{{tensor.id}}(tensors, model_data);
{% endfor %} {% endfor %}
} }
......
...@@ -73,7 +73,8 @@ def rename_tensor(net_def): ...@@ -73,7 +73,8 @@ def rename_tensor(net_def):
op.output[i] = tensor_map[op.output[i]] op.output[i] = tensor_map[op.output[i]]
class TensorInfo: class TensorInfo:
def __init__(self, t, runtime): def __init__(self, id, t, runtime):
self.id = id
self.name = t.name self.name = t.name
self.data_type = mace_pb2.DataType.Name(t.data_type) self.data_type = mace_pb2.DataType.Name(t.data_type)
if t.data_type == mace_pb2.DT_FLOAT: if t.data_type == mace_pb2.DT_FLOAT:
...@@ -106,20 +107,20 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, ...@@ -106,20 +107,20 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
j2_env = Environment(loader=FileSystemLoader(template_dir), j2_env = Environment(loader=FileSystemLoader(template_dir),
trim_blocks=True) trim_blocks=True)
j2_env.filters['stringfy'] = stringfy j2_env.filters['stringfy'] = stringfy
counter = 0
output_dir = os.path.dirname(output) + '/' output_dir = os.path.dirname(output) + '/'
# generate tensor source files # generate tensor source files
model_data = [] model_data = []
offset = 0 offset = 0
counter = 0
for t in net_def.tensors: for t in net_def.tensors:
tensor_info = TensorInfo(t, runtime) tensor_info = TensorInfo(counter, t, runtime)
# align # align
if tensor_info.data_type != 'DT_UINT8' and offset % 4 != 0: if tensor_info.data_type != 'DT_UINT8' and offset % 4 != 0:
padding = 4 - offset % 4 padding = 4 - offset % 4
model_data.extend(bytearray([0] * padding)) model_data.extend(bytearray([0] * padding))
offset += padding offset += padding
source = j2_env.get_template(template_name).render( source = j2_env.get_template(template_name).render(
tensor_info = TensorInfo(t, runtime), tensor_info = tensor_info,
tensor = t, tensor = t,
tag = model_tag, tag = model_tag,
mode = 0, mode = 0,
...@@ -164,7 +165,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag, ...@@ -164,7 +165,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
counter += 1 counter += 1
# generate model source files # generate model source files
tensors = [TensorInfo(t, runtime) for t in net_def.tensors] tensors = [TensorInfo(i, net_def.tensors[i], runtime) for i in range(len(net_def.tensors))]
source = j2_env.get_template(template_name).render( source = j2_env.get_template(template_name).render(
tensors = tensors, tensors = tensors,
net = net_def, net = net_def,
......
...@@ -641,6 +641,7 @@ class TFConverter(object): ...@@ -641,6 +641,7 @@ class TFConverter(object):
conv_op = self.tf_graph[op.name][0] conv_op = self.tf_graph[op.name][0]
op_def.name = conv_op.name op_def.name = conv_op.name
op_def.type = conv_op.type op_def.type = conv_op.type
self.transpose_filter_tensor[get_input_tensor(conv_op, 1).name] = (0, 1, 3, 2)
if self.device == 'gpu': if self.device == 'gpu':
op_def.input.extend([op.inputs[0].name]) op_def.input.extend([op.inputs[0].name])
output_name = self.add_buffer_to_image(get_input_tensor(conv_op, 1).name, "CONV2D_FILTER") output_name = self.add_buffer_to_image(get_input_tensor(conv_op, 1).name, "CONV2D_FILTER")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册