diff --git a/mace/kernels/opencl/cl/eltwise.cl b/mace/kernels/opencl/cl/eltwise.cl index e3cd7ecfaf4d155939da8bfb70d51d98758b1944..52ee65eb0f95fa144e82ea6fce7a4ee928615613 100644 --- a/mace/kernels/opencl/cl/eltwise.cl +++ b/mace/kernels/opencl/cl/eltwise.cl @@ -74,9 +74,9 @@ __kernel void eltwise(KERNEL_ERROR_PARAMS out = diff * diff; #elif ELTWISE_TYPE == 9 #ifdef SWAPPED - out = pow(in0, in1); - #else out = pow(in1, in0); + #else + out = pow(in0, in1); #endif #endif diff --git a/mace/kernels/opencl/cl/reduce_mean.cl b/mace/kernels/opencl/cl/reduce_mean.cl index ceaac871699c5fe3714208140e7533cc6b52fbb2..5a23d1051930ee5a0b5c010938ab46b35eca5766 100644 --- a/mace/kernels/opencl/cl/reduce_mean.cl +++ b/mace/kernels/opencl/cl/reduce_mean.cl @@ -3,46 +3,45 @@ __kernel void reduce_mean(KERNEL_ERROR_PARAMS GLOBAL_WORK_GROUP_SIZE_DIM3 __read_only image2d_t input, - __local float4* group_sum, + __local DATA_TYPE4 *group_sum, __private const int group_size, __private const int partial_len, __private const int remain_index, __private const int batch, __private const int in_height, __private const int in_width, - __private const float in_height_r, - __private const float in_width_r, + __private const float image_size_reciprocal, + __private const float in_width_reciprocal, __private const int channel_blocks, + __private const float channel_blocks_reciprocal, __write_only image2d_t output) { const int i = get_local_id(0); const int j = get_local_id(1); const int k = get_global_id(2); #ifndef NON_UNIFORM_WORK_GROUP - if (i >= local_size_dim0 || j >= local_size_dim1 || k >= global_size_dim2) + if (k >= global_size_dim2) return; - const int dim0_size = local_size_dim0; -#else - const int dim0_size = get_local_size(0); #endif + const int dim0_size = get_local_size(0); DATA_TYPE4 tmp = (DATA_TYPE4){0, 0, 0, 0}; - const int index = j * dim0_size + i; - const int b = k / channel_blocks; - const int ch = k - b * channel_blocks; + const int index = mad24(j, dim0_size, i); + const int b = floor(k * channel_blocks_reciprocal); + const int ch = mad24(b, -channel_blocks, k); DATA_TYPE4 in; const int valid_part_len = select(partial_len, partial_len - 1, remain_index > 0 && index >= remain_index); - const int full_offset = index * partial_len; + const int full_offset = mul24(index, partial_len); const int base_offset = select(full_offset, full_offset - (index - remain_index), valid_part_len < partial_len); #pragma unroll for (int l = 0; l < valid_part_len; ++l) { int offset = base_offset + l; - int h_id = floor(offset * in_width_r); - int w_id = offset - h_id * in_width; + int h_id = floor(offset * in_width_reciprocal); + int w_id = mad24(h_id, -in_width, offset); int pos_x = mad24(ch, in_width, w_id); int pos_y = mad24(b, in_height, h_id); in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y)); @@ -60,7 +59,7 @@ __kernel void reduce_mean(KERNEL_ERROR_PARAMS for (int l = 0; l < group_size; ++l) { out = out + group_sum[l]; } - out = out * in_height_r * in_width_r; + out = out * image_size_reciprocal; WRITE_IMAGET(output, (int2)(ch, b), out); } } diff --git a/mace/kernels/opencl/reduce_mean_opencl.cc b/mace/kernels/opencl/reduce_mean_opencl.cc index a8737c7fa77b9f01a76efe371fb546830e3d8bd9..8d47e4df610cc86c9aa68aa6a2ef9a1b6a0ae7b2 100644 --- a/mace/kernels/opencl/reduce_mean_opencl.cc +++ b/mace/kernels/opencl/reduce_mean_opencl.cc @@ -17,7 +17,7 @@ MaceStatus ReduceMeanFunctor::operator()( Tensor *output, StatsFuture *future) { MACE_CHECK_NOTNULL(input); - MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims."); +// MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims."); MACE_CHECK(input->dim_size() == 4, "reduce mean gpu only support 4-dim input"); MACE_CHECK(axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2, @@ -83,8 +83,9 @@ MaceStatus ReduceMeanFunctor::operator()( const int group_size = lws[0] * lws[1] * lws[2]; const int partial_len = (image_size + group_size - 1) / group_size; const int remain_index = image_size % group_size; - const float in_width_r = 1.f / in_width; - const float in_height_r = 1.f / in_height; + const float in_width_reciprocal = 1.f / in_width; + const float img_size_reciprocal = 1.f / (in_width * in_height); + const float channel_blk_reciprocal = 1.f / channel_blocks; if (!IsVecEqual(input_shape_, input->shape())) { uint32_t idx = 0; @@ -98,7 +99,7 @@ MaceStatus ReduceMeanFunctor::operator()( kernel_.setArg(idx++, gws[2]); } kernel_.setArg(idx++, *(input->opencl_image())); - kernel_.setArg(idx++, (group_size * 4 * sizeof(float)), + kernel_.setArg(idx++, (group_size * 4 * sizeof(T)), nullptr); kernel_.setArg(idx++, static_cast(group_size)); kernel_.setArg(idx++, static_cast(partial_len)); @@ -106,9 +107,10 @@ MaceStatus ReduceMeanFunctor::operator()( kernel_.setArg(idx++, static_cast(batch)); kernel_.setArg(idx++, static_cast(in_height)); kernel_.setArg(idx++, static_cast(in_width)); - kernel_.setArg(idx++, in_height_r); - kernel_.setArg(idx++, in_width_r); + kernel_.setArg(idx++, img_size_reciprocal); + kernel_.setArg(idx++, in_width_reciprocal); kernel_.setArg(idx++, static_cast(channel_blocks)); + kernel_.setArg(idx++, channel_blk_reciprocal); kernel_.setArg(idx++, *(output->opencl_image())); input_shape_ = input->shape(); diff --git a/mace/python/tools/converter_tool/base_converter.py b/mace/python/tools/converter_tool/base_converter.py index ac48726c633718b0de7110d44d951b8795e4ed7e..b5ef56b9a3e0ac4eb2e6049fef0f4ed566bb3681 100644 --- a/mace/python/tools/converter_tool/base_converter.py +++ b/mace/python/tools/converter_tool/base_converter.py @@ -149,6 +149,7 @@ class MaceKeyword(object): mace_device = 'device' mace_value_str = 'value' mace_wino_block_size = 'wino_block_size' + mace_output_shape_str = 'output_shape' mace_begin_mask_str = 'begin_mask' mace_end_mask_str = 'end_mask' mace_ellipsis_mask_str = 'ellipsis_mask' diff --git a/mace/python/tools/converter_tool/tensorflow_converter.py b/mace/python/tools/converter_tool/tensorflow_converter.py index 65211cfdfda8f1f284c01a1b30019d54ccc4d11c..63d046bd6ec60c4c71f0c2535569a8282a500915 100644 --- a/mace/python/tools/converter_tool/tensorflow_converter.py +++ b/mace/python/tools/converter_tool/tensorflow_converter.py @@ -57,6 +57,7 @@ TFSupportedOps = [ 'Max', 'Neg', 'Abs', + 'Pow', 'RealDiv', 'Square', 'SquaredDifference', @@ -119,6 +120,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Max.name: EltwiseType.MAX, TFOpType.Neg.name: EltwiseType.NEG, TFOpType.Abs.name: EltwiseType.ABS, + TFOpType.Pow.name: EltwiseType.POW, TFOpType.RealDiv.name: EltwiseType.DIV, TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF, TFOpType.Square.name: EltwiseType.POW, @@ -145,6 +147,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.Max.name: self.convert_elementwise, TFOpType.Neg.name: self.convert_elementwise, TFOpType.Abs.name: self.convert_elementwise, + TFOpType.Pow.name: self.convert_elementwise, TFOpType.RealDiv.name: self.convert_elementwise, TFOpType.SquaredDifference.name: self.convert_elementwise, TFOpType.Square.name: self.convert_elementwise, @@ -327,8 +330,17 @@ class TensorflowConverter(base_converter.ConverterInterface): dilation_val = tf_op.get_attr(tf_dilations_str)[1:3] except ValueError: dilation_val = [1, 1] - dilation_arg.ints.extend(dilation_val) + else: + del op.input[1:] + output_shape_arg = op.arg.add() + output_shape_arg.name = MaceKeyword.mace_output_shape_str + output_shape_value = tf_op.inputs[0].eval().astype(np.int32).flat + output_shape_arg.ints.extend(output_shape_value) + self._skip_tensor.add(tf_op.inputs[0].name) + del op.input[0] + if len(tf_op.inputs) >= 3: + op.input.extend([tf_op.inputs[2].name, tf_op.inputs[1].name]) def convert_elementwise(self, tf_op): op = self.convert_general_op(tf_op) @@ -348,7 +360,6 @@ class TensorflowConverter(base_converter.ConverterInterface): value_arg.f = -0.5 if type_arg.i != EltwiseType.NEG.value \ - and type_arg.i != EltwiseType.POW.value \ and type_arg.i != EltwiseType.ABS.value: if len(tf_op.inputs[0].shape) == 0: value_arg = op.arg.add() @@ -578,18 +589,30 @@ class TensorflowConverter(base_converter.ConverterInterface): op = self.convert_general_op(tf_op) del op.input[1:] - reduce_dims = tf_op.inputs[1].eval() op.type = MaceOp.ReduceMean.name axis_arg = op.arg.add() axis_arg.name = MaceKeyword.mace_axis_str + if len(tf_op.inputs) > 1: + reduce_dims = tf_op.inputs[1].eval() + else: + try: + reduce_dims = tf_op.get_attr('axis') + except ValueError: + try: + reduce_dims = tf_op.get_attr('reduction_indices') + except ValueError: + reduce_dims = [] axis_arg.ints.extend(reduce_dims) + keep_dims_arg = op.arg.add() + keep_dims_arg.name = MaceKeyword.mace_keepdims_str try: - keep_dims = tf_op.get_attr(MaceKeyword.mace_keepdims_str) - keep_dims_arg = op.arg.add() - keep_dims_arg.name = MaceKeyword.mace_keepdims_str - keep_dims_arg.i = keep_dims + keep_dims = tf_op.get_attr('keepdims') except ValueError: - pass + try: + keep_dims = tf_op.get_attr('keep_dims') + except ValueError: + keep_dims = 0 + keep_dims_arg.i = keep_dims self._skip_tensor.add(tf_op.inputs[1].name) diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index 5345129eaf7a3222d37f11dc7937c7373179bd06..4a9e3fbea4248d6febf3b009ef3779bf3a85cb25 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -919,7 +919,10 @@ class Transformer(base_converter.ConverterInterface): filter = self._consts[op.input[1]] filter_data = np.array(filter.float_data).reshape( filter.dims) - filter_data = filter_data.transpose(3, 2, 0, 1) + if op.type == MaceOp.Deconv2D.name: + filter_data = filter_data.transpose(2, 3, 0, 1) + else: + filter_data = filter_data.transpose(3, 2, 0, 1) filter.float_data[:] = filter_data.flat filter.dims[:] = filter_data.shape if op.type == MaceOp.FullyConnected.name: @@ -993,6 +996,13 @@ class Transformer(base_converter.ConverterInterface): self.buffer_to_image(op, 2, OpenCLBufferType.ARGUMENT) elif op.type == MaceOp.BiasAdd.name: self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT) + elif op.type == MaceOp.Eltwise.name and len(op.input) == 2: + if op.input[0] in self._consts \ + and len(self._consts[op.input[0]].dims) == 1: + self.buffer_to_image(op, 0, OpenCLBufferType.ARGUMENT) + if op.input[1] in self._consts \ + and len(self._consts[op.input[1]].dims) == 1: + self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT) elif op.type == MaceOp.FoldedBatchNorm.name: self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT) self.buffer_to_image(op, 2, OpenCLBufferType.ARGUMENT)