From 11a838dc8e05971b12baf5225a5b06023d617979 Mon Sep 17 00:00:00 2001 From: liuqi Date: Thu, 7 Dec 2017 18:52:54 +0800 Subject: [PATCH] Add tuning code for opencl kernel. --- mace/kernels/opencl/addn.cc | 50 ++++++++++-- mace/kernels/opencl/batch_norm_opencl.cc | 7 +- mace/kernels/opencl/concat.cc | 63 ++++++++++++--- mace/kernels/opencl/conv_2d_opencl_1x1.cc | 7 +- mace/kernels/opencl/conv_2d_opencl_3x3.cc | 7 +- mace/kernels/opencl/conv_2d_opencl_general.cc | 7 +- mace/kernels/opencl/pooling_opencl.cc | 73 ++++++++++++----- mace/kernels/opencl/relu_opencl.cc | 7 +- mace/kernels/opencl/resize_bilinear_opencl.cc | 58 +++++++++++--- mace/python/tools/tf_converter_lib.py | 80 ++++++++++++------- tools/validate.py | 8 +- tools/validate_gcn.sh | 15 ++-- 12 files changed, 288 insertions(+), 94 deletions(-) diff --git a/mace/kernels/opencl/addn.cc b/mace/kernels/opencl/addn.cc index 31cd1910..83e6b65b 100644 --- a/mace/kernels/opencl/addn.cc +++ b/mace/kernels/opencl/addn.cc @@ -6,6 +6,7 @@ #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/kernels/opencl/helper.h" #include "mace/utils/utils.h" +#include "mace/utils/tuner.h" namespace mace { namespace kernels { @@ -33,8 +34,6 @@ static void AddN(const std::vector &input_tensors, built_options.emplace("-DINPUT_NUM=" + ToString(input_tensors.size())); auto addn_kernel = runtime->BuildKernel("addn", "addn", built_options); - const uint32_t lws = runtime->GetKernelMaxWorkGroupSize(addn_kernel); - uint32_t idx = 0; for (auto input : input_tensors) { addn_kernel.setArg(idx++, @@ -42,12 +41,47 @@ static void AddN(const std::vector &input_tensors, } addn_kernel.setArg(idx++, *(static_cast(output->buffer()))); - cl_int error = runtime->command_queue().enqueueNDRangeKernel( - addn_kernel, cl::NullRange, - cl::NDRange(width_pixels, batch_height_pixels), - cl::NDRange(64, 16), // TODO fix this - nullptr, OpenCLRuntime::Get()->GetDefaultEvent()); - MACE_CHECK(error == CL_SUCCESS) << "error code: " << error; + const uint32_t gws[2] = { + static_cast(width_pixels), + static_cast(batch_height_pixels) + }; + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(addn_kernel); + std::vector lws = {64, 16}; + auto params_generator = [&]() -> std::vector> { + uint32_t local_ws[2]; + local_ws[0] = std::min(width_pixels, kwg_size); + local_ws[1] = std::min(batch_height_pixels, kwg_size / local_ws[0]); + return {{local_ws[0], local_ws[1]}, + {kwg_size / 16, 16}, + {kwg_size / 32, 32}, + {kwg_size / 64, 64}, + {kwg_size / 128, 128}, + {kwg_size / 256, 256}, + {kwg_size, 1}, + {1, kwg_size} + }; + }; + auto func = [&](const std::vector ¶ms) -> cl_int { + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + addn_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1]), + cl::NDRange(params[0], params[1]), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); + + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + return error; + }; + std::stringstream ss; + ss << "addn_opencl_kernel_" + << output->dim(0) << "_" + << output->dim(1) << "_" + << output->dim(2) << "_" + << output->dim(3); + Tuner::Get()->template TuneOrRun(ss.str(), + lws, + params_generator, + func); + } template diff --git a/mace/kernels/opencl/batch_norm_opencl.cc b/mace/kernels/opencl/batch_norm_opencl.cc index a5362262..de6571ea 100644 --- a/mace/kernels/opencl/batch_norm_opencl.cc +++ b/mace/kernels/opencl/batch_norm_opencl.cc @@ -48,8 +48,13 @@ void BatchNormFunctor::operator()( static_cast(height * batch)}; const std::vector lws = {8, 16, 8}; const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(bm_kernel); - auto params_generator = [&kwg_size]() -> std::vector> { + auto params_generator = [&]() -> std::vector> { + std::vector local_ws(3, 0); + local_ws[0] = std::min(channel_blocks, kwg_size); + local_ws[1] = std::min(width, kwg_size / local_ws[0]); + local_ws[2] = std::min(height * batch, kwg_size / (local_ws[0] * local_ws[1])); return {{8, 128, 1}, //SNPE size + {local_ws[0], local_ws[1], local_ws[2]}, {kwg_size / 16, 4, 4}, {kwg_size / 32, 4, 8}, {kwg_size / 32, 8, 4}, diff --git a/mace/kernels/opencl/concat.cc b/mace/kernels/opencl/concat.cc index f80f370d..706ed8f1 100644 --- a/mace/kernels/opencl/concat.cc +++ b/mace/kernels/opencl/concat.cc @@ -6,6 +6,7 @@ #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/kernels/opencl/helper.h" #include "mace/utils/utils.h" +#include "mace/utils/tuner.h" namespace mace { namespace kernels { @@ -41,21 +42,57 @@ static void Concat2(const Tensor *input0, concat_kernel.setArg(idx++, static_cast(input0->dim(3))); concat_kernel.setArg(idx++, *(static_cast(output->buffer()))); + const uint32_t gws[3] = { + static_cast(channel_blk), + static_cast(width), + static_cast(batch * height), + }; const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(concat_kernel); + std::vector lws = {8, 16, 8}; + auto params_generator = [&]() -> std::vector> { + std::vector local_ws(3, 0); + local_ws[0] = std::min(channel_blk, kwg_size); + local_ws[1] = std::min(width, kwg_size / local_ws[0]); + local_ws[2] = std::min(height * batch, kwg_size / (local_ws[0] * local_ws[1])); + return {{4, 15, 8}, //SNPE size + {local_ws[0], local_ws[1], local_ws[2]}, + {kwg_size / 16, 4, 4}, + {kwg_size / 32, 4, 8}, + {kwg_size / 32, 8, 4}, + {kwg_size / 64, 8, 8}, + {kwg_size / 64, 16, 4}, + {kwg_size / 128, 8, 16}, + {kwg_size / 128, 16, 8}, + {kwg_size / 128, 32, 4}, + {1, kwg_size / 32, 32}, + {1, kwg_size / 64, 64}, + {1, kwg_size / 128, 128}, + {3, 15, 9}, + {7, 15, 9}, + {9, 7, 15}, + {15, 7, 9}, + {1, kwg_size, 1}}; + }; + auto func = [&](const std::vector ¶ms) -> cl_int { + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + concat_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(params[0], params[1], params[2]), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); - uint32_t lws[3] = {8, 16, 8}; -// lws[0] = std::min(channel_blk, kwg_size); -// lws[1] = std::min(width, kwg_size / lws[0]); -// lws[2] = std::min(height * batch, kwg_size / (lws[0] * lws[1])); - - cl_int error = runtime->command_queue().enqueueNDRangeKernel( - concat_kernel, cl::NullRange, - cl::NDRange(static_cast(channel_blk), - static_cast(width), - static_cast(height * batch)), - cl::NDRange(lws[0], lws[1], lws[2]), - NULL, OpenCLRuntime::Get()->GetDefaultEvent()); - MACE_CHECK(error == CL_SUCCESS); + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + return error; + }; + std::stringstream ss; + ss << "concat_opencl_kernel_" + << output->dim(0) << "_" + << output->dim(1) << "_" + << output->dim(2) << "_" + << output->dim(3); + Tuner::Get()->template TuneOrRun(ss.str(), + lws, + params_generator, + func); } template diff --git a/mace/kernels/opencl/conv_2d_opencl_1x1.cc b/mace/kernels/opencl/conv_2d_opencl_1x1.cc index 1fe00494..9eaaa3b1 100644 --- a/mace/kernels/opencl/conv_2d_opencl_1x1.cc +++ b/mace/kernels/opencl/conv_2d_opencl_1x1.cc @@ -68,8 +68,13 @@ void Conv1x1(const Tensor *input, static_cast(height * batch)}; const std::vector lws = {8, 15, 8}; const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(conv_2d_kernel); - auto params_generator = [&kwg_size]()->std::vector> { + auto params_generator = [&]()->std::vector> { + std::vector local_ws(3, 0); + local_ws[0] = std::min(channel_blocks, kwg_size); + local_ws[1] = std::min(width_blocks, kwg_size / local_ws[0]); + local_ws[2] = std::min(height * batch, kwg_size / (local_ws[0] * local_ws[1])); return {{4, 15, 8}, //SNPE size + {local_ws[0], local_ws[1], local_ws[2]}, {kwg_size/16, 4, 4}, {kwg_size/32, 4, 8}, {kwg_size/32, 8, 4}, diff --git a/mace/kernels/opencl/conv_2d_opencl_3x3.cc b/mace/kernels/opencl/conv_2d_opencl_3x3.cc index 858fc5fc..0b77b6c2 100644 --- a/mace/kernels/opencl/conv_2d_opencl_3x3.cc +++ b/mace/kernels/opencl/conv_2d_opencl_3x3.cc @@ -60,8 +60,13 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter, static_cast(height * batch)}; const std::vector lws = {4, 15, 8}; const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(conv_2d_kernel); - auto params_generator = [&kwg_size]() -> std::vector> { + auto params_generator = [&]() -> std::vector> { + std::vector local_ws(3, 0); + local_ws[0] = std::min(channel_blocks, kwg_size); + local_ws[1] = std::min(width_blocks, kwg_size / local_ws[0]); + local_ws[2] = std::min(height * batch, kwg_size / (local_ws[0] * local_ws[1])); return {{4, 15, 8}, //SNPE size + {local_ws[0], local_ws[1], local_ws[2]}, {kwg_size / 16, 4, 4}, {kwg_size / 32, 4, 8}, {kwg_size / 32, 8, 4}, diff --git a/mace/kernels/opencl/conv_2d_opencl_general.cc b/mace/kernels/opencl/conv_2d_opencl_general.cc index 7a74f86b..dcfbdec8 100644 --- a/mace/kernels/opencl/conv_2d_opencl_general.cc +++ b/mace/kernels/opencl/conv_2d_opencl_general.cc @@ -62,8 +62,13 @@ void Conv2dOpencl(const Tensor *input, const Tensor *filter, static_cast(height * batch)}; const std::vector lws = {8, 16, 8}; const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(conv_2d_kernel); - auto params_generator = [&kwg_size]() -> std::vector> { + auto params_generator = [&]() -> std::vector> { + std::vector local_ws(3, 0); + local_ws[0] = std::min(channel_blocks, kwg_size); + local_ws[1] = std::min(width_blocks, kwg_size / local_ws[0]); + local_ws[2] = std::min(height * batch, kwg_size / (local_ws[0] * local_ws[1])); return {{4, 15, 8}, //SNPE size + {local_ws[0], local_ws[1], local_ws[2]}, {kwg_size / 16, 4, 4}, {kwg_size / 32, 4, 8}, {kwg_size / 32, 8, 4}, diff --git a/mace/kernels/opencl/pooling_opencl.cc b/mace/kernels/opencl/pooling_opencl.cc index 349c6195..5a0fbadf 100644 --- a/mace/kernels/opencl/pooling_opencl.cc +++ b/mace/kernels/opencl/pooling_opencl.cc @@ -6,6 +6,7 @@ #include "mace/core/runtime/opencl/cl2_header.h" #include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/kernels/opencl/helper.h" +#include "mace/utils/tuner.h" namespace mace { namespace kernels { @@ -23,11 +24,6 @@ static void Pooling(const Tensor *input, index_t channels = output->dim(3); index_t channel_blocks = (channels + 3) / 4; - const uint32_t gws[3] = { - static_cast(channel_blocks), - static_cast(out_width), - static_cast(batch * out_height), - }; auto runtime = OpenCLRuntime::Get(); std::set built_options; @@ -44,13 +40,6 @@ static void Pooling(const Tensor *input, } auto pooling_kernel = runtime->BuildKernel("pooling", "pooling", built_options); - const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(pooling_kernel); - - uint32_t lws[3]; - lws[0] = std::min(channel_blocks, kwg_size); - lws[1] = std::min(out_width, kwg_size / lws[0]); - lws[2] = std::min(out_height * batch, kwg_size / (lws[0] * lws[1])); - uint32_t idx = 0; pooling_kernel.setArg(idx++, *(static_cast(input->buffer()))); pooling_kernel.setArg(idx++, static_cast(input->dim(1))); @@ -62,12 +51,60 @@ static void Pooling(const Tensor *input, pooling_kernel.setArg(idx++, pooling_size); pooling_kernel.setArg(idx++, *(static_cast(output->buffer()))); - cl_int error = runtime->command_queue().enqueueNDRangeKernel( - pooling_kernel, cl::NullRange, - cl::NDRange(gws[0], gws[1], gws[2]), - cl::NDRange(lws[0], lws[1], lws[2]), - NULL, OpenCLRuntime::Get()->GetDefaultEvent()); - MACE_CHECK(error == CL_SUCCESS) << error; + const uint32_t gws[3] = { + static_cast(channel_blocks), + static_cast(out_width), + static_cast(batch * out_height), + }; + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(pooling_kernel); + std::vector lws(3, 0); + lws[0] = std::min(channel_blocks, kwg_size); + lws[1] = std::min(out_width, kwg_size / lws[0]); + lws[2] = std::min(out_height * batch, kwg_size / (lws[0] * lws[1])); + auto params_generator = [&]() -> std::vector> { + std::vector local_ws(3, 0); + local_ws[0] = std::min(channel_blocks, kwg_size); + local_ws[1] = std::min(out_width, kwg_size / local_ws[0]); + local_ws[2] = std::min(out_height * batch, kwg_size / (local_ws[0] * local_ws[1])); + return {{4, 15, 8}, //SNPE size + {local_ws[0], local_ws[1], local_ws[2]}, + {kwg_size / 16, 4, 4}, + {kwg_size / 32, 4, 8}, + {kwg_size / 32, 8, 4}, + {kwg_size / 64, 8, 8}, + {kwg_size / 64, 16, 4}, + {kwg_size / 128, 8, 16}, + {kwg_size / 128, 16, 8}, + {kwg_size / 128, 32, 4}, + {1, kwg_size / 32, 32}, + {1, kwg_size / 64, 64}, + {1, kwg_size / 128, 128}, + {3, 15, 9}, + {7, 15, 9}, + {9, 7, 15}, + {15, 7, 9}, + {1, kwg_size, 1}}; + }; + auto func = [&](const std::vector ¶ms) -> cl_int { + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + pooling_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(params[0], params[1], params[2]), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); + + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + return error; + }; + std::stringstream ss; + ss << "pooling_opencl_kernel_" + << output->dim(0) << "_" + << output->dim(1) << "_" + << output->dim(2) << "_" + << output->dim(3); + Tuner::Get()->template TuneOrRun(ss.str(), + lws, + params_generator, + func); } template diff --git a/mace/kernels/opencl/relu_opencl.cc b/mace/kernels/opencl/relu_opencl.cc index 28ff881b..483ec8d4 100644 --- a/mace/kernels/opencl/relu_opencl.cc +++ b/mace/kernels/opencl/relu_opencl.cc @@ -50,8 +50,13 @@ void ReluFunctor::operator()(const Tensor *input, static_cast(height * batch)}; const std::vector lws = {8, 16, 8}; const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(relu_kernel); - auto params_generator = [&kwg_size]() -> std::vector> { + auto params_generator = [&]() -> std::vector> { + std::vector local_ws(3, 0); + local_ws[0] = std::min(channel_blocks, kwg_size); + local_ws[1] = std::min(width, kwg_size / local_ws[0]); + local_ws[2] = std::min(height * batch, kwg_size / (local_ws[0] * local_ws[1])); return {{4, 15, 8}, //SNPE size + {local_ws[0], local_ws[1], local_ws[2]}, {kwg_size / 16, 4, 4}, {kwg_size / 32, 4, 8}, {kwg_size / 32, 8, 4}, diff --git a/mace/kernels/opencl/resize_bilinear_opencl.cc b/mace/kernels/opencl/resize_bilinear_opencl.cc index 1ebc21f8..a3686e47 100644 --- a/mace/kernels/opencl/resize_bilinear_opencl.cc +++ b/mace/kernels/opencl/resize_bilinear_opencl.cc @@ -7,6 +7,7 @@ #include "mace/kernels/resize_bilinear.h" #include "mace/kernels/opencl/helper.h" #include "mace/utils/utils.h" +#include "mace/utils/tuner.h" namespace mace { namespace kernels { @@ -44,8 +45,6 @@ void ResizeBilinearFunctor::operator()( built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); auto rb_kernel = runtime->BuildKernel("resize_bilinear", "resize_bilinear_nocache", built_options); - const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(rb_kernel); - uint32_t idx = 0; rb_kernel.setArg(idx++, *(static_cast(input->buffer()))); rb_kernel.setArg(idx++, *(static_cast(output->buffer()))); @@ -55,17 +54,52 @@ void ResizeBilinearFunctor::operator()( rb_kernel.setArg(idx++, static_cast(in_width)); rb_kernel.setArg(idx++, static_cast(out_height)); - auto command_queue = runtime->command_queue(); + const uint32_t gws[3] = {static_cast(channel_blocks), + static_cast(out_width), + static_cast(out_height * batch)}; + const std::vector lws = {8, 16, 8}; + const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(rb_kernel); + auto params_generator = [&]() -> std::vector> { + std::vector local_ws(3, 0); + local_ws[0] = std::min(channel_blocks, kwg_size); + local_ws[1] = std::min(out_width, kwg_size / local_ws[0]); + local_ws[2] = std::min(out_height * batch, kwg_size / (local_ws[0] * local_ws[1])); + return {{4, 15, 8}, //SNPE size + {local_ws[0], local_ws[1], local_ws[2]}, + {kwg_size / 16, 4, 4}, + {kwg_size / 32, 4, 8}, + {kwg_size / 32, 8, 4}, + {kwg_size / 64, 8, 8}, + {kwg_size / 64, 16, 4}, + {kwg_size / 128, 8, 16}, + {kwg_size / 128, 16, 8}, + {kwg_size / 128, 32, 4}, + {1, kwg_size / 32, 32}, + {1, kwg_size / 64, 64}, + {1, kwg_size / 128, 128}, + {1, kwg_size, 1}}; + }; + auto func = [&](const std::vector ¶ms) -> cl_int { + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + rb_kernel, cl::NullRange, + cl::NDRange(gws[0], gws[1], gws[2]), + cl::NDRange(params[0], params[1], params[2]), + NULL, OpenCLRuntime::Get()->GetDefaultEvent()); + + MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error; + return error; + }; + std::stringstream ss; + ss << "resize_bilinear_opencl_kernel_" + << output->dim(0) << "_" + << output->dim(1) << "_" + << output->dim(2) << "_" + << output->dim(3); + Tuner::Get()->template TuneOrRun(ss.str(), + lws, + params_generator, + func); - cl_int error = command_queue.enqueueNDRangeKernel( - rb_kernel, cl::NullRange, - cl::NDRange(static_cast(channel_blocks), - static_cast(out_width), - static_cast(out_height * batch)), - // TODO tuning - cl::NDRange(1, static_cast(out_width > kwg_size ? kwg_size : out_width), 1), - nullptr, OpenCLRuntime::Get()->GetDefaultEvent()); - MACE_CHECK(error == CL_SUCCESS, error); } template struct ResizeBilinearFunctor; diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index f603d3b5..08d9dcd7 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -1,7 +1,6 @@ from mace.proto import mace_pb2 import tensorflow as tf import numpy as np -from mace.python.tools.convert_util import tf_dtype_2_mace_dtype # TODO: support NCHW formt, now only support NHWC. padding_mode = { @@ -111,18 +110,14 @@ def add_output_transform(name, net_def): epsilon_arg.name = 'buffer_type' epsilon_arg.i = buffer_type_map['IN_OUT'] - -def convert_op_outputs(mace_op_def, tf_op): - mace_op_def.output.extend([output.name for output in tf_op.outputs]) - mace_op_def.output_type.extend([tf_dtype_2_mace_dtype(output.dtype) - for output in tf_op.outputs]) +def add_output_shape(outputs, op): output_shapes = [] - for output in tf_op.outputs: - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(output.shape.as_list()) - output_shapes.append(output_shape) - mace_op_def.output_shape.extend(output_shapes) - + for output in outputs: + if output.shape is not None and not output.shape: + output_shape = mace_pb2.OutputShape() + output_shape.dims.extend(output.shape.as_list()) + output_shapes.append(output_shape) + op.output_shape.extend(output_shapes) def convert_ops(unresolved_ops, dt, net_def, device): ops_count = len(unresolved_ops) @@ -185,7 +180,8 @@ def convert_ops(unresolved_ops, dt, net_def, device): final_op = relu_op resolved_count = 4 - convert_op_outputs(op_def, final_op) + op_def.output.extend([output.name for output in final_op.outputs]) + add_output_shape(final_op.outputs, op_def) elif first_op.type == 'FusedBatchNorm': op_def.name = first_op.name @@ -199,9 +195,7 @@ def convert_ops(unresolved_ops, dt, net_def, device): op_def.input.extend([input.name for input in first_op.inputs]) op_def.output.extend([first_op.outputs[0].name]) - output_shape = mace_pb2.OutputShape() - output_shape.dims.extend(first_op.outputs[0].shape.as_list()) - op_def.output_shape.extend([output_shape]) + add_output_shape(first_op.outputs, op_def) epsilon_arg = op_def.arg.add() epsilon_arg.name = 'epsilon' @@ -217,31 +211,42 @@ def convert_ops(unresolved_ops, dt, net_def, device): mul_2_op = unresolved_ops[4] sub_op = unresolved_ops[5] add_1_op = unresolved_ops[6] - # print (mul_op.type, mul_2_op.type, mul_1_op.type, sub_op.type) + print (mul_op.type, mul_2_op.type, mul_1_op.type, sub_op.type) if mul_op.type != 'Mul' or mul_2_op.type != 'Mul' or \ mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add': raise Exception('Invalid BatchNorm Op') - get_input_tensor(mul_1_op, 0) input_name = get_input_tensor(mul_1_op, 0).name gamma = get_input_tensor(mul_op, 1).name beta = get_input_tensor(sub_op, 0).name mean = get_input_tensor(mul_2_op, 0).name variance = get_input_tensor(add_op, 0).name - epsilon = get_input_tensor(add_op, 1).name op_def.name = first_op.name[:-4] # remove /add op_def.type = 'BatchNorm' - op_def.input.extend([input_name, gamma, beta, mean, variance, epsilon]) - convert_op_outputs(op_def, add_1_op) + if device == 'gpu': + op_def.input.extend([input_name]) + for tensor_name in [gamma, beta, mean, variance]: + output_name = add_buffer_to_image(tensor_name, "ARGUMENT", dt, net_def) + op_def.input.extend([output_name]) + else: + op_def.input.extend([input_name, gamma, beta, mean, variance]) + op_def.output.extend([output.name for output in add_1_op.outputs]) + add_output_shape(add_1_op.outputs, op_def) + epsilon_arg = op_def.arg.add() + epsilon_arg.name = 'epsilon' + epsilon_arg.f = get_input_tensor(add_op, 1).eval().astype(np.float) + data_format_arg = op_def.arg.add() + data_format_arg.name = 'data_format' + data_format_arg.s = 'NHWC' resolved_count = 7 elif first_op.type == 'Relu6': op_def.name = first_op.name op_def.type = 'Relu' op_def.input.extend([input.name for input in first_op.inputs]) - convert_op_outputs(op_def, first_op) - + op_def.output.extend([output.name for output in first_op.outputs]) + add_output_shape(first_op.outputs, op_def) max_limit_arg = op_def.arg.add() max_limit_arg.name = 'max_limit' max_limit_arg.f = 6 @@ -249,8 +254,8 @@ def convert_ops(unresolved_ops, dt, net_def, device): op_def.name = first_op.name op_def.type = 'Pooling' op_def.input.extend([input.name for input in first_op.inputs]) - convert_op_outputs(op_def, first_op) - + op_def.output.extend([output.name for output in first_op.outputs]) + add_output_shape(first_op.outputs, op_def) pooling_type_arg = op_def.arg.add() pooling_type_arg.name = 'pooling_type' pooling_type_arg.i = pooling_type_mode[first_op.type] @@ -270,31 +275,46 @@ def convert_ops(unresolved_ops, dt, net_def, device): op_def.name = first_op.name op_def.type = "AddN" op_def.input.extend([input.name for input in first_op.inputs]) - convert_op_outputs(op_def, first_op) + op_def.output.extend([output.name for output in first_op.outputs]) + add_output_shape(first_op.outputs, op_def) elif first_op.type == 'ConcatV2': op_def.name = first_op.name op_def.type = "Concat" op_def.input.extend([first_op.inputs[i].name for i in xrange(2)]) + op_def.output.extend([output.name for output in first_op.outputs]) axis_arg = op_def.arg.add() axis_arg.name = 'axis' axis_arg.i = get_input_tensor(first_op, 2).eval().astype(np.int32) - convert_op_outputs(op_def, first_op) + add_output_shape(first_op.outputs, op_def) elif first_op.type == 'ResizeBilinear': op_def.name = first_op.name op_def.type = "ResizeBilinear" op_def.input.extend([first_op.inputs[0].name]) + op_def.output.extend([output.name for output in first_op.outputs]) size_arg = op_def.arg.add() size_arg.name = 'size' size_arg.ints.extend(get_input_tensor(first_op, 1).eval().astype(np.int32).flat) size_arg = op_def.arg.add() size_arg.name = 'align_corners' size_arg.i = first_op.get_attr('align_corners') - convert_op_outputs(op_def, first_op) - elif first_op.type in ['Relu', 'SpaceToBatchND', 'BatchToSpaceND', 'BiasAdd']: + add_output_shape(first_op.outputs, op_def) + elif first_op.type == 'BiasAdd': + op_def.name = first_op.name + op_def.type = first_op.type + op_def.input.extend([first_op.inputs[0].name]) + if device == 'gpu': + output_name = add_buffer_to_image(first_op.inputs[1].name, "ARGUMENT", dt, net_def) + op_def.input.extend([output_name]) + else: + op_def.input.extend([first_op.inputs[1].name]) + op_def.output.extend([output.name for output in first_op.outputs]) + add_output_shape(first_op.outputs, op_def) + elif first_op.type in ['Relu', 'SpaceToBatchND', 'BatchToSpaceND']: op_def.name = first_op.name op_def.type = first_op.type op_def.input.extend([input.name for input in first_op.inputs]) - convert_op_outputs(op_def, first_op) + op_def.output.extend([output.name for output in first_op.outputs]) + add_output_shape(first_op.outputs, op_def) else: raise Exception('Unknown Op: %s, type: %s' % (first_op.name, first_op.type)) pass diff --git a/tools/validate.py b/tools/validate.py index 9edbdd24..f322ed70 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -4,6 +4,7 @@ import os import os.path import tensorflow as tf import numpy as np +from scipy import spatial from tensorflow import gfile @@ -34,9 +35,12 @@ def load_data(file): def valid_output(out_shape, mace_out_file, tf_out_value): mace_out_value = load_data(mace_out_file) if mace_out_value.size != 0: + similarity = (1 - spatial.distance.cosine(tf_out_value.flat, mace_out_value)) + print 'MACE VS TF similarity: ', similarity + if similarity > 0.999: + print '=======================Passed! Haha======================' mace_out_value = mace_out_value.reshape(out_shape) np.testing.assert_allclose(mace_out_value, tf_out_value, rtol=0.05) - print '=======================Passed! Haha======================' else: print '=======================Skip empty node===================' @@ -62,7 +66,7 @@ def run_model(input_shape): input_value = input_value.reshape(input_shape) output_value = session.run(output_node, feed_dict={input_node: [input_value]}) - # output_value.astype(np.float32).tofile( os.path.dirname(FLAGS.input_file) + '/tf_weight') + output_value.astype(np.float32).tofile( os.path.dirname(FLAGS.input_file) + '/tf_out') return output_value def main(unused_args): diff --git a/tools/validate_gcn.sh b/tools/validate_gcn.sh index b62cb784..8ac0a198 100644 --- a/tools/validate_gcn.sh +++ b/tools/validate_gcn.sh @@ -2,10 +2,10 @@ # Must run at root dir of mace project. set +x Usage() { - echo 'Usage: bash tools/validate_gcn.sh tf_model_file' + echo 'Usage: bash tools/validate_gcn.sh tf_model_path image_size' } -if [ $# != 1 ];then +if [ $# != 2 ];then Usage exit -1 fi @@ -19,12 +19,13 @@ OUTPUT_FILE_NAME='gcn.out' OUTPUT_LIST_FILE='gcn.list' PHONE_DATA_DIR="/data/local/tmp/${MACE_MODEL_NAME}" KERNEL_DIR="${PHONE_DATA_DIR}/cl/" +IMAGE_SIZE=$2 # Step 1: Generate input data echo "Step 1: Generate input data" python tools/validate.py --generate_data true --random_seed 1 \ --input_file=${MODEL_DIR}/${INPUT_FILE_NAME} \ - --input_shape=512,512,3 + --input_shape="${IMAGE_SIZE},${IMAGE_SIZE},3" # Step 2: convert tf model to mace model echo "Step 2: convert tf model to mace model and optimize memory" @@ -56,14 +57,15 @@ adb push bazel-bin/mace/examples/mace_run ${PHONE_DATA_DIR} num_threads=${1:-4} -adb