From a66e0e352c1d8d12d034d7fb0aa7812008ba1941 Mon Sep 17 00:00:00 2001 From: liuqi Date: Wed, 6 Dec 2017 14:41:54 +0800 Subject: [PATCH] Finish gcn converter and validation. --- mace/examples/mace_run.cc | 28 +- mace/ops/fused_conv_2d_test.cc | 78 +++++ mace/python/tools/tf_converter.py | 7 +- mace/python/tools/tf_converter_lib.py | 460 ++++++++++++++------------ tools/validate.py | 21 +- tools/validate_gcn.sh | 53 +-- tools/validate_icnet.py | 124 ------- 7 files changed, 389 insertions(+), 382 deletions(-) delete mode 100644 tools/validate_icnet.py diff --git a/mace/examples/mace_run.cc b/mace/examples/mace_run.cc index 21eade9c..8ca9765b 100644 --- a/mace/examples/mace_run.cc +++ b/mace/examples/mace_run.cc @@ -129,15 +129,21 @@ int main(int argc, char **argv) { // save output const Tensor *output = ws.GetTensor(output_node + ":0"); - Tensor::MappingGuard output_guard(output); - ofstream out_file(output_file, ios::binary); - out_file.write((const char *)(output->data()), - output->size() * sizeof(float)); - out_file.flush(); - out_file.close(); - VLOG(0) << "Output shape: [" - << output->dim(0) << ", " - << output->dim(1) << ", " - << output->dim(2) << ", " - << output->dim(3) << "]"; + std::remove(output_file.c_str()); + if (output != nullptr) { + Tensor::MappingGuard output_guard(output); + ofstream out_file(output_file, ios::binary); + out_file.write((const char *)(output->data()), + output->size() * sizeof(float)); + out_file.flush(); + out_file.close(); + stringstream ss; + ss << "Output shape: ["; + for (int i = 0; i < output->dim_size(); ++i) { + ss << output->dim(i) << ", "; + + } + ss << "]"; + VLOG(0) << ss.str(); + } } \ No newline at end of file diff --git a/mace/ops/fused_conv_2d_test.cc b/mace/ops/fused_conv_2d_test.cc index 896fbbc6..7ce58e6c 100644 --- a/mace/ops/fused_conv_2d_test.cc +++ b/mace/ops/fused_conv_2d_test.cc @@ -408,3 +408,81 @@ TEST_F(FusedConv2dOpTest, OPENCLHalfAlignedConvNxNS12) { TestHalfComplexConvNxNS12({32, 32, 32, 64}); } +template +static void TestGeneralConvNxNS12(const std::vector &image_shape, + const std::vector &filter_shape) { + testing::internal::LogToStderr(); + auto func = [&](int stride_h, int stride_w, Padding type) { + srand(time(NULL)); + + // generate random input + index_t batch = 1; + index_t height = image_shape[0]; + index_t width = image_shape[1]; + index_t input_channels = filter_shape[2]; + index_t output_channels = filter_shape[3]; + index_t kernel_h = filter_shape[0]; + index_t kernel_w = filter_shape[1]; + // Construct graph + OpsTestNet net; + OpDefBuilder("FusedConv2D", "FusedConv2dTest") + .Input("Input") + .Input("Filter") + .Input("Bias") + .Output("Output") + .AddIntsArg("strides", {stride_h, stride_w}) + .AddIntArg("padding", type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input", {batch, height, width, input_channels}); + net.AddRandomInput( + "Filter", {kernel_h, kernel_w, input_channels, output_channels}); + net.AddRandomInput("Bias", {output_channels}); + + // run on cpu + net.RunOp(); + // Check + Tensor expected; + expected.Copy(*net.GetOutput("Output")); + + // run on gpu + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT); + BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::FILTER); + BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); + + OpDefBuilder("FusedConv2D", "FusedConv2dTest") + .Input("InputImage") + .Input("FilterImage") + .Input("BiasImage") + .Output("OutputImage") + .AddIntsArg("strides", {stride_h, stride_w}) + .AddIntArg("padding", type) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + // Run on device + net.RunOp(D); + + ImageToBuffer(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT); + ExpectTensorNear(expected, *net.GetOutput("OPENCLOutput"), 0.001); + }; + + for (int stride : {1, 2}) { + func(stride, stride, VALID); + func(stride, stride, SAME); + } +} + +TEST_F(FusedConv2dOpTest, OPENCL7X7ConvNxNS12) { + TestGeneralConvNxNS12({32, 32}, + {7, 7, 3, 64}); +} + +TEST_F(FusedConv2dOpTest, OPENCL15X1ConvNxNS12) { + TestGeneralConvNxNS12({40, 40}, + {15, 1, 32, 64}); +} + diff --git a/mace/python/tools/tf_converter.py b/mace/python/tools/tf_converter.py index c1792f44..886999d3 100644 --- a/mace/python/tools/tf_converter.py +++ b/mace/python/tools/tf_converter.py @@ -24,7 +24,7 @@ def main(unused_args): input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.prequantize) else: output_graph_def = tf_converter_lib.convert_to_mace_pb( - input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.runtime) + input_graph_def, FLAGS.input_node, FLAGS.output_node, FLAGS.data_type, FLAGS.runtime) with gfile.GFile(FLAGS.output, "wb") as f: f.write(output_graph_def.SerializeToString()) @@ -67,6 +67,11 @@ def parse_args(): type=bool, default=False, help="e.g., False") + parser.add_argument( + "--data_type", + type=str, + default='DT_FLOAT', + help="e.g., DT_HALF/DT_FLOAT") return parser.parse_known_args() diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 37fd539b..80b5ee42 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -19,6 +19,11 @@ buffer_type_map = { 'ARGUMENT' : 2, } +data_type_map = { + 'DT_HALF' : mace_pb2.DT_HALF, + 'DT_FLOAT': mace_pb2.DT_FLOAT +} + def convert_tensor(op, tensor): tf_tensor = op.outputs[0].eval() tensor.name = op.outputs[0].name @@ -42,7 +47,7 @@ def get_input_tensor(op, index): input_tensor = get_input_tensor(input_tensor.op, 0) return input_tensor -def add_buffer_to_image(input_name, input_type, net_def): +def add_buffer_to_image(input_name, input_type, dt, net_def): output_name = input_name[:-2] + "_b2i" + input_name[-2:] op_def = net_def.op.add() op_def.name = output_name[:-2] @@ -50,15 +55,34 @@ def add_buffer_to_image(input_name, input_type, net_def): op_def.input.extend([input_name]) op_def.output.extend([output_name]) - epsilon_arg = op_def.arg.add() - epsilon_arg.name = 'buffer_type' - epsilon_arg.i = buffer_type_map[input_type] - epsilon_arg = op_def.arg.add() - epsilon_arg.name = 'mode' - epsilon_arg.i = 0 + arg = op_def.arg.add() + arg.name = 'buffer_type' + arg.i = buffer_type_map[input_type] + arg = op_def.arg.add() + arg.name = 'mode' + arg.i = 0 + arg = op_def.arg.add() + arg.name = 'T' + arg.i = dt + return output_name + +def add_image_to_buffer(input_name, input_type, dt, net_def): + output_name = input_name[:-2] + "_i2b" + input_name[-2:] + op_def = net_def.op.add() + op_def.name = output_name[:-2] + op_def.type = 'ImageToBuffer' + op_def.input.extend([input_name]) + op_def.output.extend([output_name]) + + arg = op_def.arg.add() + arg.name = 'buffer_type' + arg.i = buffer_type_map[input_type] + arg = op_def.arg.add() + arg.name = 'T' + arg.i = dt return output_name -def add_input_transform(name, net_def): +def add_input_transform(name, dt, net_def): new_input_name = "mace_input_node:0" op_def = net_def.op.add() op_def.name = name @@ -70,6 +94,10 @@ def add_input_transform(name, net_def): epsilon_arg.name = 'buffer_type' epsilon_arg.i = buffer_type_map['IN_OUT'] + arg = op_def.arg.add() + arg.name = 'T' + arg.i = dt + def add_output_transform(name, net_def): output_name = "mace_output_node:0" op_def = net_def.op.add() @@ -82,7 +110,7 @@ def add_output_transform(name, net_def): epsilon_arg.name = 'buffer_type' epsilon_arg.i = buffer_type_map['IN_OUT'] -def convert_ops(unresolved_ops, net_def, device): +def convert_ops(unresolved_ops, dt, net_def, device): ops_count = len(unresolved_ops) resolved_count = 1 @@ -93,225 +121,223 @@ def convert_ops(unresolved_ops, net_def, device): elif first_op.type == 'Const': tensor = net_def.tensors.add() convert_tensor(first_op, tensor) - elif first_op.type == 'Conv2D' or first_op.type == 'DepthwiseConv2dNative': + else: op_def = net_def.op.add() - op_def.name = first_op.name - if first_op.type == 'DepthwiseConv2dNative': - op_def.type = 'DepthwiseConv2d' - else: - op_def.type = first_op.type - if device == 'gpu': - op_def.input.extend([first_op.inputs[0].name]) - output_name = add_buffer_to_image(first_op.inputs[1].name, "FILTER", net_def) - op_def.input.extend([output_name]) - else: - op_def.input.extend([input.name for input in first_op.inputs]) - - padding_arg = op_def.arg.add() - padding_arg.name = 'padding' - padding_arg.i = padding_mode[first_op.get_attr('padding')] - strides_arg = op_def.arg.add() - strides_arg.name = 'strides' - strides_arg.ints.extend(first_op.get_attr('strides')[1:3]) - data_format_arg = op_def.arg.add() - data_format_arg.name = 'data_format' - data_format_arg.s = 'NHWC' - final_op = first_op - - if ops_count >= 3 and unresolved_ops[1].type == 'Const' and unresolved_ops[2].type == 'BiasAdd' : - bias_tensor = unresolved_ops[1] - tensor = net_def.tensors.add() - convert_tensor(bias_tensor, tensor) - - bias_add_op = unresolved_ops[2] + arg = op_def.arg.add() + arg.name = 'T' + arg.i = dt + + if first_op.type == 'Conv2D' or first_op.type == 'DepthwiseConv2dNative': + op_def.name = first_op.name + if first_op.type == 'DepthwiseConv2dNative': + op_def.type = 'DepthwiseConv2d' + else: + op_def.type = first_op.type if device == 'gpu': - output_name = add_buffer_to_image(bias_add_op.inputs[1].name, "ARGUMENT", net_def) + op_def.input.extend([first_op.inputs[0].name]) + output_name = add_buffer_to_image(first_op.inputs[1].name, "FILTER", dt, net_def) op_def.input.extend([output_name]) else: - op_def.input.extend([bias_add_op.inputs[1].name]) - final_op = bias_add_op - resolved_count = 3 - - if ops_count >= 4 and unresolved_ops[3].type == 'Relu': - relu_op = unresolved_ops[3]; - op_def.type = "FusedConv2D" - final_op = relu_op - resolved_count = 4 - - op_def.output.extend([output.name for output in final_op.outputs]) - output_shapes = [] - for output in final_op.outputs: - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(output.shape.as_list()) - output_shapes.append(output_shape) - op_def.output_shape.extend(output_shapes) + op_def.input.extend([input.name for input in first_op.inputs]) + + padding_arg = op_def.arg.add() + padding_arg.name = 'padding' + padding_arg.i = padding_mode[first_op.get_attr('padding')] + strides_arg = op_def.arg.add() + strides_arg.name = 'strides' + strides_arg.ints.extend(first_op.get_attr('strides')[1:3]) + data_format_arg = op_def.arg.add() + data_format_arg.name = 'data_format' + data_format_arg.s = 'NHWC' + final_op = first_op + + if ops_count >= 3 and unresolved_ops[1].type == 'Const' and unresolved_ops[2].type == 'BiasAdd' : + bias_tensor = unresolved_ops[1] + tensor = net_def.tensors.add() + convert_tensor(bias_tensor, tensor) + + bias_add_op = unresolved_ops[2] + if device == 'gpu': + output_name = add_buffer_to_image(bias_add_op.inputs[1].name, "ARGUMENT", dt, net_def) + op_def.input.extend([output_name]) + else: + op_def.input.extend([bias_add_op.inputs[1].name]) + final_op = bias_add_op + resolved_count = 3 + + if ops_count >= 4 and unresolved_ops[3].type == 'Relu': + relu_op = unresolved_ops[3]; + op_def.type = "FusedConv2D" + final_op = relu_op + resolved_count = 4 + + op_def.output.extend([output.name for output in final_op.outputs]) + output_shapes = [] + for output in final_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + + elif first_op.type == 'FusedBatchNorm': + op_def.name = first_op.name + op_def.type = 'BatchNorm' + if device == 'gpu': + op_def.input.extend([first_op.inputs[0].name]) + for i in range(1, len(first_op.inputs)): + output_name = add_buffer_to_image(first_op.inputs[i].name, "ARGUMENT", dt, net_def) + op_def.input.extend([output_name]) + else: + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([first_op.outputs[0].name]) - elif first_op.type == 'FusedBatchNorm': - op_def = net_def.op.add() - op_def.name = first_op.name - op_def.type = 'BatchNorm' - if device == 'gpu': + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(first_op.outputs[0].shape.as_list()) + op_def.output_shape.extend([output_shape]) + + epsilon_arg = op_def.arg.add() + epsilon_arg.name = 'epsilon' + epsilon_arg.f = first_op.get_attr('epsilon') + data_format_arg = op_def.arg.add() + data_format_arg.name = 'data_format' + data_format_arg.s = 'NHWC' + elif first_op.type == 'Add' and first_op.name.endswith( + 'batchnorm/add') and ops_count > 7: + add_op = first_op + mul_op = unresolved_ops[2] + mul_1_op = unresolved_ops[3] + mul_2_op = unresolved_ops[4] + sub_op = unresolved_ops[5] + add_1_op = unresolved_ops[6] + # print (mul_op.type, mul_2_op.type, mul_1_op.type, sub_op.type) + if mul_op.type != 'Mul' or mul_2_op.type != 'Mul' or \ + mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add': + raise Exception('Invalid BatchNorm Op') + + get_input_tensor(mul_1_op, 0) + input_name = get_input_tensor(mul_1_op, 0).name + gamma = get_input_tensor(mul_op, 1).name + beta = get_input_tensor(sub_op, 0).name + mean = get_input_tensor(mul_2_op, 0).name + variance = get_input_tensor(add_op, 0).name + epsilon = get_input_tensor(add_op, 1).name + + op_def.name = first_op.name[:-4] # remove /add + op_def.type = 'BatchNorm' + op_def.input.extend([input_name, gamma, beta, mean, variance, epsilon]) + op_def.output.extend([output.name for output in add_1_op.outputs]) + output_shapes = [] + for output in add_1_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + + resolved_count = 7 + elif first_op.type == 'Relu6': + op_def.name = first_op.name + op_def.type = 'Relu' + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([output.name for output in first_op.outputs]) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + max_limit_arg = op_def.arg.add() + max_limit_arg.name = 'max_limit' + max_limit_arg.f = 6 + elif first_op.type == 'AvgPool' or first_op.type == 'MaxPool': + op_def.name = first_op.name + op_def.type = 'Pooling' + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([output.name for output in first_op.outputs]) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + pooling_type_arg = op_def.arg.add() + pooling_type_arg.name = 'pooling_type' + pooling_type_arg.i = pooling_type_mode[first_op.type] + padding_arg = op_def.arg.add() + padding_arg.name = 'padding' + padding_arg.i = padding_mode[first_op.get_attr('padding')] + strides_arg = op_def.arg.add() + strides_arg.name = 'strides' + strides_arg.ints.extend(first_op.get_attr('strides')[1:3]) + kernels_arg = op_def.arg.add() + kernels_arg.name = 'kernels' + kernels_arg.ints.extend(first_op.get_attr('ksize')[1:3]) + data_format_arg = op_def.arg.add() + data_format_arg.name = 'data_format' + data_format_arg.s = 'NHWC' + elif first_op.type == 'Add': + op_def.name = first_op.name + op_def.type = "AddN" + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([output.name for output in first_op.outputs]) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + elif first_op.type == 'ConcatV2': + op_def.name = first_op.name + op_def.type = "Concat" + op_def.input.extend([first_op.inputs[i].name for i in xrange(2)]) + op_def.output.extend([output.name for output in first_op.outputs]) + axis_arg = op_def.arg.add() + axis_arg.name = 'axis' + axis_arg.i = get_input_tensor(first_op, 2).eval().astype(np.int32) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + elif first_op.type == 'ResizeBilinear': + op_def.name = first_op.name + op_def.type = "ResizeBilinear" op_def.input.extend([first_op.inputs[0].name]) - for i in range(1, len(first_op.inputs)): - output_name = add_buffer_to_image(first_op.inputs[i].name, "ARGUMENT", net_def) - op_def.input.extend([output_name]) - else: + op_def.output.extend([output.name for output in first_op.outputs]) + size_arg = op_def.arg.add() + size_arg.name = 'size' + size_arg.ints.extend(get_input_tensor(first_op, 1).eval().astype(np.int32).flat) + size_arg = op_def.arg.add() + size_arg.name = 'align_corners' + size_arg.i = first_op.get_attr('align_corners') + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + elif first_op.type in ['Relu', 'SpaceToBatchND', 'BatchToSpaceND', 'BiasAdd']: + op_def.name = first_op.name + op_def.type = first_op.type op_def.input.extend([input.name for input in first_op.inputs]) - op_def.output.extend([first_op.outputs[0].name]) - - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(first_op.outputs[0].shape.as_list()) - op_def.output_shape.extend([output_shape]) - - epsilon_arg = op_def.arg.add() - epsilon_arg.name = 'epsilon' - epsilon_arg.f = first_op.get_attr('epsilon') - data_format_arg = op_def.arg.add() - data_format_arg.name = 'data_format' - data_format_arg.s = 'NHWC' - elif first_op.type == 'Add' and first_op.name.endswith( - 'batchnorm/add') and ops_count > 7: - add_op = first_op - mul_op = unresolved_ops[2] - mul_1_op = unresolved_ops[3] - mul_2_op = unresolved_ops[4] - sub_op = unresolved_ops[5] - add_1_op = unresolved_ops[6] - # print (mul_op.type, mul_2_op.type, mul_1_op.type, sub_op.type) - if mul_op.type != 'Mul' or mul_2_op.type != 'Mul' or \ - mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add': - raise Exception('Invalid BatchNorm Op') - - get_input_tensor(mul_1_op, 0) - input_name = get_input_tensor(mul_1_op, 0).name - gamma = get_input_tensor(mul_op, 1).name - beta = get_input_tensor(sub_op, 0).name - mean = get_input_tensor(mul_2_op, 0).name - variance = get_input_tensor(add_op, 0).name - epsilon = get_input_tensor(add_op, 1).name - - op_def = net_def.op.add() - op_def.name = first_op.name[:-4] # remove /add - op_def.type = 'BatchNorm' - op_def.input.extend([input_name, gamma, beta, mean, variance, epsilon]) - op_def.output.extend([output.name for output in add_1_op.outputs]) - output_shapes = [] - for output in add_1_op.outputs: - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(output.shape.as_list()) - output_shapes.append(output_shape) - op_def.output_shape.extend(output_shapes) - - resolved_count = 7 - elif first_op.type == 'Relu6': - op_def = net_def.op.add() - op_def.name = first_op.name - op_def.type = 'Relu' - op_def.input.extend([input.name for input in first_op.inputs]) - op_def.output.extend([output.name for output in first_op.outputs]) - output_shapes = [] - for output in first_op.outputs: - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(output.shape.as_list()) - output_shapes.append(output_shape) - op_def.output_shape.extend(output_shapes) - max_limit_arg = op_def.arg.add() - max_limit_arg.name = 'max_limit' - max_limit_arg.f = 6 - elif first_op.type == 'AvgPool' or first_op.type == 'MaxPool': - op_def = net_def.op.add() - op_def.name = first_op.name - op_def.type = 'Pooling' - op_def.input.extend([input.name for input in first_op.inputs]) - op_def.output.extend([output.name for output in first_op.outputs]) - output_shapes = [] - for output in first_op.outputs: - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(output.shape.as_list()) - output_shapes.append(output_shape) - op_def.output_shape.extend(output_shapes) - pooling_type_arg = op_def.arg.add() - pooling_type_arg.name = 'pooling_type' - pooling_type_arg.i = pooling_type_mode[first_op.type] - padding_arg = op_def.arg.add() - padding_arg.name = 'padding' - padding_arg.i = padding_mode[first_op.get_attr('padding')] - strides_arg = op_def.arg.add() - strides_arg.name = 'strides' - strides_arg.ints.extend(first_op.get_attr('strides')[1:3]) - kernels_arg = op_def.arg.add() - kernels_arg.name = 'kernels' - kernels_arg.ints.extend(first_op.get_attr('ksize')[1:3]) - data_format_arg = op_def.arg.add() - data_format_arg.name = 'data_format' - data_format_arg.s = 'NHWC' - elif first_op.type == 'Add': - op_def = net_def.op.add() - op_def.name = first_op.name - op_def.type = "AddN" - op_def.input.extend([input.name for input in first_op.inputs]) - op_def.output.extend([output.name for output in first_op.outputs]) - output_shapes = [] - for output in first_op.outputs: - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(output.shape.as_list()) - output_shapes.append(output_shape) - op_def.output_shape.extend(output_shapes) - elif first_op.type == 'ConcatV2': - op_def = net_def.op.add() - op_def.name = first_op.name - op_def.type = "Concat" - op_def.input.extend([first_op.inputs[i].name for i in xrange(2)]) - op_def.output.extend([output.name for output in first_op.outputs]) - axis_arg = op_def.arg.add() - axis_arg.name = 'axis' - axis_arg.i = get_input_tensor(first_op, 2).eval().astype(np.int32) - output_shapes = [] - for output in first_op.outputs: - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(output.shape.as_list()) - output_shapes.append(output_shape) - op_def.output_shape.extend(output_shapes) - elif first_op.type == 'ResizeBilinear': - op_def = net_def.op.add() - op_def.name = first_op.name - op_def.type = "ResizeBilinear" - op_def.input.extend([first_op.inputs[0].name]) - op_def.output.extend([output.name for output in first_op.outputs]) - size_arg = op_def.arg.add() - size_arg.name = 'size' - size_arg.ints.extend(get_input_tensor(first_op, 1).eval().astype(np.int32).flat) - size_arg = op_def.arg.add() - size_arg.name = 'align_corners' - size_arg.i = first_op.get_attr('align_corners') - output_shapes = [] - for output in first_op.outputs: - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(output.shape.as_list()) - output_shapes.append(output_shape) - op_def.output_shape.extend(output_shapes) - elif first_op.type in ['Relu', 'SpaceToBatchND', 'BatchToSpaceND', 'BiasAdd']: - op_def = net_def.op.add() - op_def.name = first_op.name - op_def.type = first_op.type - op_def.input.extend([input.name for input in first_op.inputs]) - op_def.output.extend([output.name for output in first_op.outputs]) - output_shapes = [] - for output in first_op.outputs: - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(output.shape.as_list()) - output_shapes.append(output_shape) - op_def.output_shape.extend(output_shapes) - else: - raise Exception('Unknown Op: %s, type: %s' % (first_op.name, first_op.type)) - pass + op_def.output.extend([output.name for output in first_op.outputs]) + output_shapes = [] + for output in first_op.outputs: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op_def.output_shape.extend(output_shapes) + else: + raise Exception('Unknown Op: %s, type: %s' % (first_op.name, first_op.type)) + pass for i in range(resolved_count): del unresolved_ops[0] -def convert_to_mace_pb(input_graph_def, input_node, output_node, device): +def convert_to_mace_pb(input_graph_def, input_node, output_node, data_type, device): net_def = mace_pb2.NetDef() + dt = data_type_map[data_type] with tf.Session() as session: with session.graph.as_default() as graph: @@ -319,9 +345,9 @@ def convert_to_mace_pb(input_graph_def, input_node, output_node, device): ops = graph.get_operations() unresolved_ops = ops if device == 'gpu': - add_input_transform(input_node, net_def) + add_input_transform(input_node, dt, net_def) while len(unresolved_ops) > 0: - convert_ops(unresolved_ops, net_def, device) + convert_ops(unresolved_ops, dt, net_def, device) if device == 'gpu': add_output_transform(output_node, net_def) diff --git a/tools/validate.py b/tools/validate.py index f70c59f3..42a856b0 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -1,5 +1,7 @@ import argparse import sys +import os +import os.path import tensorflow as tf import numpy as np @@ -25,13 +27,23 @@ def generate_data(shape): print "Generate input file done." def load_data(file): - return np.fromfile(file=file, dtype=np.float32) + if os.path.isfile(file): + return np.fromfile(file=file, dtype=np.float32) + else: + return np.empty([0]) def valid_output(out_shape, mace_out_file, tf_out_value): mace_out_value = load_data(mace_out_file) - mace_out_value = mace_out_value.reshape(out_shape) - res = np.allclose(tf_out_value, mace_out_value, rtol=0, atol=1e-5) - print 'Passed! Haha' if res else 'Failed! Oops' + if mace_out_value.size != 0: + mace_out_value = mace_out_value.reshape(out_shape) + np.testing.assert_allclose(tf_out_value, mace_out_value, rtol=0, atol=1e-3) + res = np.allclose(tf_out_value, mace_out_value, rtol=0, atol=1e-3) + if res: + print '=======================Passed! Haha======================' + else: + print '=======================Failed! Oops======================' + else: + print '=======================Skip empty node===================' def run_model(input_shape): @@ -55,6 +67,7 @@ def run_model(input_shape): input_value = input_value.reshape(input_shape) output_value = session.run(output_node, feed_dict={input_node: [input_value]}) + # output_value.astype(np.float32).tofile( os.path.dirname(FLAGS.input_file) + '/tf_weight') return output_value def main(unused_args): diff --git a/tools/validate_gcn.sh b/tools/validate_gcn.sh index a1174466..f4dfc6eb 100644 --- a/tools/validate_gcn.sh +++ b/tools/validate_gcn.sh @@ -1,6 +1,5 @@ #!/bin/bash # Must run at root dir of mace project. -set -e Usage() { echo 'Usage: bash tools/validate_gcn.sh tf_model_file' @@ -16,23 +15,26 @@ MODEL_DIR=$(dirname ${TF_MODEL_FILE_PATH}) MACE_MODEL_NAME='mace_model.pb' INPUT_FILE_NAME='model_input' OUTPUT_FILE_NAME='gcn.out' +OUTPUT_LIST_FILE='gcn.list' PHONE_DATA_DIR="/data/local/tmp/${MACE_MODEL_NAME}" KERNEL_DIR="${PHONE_DATA_DIR}/cl/" -# Step 1: convert tf model to mace model -echo "Step 1: convert tf model to mace model" +# Step 1: Generate input data +echo "Step 1: Generate input data" +python tools/validate.py --generate_data true --random_seed 1 \ + --input_file=${MODEL_DIR}/${INPUT_FILE_NAME} \ + --input_shape=512,512,3 + +# Step 2: convert tf model to mace model +echo "Step 2: convert tf model to mace model" bazel build //mace/python/tools:tf_converter bazel-bin/mace/python/tools/tf_converter --input=${TF_MODEL_FILE_PATH} \ --output=${MODEL_DIR}/${MACE_MODEL_NAME} \ --input_node=input \ --output_node=GCN/br_result_2/fcn_br \ + --data_type=DT_FLOAT \ --runtime=gpu -# Step 2: Generate input data -echo "Step 2: Generate input data" -python tools/validate.py --generate_data true --random_seed 1 \ - --input_file=${MODEL_DIR}/${INPUT_FILE_NAME} \ - --input_shape=512,512,3 # Step 3: Run model on the phone echo "Step 3: Run model on the phone" @@ -50,28 +52,29 @@ adb push bazel-bin/mace/examples/mace_run ${PHONE_DATA_DIR} num_threads=${1:-1} -adb shell MACE_RUN_PARAMETER_PATH=${PHONE_DATA_DIR}/mace_run.config \ - MACE_KERNEL_PATH=$KERNEL_DIR \ - OMP_NUM_THREADS=$num_threads \ - ${PHONE_DATA_DIR}/mace_run \ - --model=${PHONE_DATA_DIR}/${MACE_MODEL_NAME} \ - --input=mace_input_node \ - --output=mace_output_node \ - --input_shape=1,512,512,3\ - --input_file=${PHONE_DATA_DIR}/${MACE_INPUT_FILE_NAME} \ - --output_file=${PHONE_DATA_DIR}/${OUTPUT_FILE_NAME} \ - --device=OPENCL +adb