提交 5dce86c6 编写于 作者: L liuqi

Fix atrous convolution transform bug.

上级 f003b9ae
......@@ -15,7 +15,7 @@
namespace mace {
namespace {{tag}} {
void Create{{tensor.name}}(std::vector<mace::ConstTensor> &tensors,
void CreateTensor{{tensor_info.id}}(std::vector<mace::ConstTensor> &tensors,
const unsigned char *model_data) {
MACE_LATENCY_LOGGER(2, "Create tensor {{ tensor.name }}");
tensors.emplace_back(mace::ConstTensor(
......@@ -189,7 +189,7 @@ namespace mace {
namespace {{tag}} {
{% for tensor in tensors %}
extern void Create{{ tensor.name }}(std::vector<mace::ConstTensor> &tensors,
extern void CreateTensor{{ tensor.id }}(std::vector<mace::ConstTensor> &tensors,
const unsigned char *model_data);
{% endfor %}
......@@ -269,9 +269,9 @@ void CreateTensors(std::vector<mace::ConstTensor> &tensors,
MACE_LATENCY_LOGGER(1, "Create tensors");
tensors.reserve({{ net.tensors|length }});
{% for tensor in net.tensors %}
{% for tensor in tensors %}
mace::{{tag}}::Create{{tensor.name}}(tensors, model_data);
mace::{{tag}}::CreateTensor{{tensor.id}}(tensors, model_data);
{% endfor %}
}
......
......@@ -73,7 +73,8 @@ def rename_tensor(net_def):
op.output[i] = tensor_map[op.output[i]]
class TensorInfo:
def __init__(self, t, runtime):
def __init__(self, id, t, runtime):
self.id = id
self.name = t.name
self.data_type = mace_pb2.DataType.Name(t.data_type)
if t.data_type == mace_pb2.DT_FLOAT:
......@@ -106,20 +107,20 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
j2_env = Environment(loader=FileSystemLoader(template_dir),
trim_blocks=True)
j2_env.filters['stringfy'] = stringfy
counter = 0
output_dir = os.path.dirname(output) + '/'
# generate tensor source files
model_data = []
offset = 0
counter = 0
for t in net_def.tensors:
tensor_info = TensorInfo(t, runtime)
tensor_info = TensorInfo(counter, t, runtime)
# align
if tensor_info.data_type != 'DT_UINT8' and offset % 4 != 0:
padding = 4 - offset % 4
model_data.extend(bytearray([0] * padding))
offset += padding
source = j2_env.get_template(template_name).render(
tensor_info = TensorInfo(t, runtime),
tensor_info = tensor_info,
tensor = t,
tag = model_tag,
mode = 0,
......@@ -164,7 +165,7 @@ def convert_to_source(net_def, mode_pb_checksum, template, obfuscate, model_tag,
counter += 1
# generate model source files
tensors = [TensorInfo(t, runtime) for t in net_def.tensors]
tensors = [TensorInfo(i, net_def.tensors[i], runtime) for i in range(len(net_def.tensors))]
source = j2_env.get_template(template_name).render(
tensors = tensors,
net = net_def,
......
......@@ -641,6 +641,7 @@ class TFConverter(object):
conv_op = self.tf_graph[op.name][0]
op_def.name = conv_op.name
op_def.type = conv_op.type
self.transpose_filter_tensor[get_input_tensor(conv_op, 1).name] = (0, 1, 3, 2)
if self.device == 'gpu':
op_def.input.extend([op.inputs[0].name])
output_name = self.add_buffer_to_image(get_input_tensor(conv_op, 1).name, "CONV2D_FILTER")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册