From 0c3967d634e038e54e249af8d041a3f4eb1dc92b Mon Sep 17 00:00:00 2001 From: liuqi Date: Fri, 16 Mar 2018 11:27:34 +0800 Subject: [PATCH] Optimize fully connect op and support winograd for caffe. --- mace/core/runtime/opencl/opencl_extension.h | 2 + mace/core/runtime/opencl/opencl_runtime.cc | 7 + mace/core/runtime/opencl/opencl_runtime.h | 1 + mace/kernels/opencl/cl/fully_connected.cl | 22 ++- mace/kernels/opencl/fully_connected_opencl.cc | 24 ++- mace/ops/fully_connected_test.cc | 29 +-- mace/ops/winograd_convolution_test.cc | 102 ++++++++++ mace/python/tools/caffe_converter_lib.py | 178 ++++++++++++++++-- tools/validate.py | 13 +- 9 files changed, 331 insertions(+), 47 deletions(-) diff --git a/mace/core/runtime/opencl/opencl_extension.h b/mace/core/runtime/opencl/opencl_extension.h index a18f921e..6192eff3 100644 --- a/mace/core/runtime/opencl/opencl_extension.h +++ b/mace/core/runtime/opencl/opencl_extension.h @@ -25,4 +25,6 @@ typedef cl_uint cl_priority_hint; #define CL_PRIORITY_HINT_NORMAL_QCOM 0x40CB #define CL_PRIORITY_HINT_LOW_QCOM 0x40CC +/* Accepted by clGetKernelWorkGroupInfo */ +#define CL_KERNEL_WAVE_SIZE_QCOM 0xAA02 #endif // MACE_CORE_RUNTIME_OPENCL_OPENCL_EXTENSION_H_ diff --git a/mace/core/runtime/opencl/opencl_runtime.cc b/mace/core/runtime/opencl/opencl_runtime.cc index 793b53fe..5a4ca0ea 100644 --- a/mace/core/runtime/opencl/opencl_runtime.cc +++ b/mace/core/runtime/opencl/opencl_runtime.cc @@ -331,4 +331,11 @@ uint32_t OpenCLRuntime::GetKernelMaxWorkGroupSize(const cl::Kernel &kernel) { return static_cast(size); } +// TODO(liuqi): not compatible with mali gpu. +uint32_t OpenCLRuntime::GetKernelWaveSize(const cl::Kernel &kernel) { + unsigned long long size = 0; + kernel.getWorkGroupInfo(*device_, CL_KERNEL_WAVE_SIZE_QCOM, &size); + return static_cast(size); +} + } // namespace mace diff --git a/mace/core/runtime/opencl/opencl_runtime.h b/mace/core/runtime/opencl/opencl_runtime.h index faa81838..58c6cdab 100644 --- a/mace/core/runtime/opencl/opencl_runtime.h +++ b/mace/core/runtime/opencl/opencl_runtime.h @@ -48,6 +48,7 @@ class OpenCLRuntime { void GetCallStats(const cl::Event &event, CallStats *stats); uint32_t GetDeviceMaxWorkGroupSize(); uint32_t GetKernelMaxWorkGroupSize(const cl::Kernel &kernel); + uint32_t GetKernelWaveSize(const cl::Kernel &kernel); cl::Kernel BuildKernel(const std::string &program_name, const std::string &kernel_name, const std::set &build_options); diff --git a/mace/kernels/opencl/cl/fully_connected.cl b/mace/kernels/opencl/cl/fully_connected.cl index 217224db..878c110d 100644 --- a/mace/kernels/opencl/cl/fully_connected.cl +++ b/mace/kernels/opencl/cl/fully_connected.cl @@ -66,12 +66,15 @@ __kernel void fully_connected_width(__read_only image2d_t input, __local float *intermediate_output, __private const int input_height, __private const int input_width, - __private const short in_chan_blks, + __private const int in_chan_blks, + __private const int out_blks, __private const float relux_max_limit) { const int inter_out_idx = get_global_id(0); const int width_blk_idx = get_global_id(1); const int width_blk_count = get_global_size(1); - const int out_blk_idx = get_global_id(2); + const int batch_out_blk_idx = get_global_id(2); + const int batch_idx = batch_out_blk_idx / out_blks; + const int out_blk_idx = batch_out_blk_idx % out_blks; const short in_outer_size = mul24(input_width, in_chan_blks); const short weight_y = mad24(out_blk_idx, 4, inter_out_idx); @@ -80,16 +83,17 @@ __kernel void fully_connected_width(__read_only image2d_t input, DATA_TYPE4 in, w; DATA_TYPE sum = 0.0; - input_coord = (int2)(0, 0); + input_coord = (int2)(0, mul24(batch_idx, input_height)); - for (short h_idx = 0; h_idx < input_height; ++h_idx) { - short weight_x_base = mul24(h_idx, in_outer_size); - for (short w_idx = (short)width_blk_idx; w_idx < input_width; w_idx += width_blk_count) { - short weight_x = mad24(w_idx, in_chan_blks, weight_x_base); + for (int h_idx = 0; h_idx < input_height; ++h_idx) { + int weight_x_base = mul24(h_idx, in_outer_size); + for (int w_idx = width_blk_idx; w_idx < input_width; + w_idx += width_blk_count) { + int weight_x = mad24(w_idx, in_chan_blks, weight_x_base); weight_coord = (int2)(weight_x, weight_y); input_coord.x = w_idx; #pragma unroll - for (short chan_idx = 0; chan_idx < in_chan_blks; ++chan_idx) { + for (int chan_idx = 0; chan_idx < in_chan_blks; ++chan_idx) { in = READ_IMAGET(input, SAMPLER, input_coord); w = READ_IMAGET(weight, SAMPLER, weight_coord); @@ -125,6 +129,6 @@ __kernel void fully_connected_width(__read_only image2d_t input, #if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID) result = do_activation(result, relux_max_limit); #endif - WRITE_IMAGET(output, (int2)(out_blk_idx, 0), result); + WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result); } } diff --git a/mace/kernels/opencl/fully_connected_opencl.cc b/mace/kernels/opencl/fully_connected_opencl.cc index ca07b989..abcbfe52 100644 --- a/mace/kernels/opencl/fully_connected_opencl.cc +++ b/mace/kernels/opencl/fully_connected_opencl.cc @@ -35,16 +35,22 @@ void FCWXKernel(cl::Kernel *kernel, built_options.emplace("-DBIAS"); } switch (activation) { - case NOOP:break; - case RELU:built_options.emplace("-DUSE_RELU"); + case NOOP: + break; + case RELU: + built_options.emplace("-DUSE_RELU"); break; - case RELUX:built_options.emplace("-DUSE_RELUX"); + case RELUX: + built_options.emplace("-DUSE_RELUX"); break; - case TANH:built_options.emplace("-DUSE_TANH"); + case TANH: + built_options.emplace("-DUSE_TANH"); break; - case SIGMOID:built_options.emplace("-DUSE_SIGMOID"); + case SIGMOID: + built_options.emplace("-DUSE_SIGMOID"); break; - default:LOG(FATAL) << "Unknown activation type: " << activation; + default: + LOG(FATAL) << "Unknown activation type: " << activation; } *kernel = @@ -53,8 +59,9 @@ void FCWXKernel(cl::Kernel *kernel, const index_t batch = output->dim(0); const index_t output_size = output->dim(3); const index_t output_blocks = RoundUpDiv4(output_size); + const uint32_t wave_size = runtime->GetKernelWaveSize(*kernel); - gws = {4, 8, static_cast(output_blocks)}; + gws = {4, (wave_size / 4), static_cast(batch * output_blocks)}; const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(*kernel); const uint32_t inter_local_blks = kwg_size / (gws[0] * gws[1]); @@ -70,7 +77,8 @@ void FCWXKernel(cl::Kernel *kernel, kernel->setArg(idx++, (lws[0] * lws[1] * lws[2] * sizeof(float)), nullptr); kernel->setArg(idx++, static_cast(input->dim(1))); kernel->setArg(idx++, static_cast(input->dim(2))); - kernel->setArg(idx++, static_cast(RoundUpDiv4(input->dim(3)))); + kernel->setArg(idx++, static_cast(RoundUpDiv4(input->dim(3)))); + kernel->setArg(idx++, static_cast(output_blocks)); kernel->setArg(idx++, relux_max_limit); } cl::Event event; diff --git a/mace/ops/fully_connected_test.cc b/mace/ops/fully_connected_test.cc index cfbc6796..c50e6952 100644 --- a/mace/ops/fully_connected_test.cc +++ b/mace/ops/fully_connected_test.cc @@ -187,11 +187,11 @@ TEST_F(FullyConnectedOpTest, OPENCLHalfUnAlignedWithBatch) { } template -void TestWeightWidthFormat(const index_t batch, - const index_t height, - const index_t width, - const index_t channels, - const index_t out_channel) { +void TestWXFormat(const index_t batch, + const index_t height, + const index_t width, + const index_t channels, + const index_t out_channel) { srand(time(NULL)); // Construct graph @@ -246,14 +246,21 @@ void TestWeightWidthFormat(const index_t batch, } TEST_F(FullyConnectedOpTest, OPENCLWidthFormatAligned) { - TestWeightWidthFormat(1, 7, 7, 32, 16); - TestWeightWidthFormat(1, 7, 7, 512, 128); - TestWeightWidthFormat(1, 1, 1, 2048, 1024); + TestWXFormat(1, 7, 7, 32, 16); + TestWXFormat(1, 7, 7, 512, 128); + TestWXFormat(1, 1, 1, 2048, 1024); } + +TEST_F(FullyConnectedOpTest, OPENCLWidthFormatMultiBatch) { + TestWXFormat(11, 7, 7, 32, 16); + TestWXFormat(5, 7, 7, 512, 128); + TestWXFormat(3, 1, 1, 2048, 1024); +} + TEST_F(FullyConnectedOpTest, OPENCLHalfWidthFormatAligned) { - TestWeightWidthFormat(1, 2, 2, 512, 2); - TestWeightWidthFormat(1, 11, 11, 32, 16); - TestWeightWidthFormat(1, 16, 32, 32, 32); + TestWXFormat(1, 2, 2, 512, 2); + TestWXFormat(1, 11, 11, 32, 16); + TestWXFormat(1, 16, 32, 32, 32); } } diff --git a/mace/ops/winograd_convolution_test.cc b/mace/ops/winograd_convolution_test.cc index 9fc5e40b..41e8478a 100644 --- a/mace/ops/winograd_convolution_test.cc +++ b/mace/ops/winograd_convolution_test.cc @@ -148,4 +148,106 @@ TEST_F(WinogradConvlutionTest, BatchConvolution) { WinogradConvolution(5, 61, 67, 37, 31, Padding::SAME); } + +template +void WinogradConvolutionWithPad(const index_t batch, + const index_t height, + const index_t width, + const index_t in_channels, + const index_t out_channels, + const int padding) { + srand(time(NULL)); + + // Construct graph + OpsTestNet net; + // Add input data + std::vector filter_data; + std::vector filter_shape = {3, 3, out_channels, in_channels}; + GenerateRandomRealTypeData(filter_shape, filter_data); + net.AddRandomInput("Input", {batch, height, width, in_channels}); + net.AddInputFromArray("Filter", filter_shape, filter_data); + net.AddRandomInput("Bias", {out_channels}); + + BufferToImage(net, "Input", "InputImage", + kernels::BufferType::IN_OUT_CHANNEL); + BufferToImage(net, "Filter", "FilterImage", + kernels::BufferType::CONV2D_FILTER); + BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); + OpDefBuilder("Conv2D", "Conv2dTest") + .Input("InputImage") + .Input("FilterImage") + .Input("BiasImage") + .Output("OutputImage") + .AddIntsArg("strides", {1, 1}) + .AddIntsArg("padding_values", {padding, padding}) + .AddIntsArg("dilations", {1, 1}) + .Finalize(net.NewOperatorDef()); + + net.RunOp(D); + + // Transfer output + ImageToBuffer(net, "OutputImage", "ConvOutput", + kernels::BufferType::IN_OUT_CHANNEL); + Tensor expected; + expected.Copy(*net.GetOutput("ConvOutput")); + auto output_shape = expected.shape(); + + // Winograd convolution + // transform filter + std::vector wino_filter_data; + TransposeFilter(filter_data, filter_shape, wino_filter_data); + net.AddInputFromArray( + "WinoFilterData", {out_channels, in_channels, 3, 3}, wino_filter_data); + BufferToImage(net, "WinoFilterData", "WinoFilter", + kernels::BufferType::WINOGRAD_FILTER); + + // transform input + OpDefBuilder("WinogradTransform", "WinogradTransformTest") + .Input("InputImage") + .Output("WinoInput") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .AddIntsArg("padding_values", {padding, padding}) + .Finalize(net.NewOperatorDef()); + + // Run on opencl + net.RunOp(D); + + // MatMul + OpDefBuilder("MatMul", "MatMulTest") + .Input("WinoFilter") + .Input("WinoInput") + .Output("WinoGemm") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(net.NewOperatorDef()); + // Run on opencl + net.RunOp(D); + + // Inverse transform + OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest") + .Input("WinoGemm") + .Input("BiasImage") + .AddIntArg("batch", batch) + .AddIntArg("height", output_shape[1]) + .AddIntArg("width", output_shape[2]) + .Output("WinoOutputImage") + .Finalize(net.NewOperatorDef()); + + // Run on opencl + net.RunOp(D); + net.Sync(); + + ImageToBuffer(net, "WinoOutputImage", "WinoOutput", + kernels::BufferType::IN_OUT_CHANNEL); + if (DataTypeToEnum::value == DataType::DT_HALF) { + ExpectTensorNear(expected, *net.GetOutput("WinoOutput"), 1e-1); + } else { + ExpectTensorNear(expected, *net.GetOutput("WinoOutput"), 1e-3); + } +} + +TEST_F(WinogradConvlutionTest, UnAlignedConvolutionPad2) { + WinogradConvolutionWithPad(1, 64, 64, 40, 19, 2); + WinogradConvolutionWithPad(1, 32, 32, 96, 109, 2); +} + } diff --git a/mace/python/tools/caffe_converter_lib.py b/mace/python/tools/caffe_converter_lib.py index fbbe9f9f..c69297c0 100644 --- a/mace/python/tools/caffe_converter_lib.py +++ b/mace/python/tools/caffe_converter_lib.py @@ -19,6 +19,7 @@ buffer_type_map = { 'WINOGRAD_FILTER' : 5, 'DW_CONV2D_FILTER' : 6, 'WEIGHT_HEIGHT' : 7, + 'WEIGHT_WIDTH' : 8, } data_type_map = { @@ -310,24 +311,25 @@ class CaffeConverter(object): pad = [param.pad * 2, param.pad * 2] kernel = [param.kernel_size, param.kernel_size] - strides_arg = op_def.arg.add() - strides_arg.name = 'strides' if param.HasField("stride_h") or param.HasField("stride_w"): stride = [param.stride_h, param.stride_w] - strides_arg.ints.extend(stride) # Pad - padding_arg = op_def.arg.add() - padding_arg.name = 'padding_values' if param.HasField("pad_h") or param.HasField("pad_w"): pad = [param.pad_h * 2, param.pad_w * 2] - padding_arg.ints.extend(pad) - # kernel - if op_def.type == 'Pooling': - kernel_arg = op_def.arg.add() - kernel_arg.name = 'kernels' - if param.HasField("kernel_h") or param.HasField("kernel_w"): - kernel = [param.kernel_h, param.kernel_w] - kernel_arg.ints.extend(kernel) + + if op_def is not None: + strides_arg = op_def.arg.add() + strides_arg.name = 'strides' + strides_arg.ints.extend(stride) + + padding_arg = op_def.arg.add() + padding_arg.name = 'padding_values' + padding_arg.ints.extend(pad) + + if op_def.type == 'Pooling': + if param.HasField("kernel_h") or param.HasField("kernel_w"): + kernel = [param.kernel_h, param.kernel_w] + return pad, stride, kernel def convert_conv2d(self, op): @@ -391,6 +393,126 @@ class CaffeConverter(object): self.add_output_shape(op_def, output_shape) self.net_def.op.extend([op_def]) + def check_winograd_conv(self, op): + param = op.layer.convolution_param + filter_shape = np.asarray(op.data[0].shape) + filter_shape = filter_shape[[2, 3, 0, 1]] + paddings, strides, _ = self.add_stride_pad_kernel_arg(param, None) + + dilations = [1, 1] + if len(param.dilation) > 0: + if len(param.dilation) == 1: + dilations = [param.dilation[0], param.dilation[0]] + elif len(param.dilation) == 2: + dilations = [param.dilation[0], param.dilation[1]] + + output_shape = Shapes.conv_pool_shape( + op.get_single_parent().output_shape_map[op.layer.bottom[0]], + filter_shape, paddings, strides, dilations, math.floor) + width = output_shape[0] * ((output_shape[1] + 1)/2) * ((output_shape[2]+1)/2) + return self.winograd and self.device == 'gpu' and \ + filter_shape[0] == 3 and (filter_shape[0] == filter_shape[1]) and \ + dilations[0] == 1 and (dilations[0] == dilations[1]) and\ + (strides[0] == 1) and (strides[0] == strides[1]) and \ + (16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \ + (16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \ + (width < OPENCL_IMAGE_MAX_SIZE) + + def convert_winograd_conv(self, op): + # Add filter + weight_tensor_name = op.name + '_weight:0' + self.add_tensor(weight_tensor_name, op.data[0]) + print 'Winograd filter shape:', op.data[0].shape + + buffer_type = "WINOGRAD_FILTER" + filter_name = self.add_buffer_to_image(weight_tensor_name, buffer_type) + + param = op.layer.convolution_param + paddings, strides, _ = self.add_stride_pad_kernel_arg(param, None) + + filter_shape = np.asarray(op.data[0].shape) + filter_shape = filter_shape[[2, 3, 0, 1]] + + output_shape = Shapes.conv_pool_shape( + op.get_single_parent().output_shape_map[op.layer.bottom[0]], + filter_shape, paddings, strides, [1, 1], math.floor) + + # Input transform + wt_op = mace_pb2.OperatorDef() + arg = wt_op.arg.add() + arg.name = 'T' + arg.i = self.dt + padding_arg = wt_op.arg.add() + padding_arg.name = 'padding_values' + padding_arg.ints.extend(paddings) + wt_op.name = op.name + '_input_transform' + wt_op.type = 'WinogradTransform' + wt_op.input.extend([name+':0' for name in self.inputs_map[op.name]]) + wt_output_name = wt_op.name + ":0" + wt_op.output.extend([wt_output_name]) + wt_output_shape = mace_pb2.OutputShape() + wt_output_width = output_shape[0] * ((output_shape[1] + 1)/2) * ((output_shape[2]+1)/2) + wt_output_shape.dims.extend([16, filter_shape[3], wt_output_width, 1]) + wt_op.output_shape.extend([wt_output_shape]) + + # MatMul + matmul_op = mace_pb2.OperatorDef() + arg = matmul_op.arg.add() + arg.name = 'T' + arg.i = self.dt + matmul_op.name = op.name + '_matmul' + matmul_op.type = 'MatMul' + matmul_op.input.extend([filter_name, wt_output_name]) + matmul_output_name = matmul_op.name + ":0" + matmul_op.output.extend([matmul_output_name]) + matmul_output_shape = mace_pb2.OutputShape() + matmul_output_shape.dims.extend([16, filter_shape[2], wt_output_width, 1]) + matmul_op.output_shape.extend([matmul_output_shape]) + + # Inverse transform + iwt_op = mace_pb2.OperatorDef() + arg = iwt_op.arg.add() + arg.name = 'T' + arg.i = self.dt + batch_arg = iwt_op.arg.add() + batch_arg.name = 'batch' + batch_arg.i = output_shape[0] + height_arg = iwt_op.arg.add() + height_arg.name = 'height' + height_arg.i = output_shape[1] + width_arg = iwt_op.arg.add() + width_arg.name = 'width' + width_arg.i = output_shape[2] + iwt_op.name = op.name + '_inverse_transform' + iwt_op.type = 'WinogradInverseTransform' + iwt_op.input.extend([matmul_output_name]) + + # Add Bias + if len(op.data) == 2: + bias_tensor_name = op.name + '_bias:0' + bias_data = op.data[1].reshape(-1) + self.add_tensor(bias_tensor_name, bias_data) + output_name = self.add_buffer_to_image(bias_tensor_name, "ARGUMENT") + iwt_op.input.extend([output_name]) + + final_op = op + final_op.output_shape_map[final_op.layer.top[0]] = output_shape + self.resolved_ops.add(op.name) + + if len(self.ops_map[final_op.name].children) == 1 \ + and self.ops_map[final_op.name].children[0].type in activation_name_map: + activation_op = self.ops_map[final_op.name].children[0] + fused_act_arg = iwt_op.arg.add() + fused_act_arg.name = 'activation' + fused_act_arg.s = activation_name_map[activation_op.type] + final_op = activation_op + final_op.output_shape_map[final_op.layer.top[0]] = output_shape + self.resolved_ops.add(activation_op.name) + + iwt_op.output.extend([final_op.name+':0']) + self.add_output_shape(iwt_op, output_shape) + self.net_def.op.extend([wt_op, matmul_op, iwt_op]) + def convert_batchnorm(self, op): if len(op.children) != 1 or op.children[0].type != 'Scale': raise Exception('Now only support BatchNorm+Scale') @@ -468,10 +590,21 @@ class CaffeConverter(object): self.add_tensor(weight_tensor_name, weight_data) if self.device == 'gpu': if (weight_data.shape[0] + 3) / 4 > OPENCL_IMAGE_MAX_SIZE \ - or weight_data.shape[1] > OPENCL_IMAGE_MAX_SIZE: + and (weight_data.shape[1] + 3) / 4 > OPENCL_IMAGE_MAX_SIZE: + raise Exception('Mace gpu do not support FC with weight shape: ' + +str(weight_data.shape)) + if input_shape[3] % 4 == 0: + buffer_type = "WEIGHT_WIDTH" + else: + buffer_type = "WEIGHT_HEIGHT" + weight_type_arg = op_def.arg.add() + weight_type_arg.name = 'weight_type' + weight_type_arg.i = buffer_type_map['WEIGHT_HEIGHT'] + + if buffer_type == "WEIGHT_HEIGHT" and \ + (weight_data.shape[0] + 3) / 4 > OPENCL_IMAGE_MAX_SIZE: raise Exception('Mace gpu do not support FC with weight shape: ' +str(weight_data.shape)) - buffer_type = "WEIGHT_HEIGHT" output_name = self.add_buffer_to_image(weight_tensor_name, buffer_type) op_def.input.extend([output_name]) else: @@ -521,6 +654,13 @@ class CaffeConverter(object): pooling_type_arg.i = pooling_type_mode[pooling_type] input_shape = op.get_single_parent().output_shape_map[op.layer.bottom[0]] + if param.HasField('global_pooling') and param.global_pooling: + kernels = [input_shape[1], input_shape[2]] + + kernel_arg = op_def.arg.add() + kernel_arg.name = 'kernels' + kernel_arg.ints.extend(kernels) + filter_shape = [kernels[0], kernels[1], input_shape[3], input_shape[3]] output_shape = Shapes.conv_pool_shape(input_shape, filter_shape, paddings, strides, [1, 1], math.ceil) @@ -684,7 +824,10 @@ class CaffeConverter(object): if op.type == 'Input': self.resolved_ops.add(op.name) elif op.type == 'Convolution': - self.convert_conv2d(op) + if self.check_winograd_conv(op): + self.convert_winograd_conv(op) + else: + self.convert_conv2d(op) elif op.type == 'BatchNorm': self.convert_batchnorm(op) elif op.type == 'InnerProduct': @@ -719,7 +862,8 @@ class CaffeConverter(object): print 'Unresolve Op: %s with type %s' % (op.name, op.type) -def convert_to_mace_pb(model_file, weight_file, input_node_str, input_shape_str, output_node_str, data_type, device, winograd): +def convert_to_mace_pb(model_file, weight_file, input_node_str, input_shape_str, + output_node_str, data_type, device, winograd): net_def = mace_pb2.NetDef() dt = data_type_map[data_type] diff --git a/tools/validate.py b/tools/validate.py index 4978a995..ed1c916c 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -5,6 +5,7 @@ import os.path import numpy as np import re from scipy import spatial +from scipy import stats # Validation Flow: # 1. Generate input data @@ -30,7 +31,10 @@ def format_output_name(name): def compare_output(output_name, mace_out_value, out_value): if mace_out_value.size != 0: - similarity = (1 - spatial.distance.cosine(out_value.flat, mace_out_value.flat)) + out_value = out_value.reshape(-1) + mace_out_value = mace_out_value.reshape(-1) + assert len(out_value) == len(mace_out_value) + similarity = (1 - spatial.distance.cosine(out_value, mace_out_value)) print output_name, 'MACE VS', FLAGS.platform.upper(), 'similarity: ', similarity if (FLAGS.mace_runtime == "cpu" and similarity > 0.999) or \ (FLAGS.mace_runtime == "gpu" and similarity > 0.995) or \ @@ -92,16 +96,21 @@ def validate_caffe_model(input_names, input_shapes, output_names, output_shapes) for i in range(len(input_names)): input_value = load_data(FLAGS.input_file + "_" + input_names[i]) input_value = input_value.reshape(input_shapes[i]).transpose((0, 3, 1, 2)) - net.blobs[input_names[i]].data[0] = input_value + input_blob_name = input_names[i] + if input_names[i] in net.top_names: + input_blob_name = net.top_names[input_names[i]][0] + net.blobs[input_blob_name].data[0] = input_value net.forward() for i in range(len(output_names)): value = net.blobs[net.top_names[output_names[i]][0]].data[0] + print net.top_names[output_names[i]][0] out_shape = output_shapes[i] out_shape[1], out_shape[2], out_shape[3] = out_shape[3], out_shape[1], out_shape[2] value = value.reshape(out_shape).transpose((0, 2, 3, 1)) output_file_name = FLAGS.mace_out_file + "_" + format_output_name(output_names[i]) + print 'output file name:', output_file_name mace_out_value = load_data(output_file_name) compare_output(output_names[i], mace_out_value, value) -- GitLab