diff --git a/mace/kernels/opencl/buffer_to_image.cc b/mace/kernels/opencl/buffer_to_image.cc index 6d8f3ef1b5e7f0833b31af78a62af21bc188aa7b..999c631712daeb5d3de6a84c5ce6b6fb6ab07a5f 100644 --- a/mace/kernels/opencl/buffer_to_image.cc +++ b/mace/kernels/opencl/buffer_to_image.cc @@ -86,7 +86,12 @@ void BufferToImageFunctor::operator()( static_cast(buffer->buffer_offset() / GetEnumTypeSize(buffer->dtype()))); } - if (type == ARGUMENT) { + if (type == CONV2D_FILTER) { + b2f_kernel.setArg(idx++, static_cast(buffer->dim(0))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(1))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(2))); + b2f_kernel.setArg(idx++, static_cast(buffer->dim(3))); + } else if (type == ARGUMENT) { b2f_kernel.setArg(idx++, static_cast(buffer->dim(0))); } else if (type == WEIGHT_HEIGHT || type == WEIGHT_WIDTH) { b2f_kernel.setArg(idx++, static_cast(buffer->dim(0))); diff --git a/mace/kernels/opencl/cl/buffer_to_image.cl b/mace/kernels/opencl/cl/buffer_to_image.cl index 781d21e363a2a9320999ab4dcc933ffab5fcc0fa..ece729b0ccdc3383b452090bd286a7309d90bafd 100644 --- a/mace/kernels/opencl/cl/buffer_to_image.cl +++ b/mace/kernels/opencl/cl/buffer_to_image.cl @@ -2,22 +2,25 @@ __kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, oc, ic */ __private const int input_offset, + __private const int filter_h, __private const int filter_w, __private const int out_channel, __private const int in_channel, __write_only image2d_t output) { int w = get_global_id(0); int h = get_global_id(1); - const int out_channel_idx = h * 4; - const int rounded_in_channel = ((in_channel + 3) / 4) * 4; - const int hw_idx = w / rounded_in_channel; - const int in_channel_idx = w % rounded_in_channel; + const int in_channel_idx = w; + const int hw_size = filter_w * filter_h; + const int out_channel_idx = h / hw_size * 4; + const int hw_idx = h % hw_size; const int h_idx = hw_idx / filter_w; const int w_idx = hw_idx % filter_w; - const int offset = input_offset + ((h_idx * filter_w + w_idx) * out_channel + out_channel_idx) * in_channel - + in_channel_idx; + const int offset = input_offset + + ((h_idx * filter_w + w_idx) * out_channel + + out_channel_idx) * in_channel + + in_channel_idx; - VEC_DATA_TYPE(DATA_TYPE, 4) values = 0; + DATA_TYPE4 values = 0; if (out_channel_idx < out_channel) { const int size = out_channel - out_channel_idx; if (size < 4) { @@ -38,28 +41,30 @@ __kernel void filter_buffer_to_image(__global const DATA_TYPE *input, /* h, w, o } int2 coord = (int2)(w, h); - CMD_TYPE(write_image, CMD_DATA_TYPE)(output, coord, values); + WRITE_IMAGET(output, coord, values); } __kernel void filter_image_to_buffer(__global DATA_TYPE *output, /* h, w, oc, ic */ + __private const int filter_h, __private const int filter_w, __private const int out_channel, __private const int in_channel, __read_only image2d_t input) { int w = get_global_id(0); int h = get_global_id(1); - const int out_channel_idx = h * 4; - const int rounded_in_channel = ((in_channel + 3) / 4) * 4; - const int hw_idx = w / rounded_in_channel; - const int in_channel_idx = w % rounded_in_channel; + const int in_channel_idx = w; + const int hw_size = filter_w * filter_h; + const int out_channel_idx = h / hw_size * 4; + const int hw_idx = h % hw_size; const int h_idx = hw_idx / filter_w; const int w_idx = hw_idx % filter_w; - const int offset = ((h_idx * filter_w + w_idx) * out_channel + out_channel_idx) * in_channel - + in_channel_idx; + const int offset = ((h_idx * filter_w + w_idx) * out_channel + + out_channel_idx) * in_channel + + in_channel_idx; if (out_channel_idx < out_channel) { int2 coord = (int2)(w, h); - VEC_DATA_TYPE(DATA_TYPE, 4) values = CMD_TYPE(read_image, CMD_DATA_TYPE)(input, SAMPLER, coord); + DATA_TYPE4 values = READ_IMAGET(input, SAMPLER, coord); const int size = (out_channel - out_channel_idx); if (size < 4) { switch (size) { @@ -145,7 +150,7 @@ __kernel void in_out_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ + channel_idx; const int size = channels - channel_idx; - VEC_DATA_TYPE(DATA_TYPE, 4) values = 0; + DATA_TYPE4 values = 0; if (size < 4) { switch(size) { case 3: @@ -159,7 +164,7 @@ __kernel void in_out_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ values = vload4(0, input + offset); } int2 coord = (int2)(w, h); - CMD_TYPE(write_image, CMD_DATA_TYPE)(output, coord, values); + WRITE_IMAGET(output, coord, values); } __kernel void in_out_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ @@ -177,7 +182,7 @@ __kernel void in_out_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ + channel_idx; int2 coord = (int2)(w, h); - VEC_DATA_TYPE(DATA_TYPE, 4) values = CMD_TYPE(read_image, CMD_DATA_TYPE)(input, SAMPLER, coord); + DATA_TYPE4 values = READ_IMAGET(input, SAMPLER, coord); const int size = channels - channel_idx; if (size < 4) { switch (size) { @@ -204,7 +209,7 @@ __kernel void arg_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ const int size = count - w * 4; - VEC_DATA_TYPE(DATA_TYPE, 4) values = 0; + DATA_TYPE4 values = 0; if (size < 4) { switch(size) { case 3: @@ -218,7 +223,7 @@ __kernel void arg_buffer_to_image(__global const DATA_TYPE *input, /* nhwc */ values = vload4(0, input + offset); } int2 coord = (int2)(w, h); - CMD_TYPE(write_image, CMD_DATA_TYPE)(output, coord, values); + WRITE_IMAGET(output, coord, values); } __kernel void arg_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ @@ -229,7 +234,7 @@ __kernel void arg_image_to_buffer(__global DATA_TYPE *output, /* nhwc */ const int offset = w * 4; int2 coord = (int2)(w, h); - VEC_DATA_TYPE(DATA_TYPE, 4) values = CMD_TYPE(read_image, CMD_DATA_TYPE)(input, SAMPLER, coord); + DATA_TYPE4 values = READ_IMAGET(input, SAMPLER, coord); const int size = count - offset; if (size < 4) { switch (size) { diff --git a/mace/kernels/opencl/cl/conv_2d.cl b/mace/kernels/opencl/cl/conv_2d.cl index 8ed3073fd229a8f0c4935847cd56842df506efad..42d79807ccb402b70a1d5e24a209f490cbadb77e 100644 --- a/mace/kernels/opencl/cl/conv_2d.cl +++ b/mace/kernels/opencl/cl/conv_2d.cl @@ -1,7 +1,7 @@ #include __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ - __read_only image2d_t filter, /* cout%4 * cin * kh * kw, cout/4 */ + __read_only image2d_t filter, /* cout%4 * cin, kh * kw * cout/4 */ #ifdef BIAS __read_only image2d_t bias, /* cout%4 * cout/4 */ #endif @@ -23,7 +23,6 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ const int out_w_blk = get_global_id(1); const int out_w_blks = get_global_size(1); const int out_hb = get_global_id(2); - const int rounded_in_ch = in_ch_blks << 2; #ifdef BIAS DATA_TYPE4 out0 = @@ -46,21 +45,21 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ const int height_idx = mad24((out_hb % out_height), stride, -padding_top); const int batch_idx = mul24((out_hb / out_height), in_height); - const int rounded_in_ch_x_filter_width = mul24(rounded_in_ch, filter_width); + const int filter_hw = mul24(filter_width, filter_height); DATA_TYPE4 in0, in1, in2, in3; DATA_TYPE4 weights0, weights1, weights2, weights3; for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { const int in_idx = mul24(in_ch_blk, in_width); - int filter_x_part0 = in_ch_blk << 2; + int filter_x_idx = in_ch_blk << 2; + int filter_y_idx = mul24(out_ch_blk, filter_hw); for (short hb_idx = 0; hb_idx < filter_height; ++hb_idx) { - // TODO(heliangliang) optimize out these muls int in_hb_value = height_idx + mul24(hb_idx, dilation_h); in_hb_value = select(in_hb_value + batch_idx, -1, (in_hb_value < 0 || in_hb_value >= in_height)); - int filter_x_part1 = 0; +#pragma unroll for (short width_idx = 0; width_idx < filter_width; ++width_idx) { int in_width_value; #define READ_INPUT(i) \ @@ -78,11 +77,10 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ #undef READ_INPUT // int filter_idx = (hb_idx * filter_width + width_idx) * rounded_in_ch + (in_ch_blk << 2); - int filter_idx = filter_x_part0 + filter_x_part1; - weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 0, out_ch_blk)); - weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 1, out_ch_blk)); - weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 2, out_ch_blk)); - weights3 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 3, out_ch_blk)); + weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 0, filter_y_idx)); + weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 1, filter_y_idx)); + weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 2, filter_y_idx)); + weights3 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 3, filter_y_idx)); out0 = mad(in0.x, weights0, out0); out0 = mad(in0.y, weights1, out0); @@ -105,9 +103,8 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ out3 = mad(in3.z, weights2, out3); out3 = mad(in3.w, weights3, out3); - filter_x_part1 += rounded_in_ch; + filter_y_idx += 1; } - filter_x_part0 += rounded_in_ch_x_filter_width; } } diff --git a/mace/kernels/opencl/cl/conv_2d_3x3.cl b/mace/kernels/opencl/cl/conv_2d_3x3.cl index fad561aaca4aa8f6fe862f314177221214264053..7f7fd367d1aa4019a2a2009b4ed61ca179e23ac7 100644 --- a/mace/kernels/opencl/cl/conv_2d_3x3.cl +++ b/mace/kernels/opencl/cl/conv_2d_3x3.cl @@ -1,7 +1,7 @@ #include __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ - __read_only image2d_t filter, /* cout%4 * cin * kh * kw, cout/4 */ + __read_only image2d_t filter, /* cout%4 * cin , kh * kw * cout/4 */ #ifdef BIAS __read_only image2d_t bias, /* cout%4 * cout/4 */ #endif @@ -21,7 +21,6 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] const int out_w_blk = get_global_id(1); const int out_w_blks = get_global_size(1); const int out_hb = get_global_id(2); - const int rounded_in_ch = in_ch_blks << 2; #ifdef BIAS DATA_TYPE4 out0 = @@ -47,19 +46,18 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] const int height_idx = mad24((out_hb % out_height), stride, -padding_top); const int batch_idx = mul24((out_hb / out_height), in_height); - const int rounded_in_ch_x_3 = (rounded_in_ch << 1) + rounded_in_ch; DATA_TYPE4 in0, in1, in2, in3, in4; DATA_TYPE4 weights0, weights1, weights2, weights3; for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { const int in_idx = mul24(in_ch_blk, in_width); - int filter_x_part0 = in_ch_blk << 2; + int filter_x_idx = in_ch_blk << 2; + int filter_y_idx = mul24(out_ch_blk, 9); int in_hb_idx = height_idx; for (short hb_idx = 0; hb_idx < 3; ++hb_idx) { int in_hb_value = select(in_hb_idx + batch_idx, -1, (in_hb_idx < 0 || in_hb_idx >= in_height)); - int filter_x_part1 = 0; int in_width_idx = 0; for (short width_idx = 0; width_idx < 3; ++width_idx) { int in_width_value; @@ -79,11 +77,10 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] #undef READ_INPUT // int filter_idx = (hb_idx * 3 + width_idx) * rounded_in_ch + (in_ch_blk << 2); - int filter_idx = filter_x_part0 + filter_x_part1; - weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 0, out_ch_blk)); - weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 1, out_ch_blk)); - weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 2, out_ch_blk)); - weights3 = READ_IMAGET(filter, SAMPLER, (int2)(filter_idx + 3, out_ch_blk)); + weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 0, filter_y_idx)); + weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 1, filter_y_idx)); + weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 2, filter_y_idx)); + weights3 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 3, filter_y_idx)); out0 = mad(in0.x, weights0, out0); out0 = mad(in0.y, weights1, out0); @@ -111,10 +108,9 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] out4 = mad(in4.z, weights2, out4); out4 = mad(in4.w, weights3, out4); - filter_x_part1 += rounded_in_ch; in_width_idx += dilation_w; + filter_y_idx += 1; } - filter_x_part0 += rounded_in_ch_x_3; in_hb_idx += dilation_h; } } diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index ee52625a6337bb9be5390b4392fd5b93e5a88214..e7dfb641cb5eec81ec6f83971645be8ea3dc33bb 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -23,13 +23,13 @@ void CalInOutputImageShape(const std::vector &shape, /* NHWC */ (*image_shape)[1] = shape[0] * shape[1]; } -// [RoundUp<4>(Ic) * H * W, (Oc + 3) / 4] +// [RoundUp<4>(Ic), H * W * (Oc + 3) / 4] void CalConv2dFilterImageShape(const std::vector &shape, /* HWOI */ std::vector *image_shape) { MACE_CHECK(shape.size() == 4); image_shape->resize(2); - (*image_shape)[0] = shape[0] * shape[1] * RoundUp(shape[3], 4); - (*image_shape)[1] = RoundUpDiv4(shape[2]); + (*image_shape)[0] = RoundUp(shape[3], 4); + (*image_shape)[1] = shape[0] * shape[1] * RoundUpDiv4(shape[2]); } // [H * W * M, (Ic + 3) / 4] diff --git a/mace/kernels/reorganize.h b/mace/kernels/reorganize.h index 68c772090d5db75c5cf609da23ea82f2ccc844eb..a64d55b97400188dd99ff4cccbec2b8e92287dc7 100644 --- a/mace/kernels/reorganize.h +++ b/mace/kernels/reorganize.h @@ -74,7 +74,6 @@ struct ReOrganizeFunctor { } } } - } }; diff --git a/mace/ops/conv_2d_benchmark.cc b/mace/ops/conv_2d_benchmark.cc index 90e4579eb9c53c4870a083f9871001420509318e..f06a7e127359e391a54b28bb4d35891416f32cbb 100644 --- a/mace/ops/conv_2d_benchmark.cc +++ b/mace/ops/conv_2d_benchmark.cc @@ -114,6 +114,7 @@ static void Conv2d(int iters, BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, float, OPENCL); \ BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, D, P, OC, half, OPENCL); + BM_CONV_2D(1, 256, 64, 64, 3, 3, 1, 1, VALID, 256); BM_CONV_2D(1, 512, 15, 15, 1, 1, 1, 1, VALID, 1024); @@ -135,6 +136,8 @@ BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, 1, SAME, 128); BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, 1, SAME, 128); BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, 1, SAME, 128); +BM_CONV_2D(1, 1024, 16, 16, 15, 1, 1, 1, SAME, 2); + // Dilation BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 2, VALID, 32); BM_CONV_2D(1, 32, 256, 256, 3, 3, 1, 4, VALID, 32); diff --git a/mace/python/tools/caffe_converter_lib.py b/mace/python/tools/caffe_converter_lib.py index 7c7cd9abd71cb8b4720f782ffc71835033c3e97c..f81cdc0f0dc7e678083da0da2e5eedd601b6bf64 100644 --- a/mace/python/tools/caffe_converter_lib.py +++ b/mace/python/tools/caffe_converter_lib.py @@ -72,9 +72,9 @@ class Shapes(object): output_shape = np.zeros_like(input_shape) output_shape[0] = input_shape[0] output_shape[1] = int(round_func((input_shape[1] + paddings[0] - filter_shape[0] - - (filter_shape[0] - 1) * (dilations[0] - 1)) / float(strides[0]))) + 1 + - (filter_shape[0] - 1) * (dilations[0] - 1)) / float(strides[0]))) + 1 output_shape[2] = int(round_func((input_shape[2] + paddings[1] - filter_shape[1] - - (filter_shape[1] - 1) * (dilations[1] - 1)) / float(strides[1]))) + 1 + - (filter_shape[1] - 1) * (dilations[1] - 1)) / float(strides[1]))) + 1 output_shape[3] = filter_shape[2] return output_shape @@ -333,8 +333,18 @@ class CaffeConverter(object): return pad, stride, kernel def convert_conv2d(self, op): - op_def = self.CommonConvert(op, 'Conv2D') param = op.layer.convolution_param + is_depthwise = False + if param.HasField('group'): + if param.group == op.data[0].shape[0] and op.data[0].shape[1] == 1: + is_depthwise = True + else: + raise Exception("Mace do not support group convolution yet") + + if is_depthwise: + op_def = self.CommonConvert(op, 'DepthwiseConv2d') + else: + op_def = self.CommonConvert(op, 'Conv2D') # Add filter weight_tensor_name = op.name + '_weight:0' @@ -342,7 +352,7 @@ class CaffeConverter(object): self.add_tensor(weight_tensor_name, weight_data) if self.device == 'gpu': - buffer_type = "CONV2D_FILTER" + buffer_type = "DW_CONV2D_FILTER" if is_depthwise else "CONV2D_FILTER" output_name = self.add_buffer_to_image(weight_tensor_name, buffer_type) op_def.input.extend([output_name]) else: @@ -373,15 +383,16 @@ class CaffeConverter(object): self.resolved_ops.add(op.name) output_shape = Shapes.conv_pool_shape(op.get_single_parent().output_shape_map[op.layer.bottom[0]], - weight_data.shape, - paddings, strides, dilations, - math.floor) + weight_data.shape, + paddings, strides, dilations, + math.floor) op.output_shape_map[op.layer.top[0]] = output_shape if len(self.ops_map[final_op.name].children) == 1 \ and self.ops_map[final_op.name].children[0].type in activation_name_map: activation_op = self.ops_map[final_op.name].children[0] - op_def.type = "FusedConv2D" + if not is_depthwise: + op_def.type = "FusedConv2D" fused_act_arg = op_def.arg.add() fused_act_arg.name = 'activation' fused_act_arg.s = activation_name_map[activation_op.type] @@ -412,7 +423,7 @@ class CaffeConverter(object): width = output_shape[0] * ((output_shape[1] + 1)/2) * ((output_shape[2]+1)/2) return self.winograd and self.device == 'gpu' and \ filter_shape[0] == 3 and (filter_shape[0] == filter_shape[1]) and \ - dilations[0] == 1 and (dilations[0] == dilations[1]) and\ + dilations[0] == 1 and (dilations[0] == dilations[1]) and \ (strides[0] == 1) and (strides[0] == strides[1]) and \ (16 * filter_shape[2] < OPENCL_IMAGE_MAX_SIZE) and \ (16 * filter_shape[3] < OPENCL_IMAGE_MAX_SIZE) and \ @@ -662,7 +673,7 @@ class CaffeConverter(object): filter_shape = [kernels[0], kernels[1], input_shape[3], input_shape[3]] output_shape = Shapes.conv_pool_shape(input_shape, filter_shape, - paddings, strides, [1, 1], math.ceil) + paddings, strides, [1, 1], math.ceil) op.output_shape_map[op.layer.top[0]] = output_shape op_def.output.extend([op.name + ':0']) @@ -764,7 +775,7 @@ class CaffeConverter(object): input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]] num_outputs = len(op.layer.top) if (input_shape[3] % num_outputs) != 0 or \ - (self.device == 'gpu' and ((input_shape[3] / num_outputs) % 4 != 0)) : + (self.device == 'gpu' and ((input_shape[3] / num_outputs) % 4 != 0)) : raise Exception('Mace do not support slice with input shape ' + str(input_shape) + ' and number of output ' + str(num_outputs)) output_shape = Shapes.slice_shape(input_shape, num_outputs) @@ -789,7 +800,6 @@ class CaffeConverter(object): input_shape = op.parents[0].output_shape_map[op.layer.bottom[0]] output_shape = input_shape shape_param = np.asarray(op.layer.reshape_param.shape.dim)[[0, 3, 2, 1]] - print shape_param for i in range(len(shape_param)): if shape_param[i] != 0: output_shape[i] = shape_param[i] @@ -967,3 +977,4 @@ def convert_to_mace_pb(model_file, weight_file, input_node_str, input_shape_str, print "Memory optimization done." return net_def + diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 7177a691dabae5bac8fe0fd884d05850d4bac586..1ad426a357f773782b6c87fbc5b935bffe3b45af 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -362,7 +362,8 @@ class TFConverter(object): if len(self.tf_graph.get(final_op.name, [])) == 1 \ and self.tf_graph[final_op.name][0].type in activation_name_map: activation_op = self.tf_graph[final_op.name][0] - op_def.type = "FusedConv2D" + if op_def.type == "Conv2D": + op_def.type = "FusedConv2D" fused_act_arg = op_def.arg.add() fused_act_arg.name = 'activation' fused_act_arg.s = activation_name_map[activation_op.type] diff --git a/tools/mace_tools.py b/tools/mace_tools.py index c9a22f6472e33f8b8245cee9da5796c32d5d5e1d..4f2b209a700439fffd6f466551c0dffceb555805 100644 --- a/tools/mace_tools.py +++ b/tools/mace_tools.py @@ -76,26 +76,28 @@ def generate_random_input(target_soc, model_output_dir, target_soc, model_output_dir, int(generate_data_or_not)) run_command(command) - input_name_list = [] input_file_list = [] - if isinstance(input_names, list): - input_name_list.extend(input_names) - else: - input_name_list.append(input_names) if isinstance(input_files, list): input_file_list.extend(input_files) else: input_file_list.append(input_files) - assert len(input_file_list) == len(input_name_list) - for i in range(len(input_file_list)): - if input_file_list[i] is not None: - dst_input_file = model_output_dir + '/' + input_file_name(input_name_list[i]) - if input_file_list[i].startswith("http://") or \ - input_file_list[i].startswith("https://"): - urllib.urlretrieve(input_file_list[i], dst_input_file) - else: - print 'Copy input data:', dst_input_file - shutil.copy(input_file_list[i], dst_input_file) + if len(input_file_list) != 0: + input_name_list = [] + if isinstance(input_names, list): + input_name_list.extend(input_names) + else: + input_name_list.append(input_names) + if len(input_file_list) != len(input_name_list): + raise Exception('If input_files set, the input files should match the input names.') + for i in range(len(input_file_list)): + if input_file_list[i] is not None: + dst_input_file = model_output_dir + '/' + input_file_name(input_name_list[i]) + if input_file_list[i].startswith("http://") or \ + input_file_list[i].startswith("https://"): + urllib.urlretrieve(input_file_list[i], dst_input_file) + else: + print 'Copy input data:', dst_input_file + shutil.copy(input_file_list[i], dst_input_file) def generate_model_code(): command = "bash tools/generate_model_code.sh"