提交 649f1819 编写于 作者: 刘琦

Merge branch 'fast-style-transfer' into 'master'

support fast style transfer model and fix reduce mean bugs to support resnet-V2-50 gpu

See merge request !592
...@@ -74,9 +74,9 @@ __kernel void eltwise(KERNEL_ERROR_PARAMS ...@@ -74,9 +74,9 @@ __kernel void eltwise(KERNEL_ERROR_PARAMS
out = diff * diff; out = diff * diff;
#elif ELTWISE_TYPE == 9 #elif ELTWISE_TYPE == 9
#ifdef SWAPPED #ifdef SWAPPED
out = pow(in0, in1);
#else
out = pow(in1, in0); out = pow(in1, in0);
#else
out = pow(in0, in1);
#endif #endif
#endif #endif
......
...@@ -3,46 +3,45 @@ ...@@ -3,46 +3,45 @@
__kernel void reduce_mean(KERNEL_ERROR_PARAMS __kernel void reduce_mean(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3 GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input, __read_only image2d_t input,
__local float4* group_sum, __local DATA_TYPE4 *group_sum,
__private const int group_size, __private const int group_size,
__private const int partial_len, __private const int partial_len,
__private const int remain_index, __private const int remain_index,
__private const int batch, __private const int batch,
__private const int in_height, __private const int in_height,
__private const int in_width, __private const int in_width,
__private const float in_height_r, __private const float image_size_reciprocal,
__private const float in_width_r, __private const float in_width_reciprocal,
__private const int channel_blocks, __private const int channel_blocks,
__private const float channel_blocks_reciprocal,
__write_only image2d_t output) { __write_only image2d_t output) {
const int i = get_local_id(0); const int i = get_local_id(0);
const int j = get_local_id(1); const int j = get_local_id(1);
const int k = get_global_id(2); const int k = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP #ifndef NON_UNIFORM_WORK_GROUP
if (i >= local_size_dim0 || j >= local_size_dim1 || k >= global_size_dim2) if (k >= global_size_dim2)
return; return;
const int dim0_size = local_size_dim0;
#else
const int dim0_size = get_local_size(0);
#endif #endif
const int dim0_size = get_local_size(0);
DATA_TYPE4 tmp = (DATA_TYPE4){0, 0, 0, 0}; DATA_TYPE4 tmp = (DATA_TYPE4){0, 0, 0, 0};
const int index = j * dim0_size + i; const int index = mad24(j, dim0_size, i);
const int b = k / channel_blocks; const int b = floor(k * channel_blocks_reciprocal);
const int ch = k - b * channel_blocks; const int ch = mad24(b, -channel_blocks, k);
DATA_TYPE4 in; DATA_TYPE4 in;
const int valid_part_len = select(partial_len, const int valid_part_len = select(partial_len,
partial_len - 1, partial_len - 1,
remain_index > 0 && index >= remain_index); remain_index > 0 && index >= remain_index);
const int full_offset = index * partial_len; const int full_offset = mul24(index, partial_len);
const int base_offset = select(full_offset, const int base_offset = select(full_offset,
full_offset - (index - remain_index), full_offset - (index - remain_index),
valid_part_len < partial_len); valid_part_len < partial_len);
#pragma unroll #pragma unroll
for (int l = 0; l < valid_part_len; ++l) { for (int l = 0; l < valid_part_len; ++l) {
int offset = base_offset + l; int offset = base_offset + l;
int h_id = floor(offset * in_width_r); int h_id = floor(offset * in_width_reciprocal);
int w_id = offset - h_id * in_width; int w_id = mad24(h_id, -in_width, offset);
int pos_x = mad24(ch, in_width, w_id); int pos_x = mad24(ch, in_width, w_id);
int pos_y = mad24(b, in_height, h_id); int pos_y = mad24(b, in_height, h_id);
in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y)); in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y));
...@@ -60,7 +59,7 @@ __kernel void reduce_mean(KERNEL_ERROR_PARAMS ...@@ -60,7 +59,7 @@ __kernel void reduce_mean(KERNEL_ERROR_PARAMS
for (int l = 0; l < group_size; ++l) { for (int l = 0; l < group_size; ++l) {
out = out + group_sum[l]; out = out + group_sum[l];
} }
out = out * in_height_r * in_width_r; out = out * image_size_reciprocal;
WRITE_IMAGET(output, (int2)(ch, b), out); WRITE_IMAGET(output, (int2)(ch, b), out);
} }
} }
...@@ -17,7 +17,7 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()( ...@@ -17,7 +17,7 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()(
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
MACE_CHECK_NOTNULL(input); MACE_CHECK_NOTNULL(input);
MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims."); // MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims.");
MACE_CHECK(input->dim_size() == 4, MACE_CHECK(input->dim_size() == 4,
"reduce mean gpu only support 4-dim input"); "reduce mean gpu only support 4-dim input");
MACE_CHECK(axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2, MACE_CHECK(axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2,
...@@ -83,8 +83,9 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()( ...@@ -83,8 +83,9 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()(
const int group_size = lws[0] * lws[1] * lws[2]; const int group_size = lws[0] * lws[1] * lws[2];
const int partial_len = (image_size + group_size - 1) / group_size; const int partial_len = (image_size + group_size - 1) / group_size;
const int remain_index = image_size % group_size; const int remain_index = image_size % group_size;
const float in_width_r = 1.f / in_width; const float in_width_reciprocal = 1.f / in_width;
const float in_height_r = 1.f / in_height; const float img_size_reciprocal = 1.f / (in_width * in_height);
const float channel_blk_reciprocal = 1.f / channel_blocks;
if (!IsVecEqual(input_shape_, input->shape())) { if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0; uint32_t idx = 0;
...@@ -98,7 +99,7 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()( ...@@ -98,7 +99,7 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()(
kernel_.setArg(idx++, gws[2]); kernel_.setArg(idx++, gws[2]);
} }
kernel_.setArg(idx++, *(input->opencl_image())); kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, (group_size * 4 * sizeof(float)), kernel_.setArg(idx++, (group_size * 4 * sizeof(T)),
nullptr); nullptr);
kernel_.setArg(idx++, static_cast<int32_t>(group_size)); kernel_.setArg(idx++, static_cast<int32_t>(group_size));
kernel_.setArg(idx++, static_cast<int32_t>(partial_len)); kernel_.setArg(idx++, static_cast<int32_t>(partial_len));
...@@ -106,9 +107,10 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()( ...@@ -106,9 +107,10 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()(
kernel_.setArg(idx++, static_cast<int32_t>(batch)); kernel_.setArg(idx++, static_cast<int32_t>(batch));
kernel_.setArg(idx++, static_cast<int32_t>(in_height)); kernel_.setArg(idx++, static_cast<int32_t>(in_height));
kernel_.setArg(idx++, static_cast<int32_t>(in_width)); kernel_.setArg(idx++, static_cast<int32_t>(in_width));
kernel_.setArg(idx++, in_height_r); kernel_.setArg(idx++, img_size_reciprocal);
kernel_.setArg(idx++, in_width_r); kernel_.setArg(idx++, in_width_reciprocal);
kernel_.setArg(idx++, static_cast<int32_t>(channel_blocks)); kernel_.setArg(idx++, static_cast<int32_t>(channel_blocks));
kernel_.setArg(idx++, channel_blk_reciprocal);
kernel_.setArg(idx++, *(output->opencl_image())); kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape(); input_shape_ = input->shape();
......
...@@ -149,6 +149,7 @@ class MaceKeyword(object): ...@@ -149,6 +149,7 @@ class MaceKeyword(object):
mace_device = 'device' mace_device = 'device'
mace_value_str = 'value' mace_value_str = 'value'
mace_wino_block_size = 'wino_block_size' mace_wino_block_size = 'wino_block_size'
mace_output_shape_str = 'output_shape'
mace_begin_mask_str = 'begin_mask' mace_begin_mask_str = 'begin_mask'
mace_end_mask_str = 'end_mask' mace_end_mask_str = 'end_mask'
mace_ellipsis_mask_str = 'ellipsis_mask' mace_ellipsis_mask_str = 'ellipsis_mask'
......
...@@ -57,6 +57,7 @@ TFSupportedOps = [ ...@@ -57,6 +57,7 @@ TFSupportedOps = [
'Max', 'Max',
'Neg', 'Neg',
'Abs', 'Abs',
'Pow',
'RealDiv', 'RealDiv',
'Square', 'Square',
'SquaredDifference', 'SquaredDifference',
...@@ -119,6 +120,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -119,6 +120,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Max.name: EltwiseType.MAX, TFOpType.Max.name: EltwiseType.MAX,
TFOpType.Neg.name: EltwiseType.NEG, TFOpType.Neg.name: EltwiseType.NEG,
TFOpType.Abs.name: EltwiseType.ABS, TFOpType.Abs.name: EltwiseType.ABS,
TFOpType.Pow.name: EltwiseType.POW,
TFOpType.RealDiv.name: EltwiseType.DIV, TFOpType.RealDiv.name: EltwiseType.DIV,
TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF, TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF,
TFOpType.Square.name: EltwiseType.POW, TFOpType.Square.name: EltwiseType.POW,
...@@ -145,6 +147,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -145,6 +147,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Max.name: self.convert_elementwise, TFOpType.Max.name: self.convert_elementwise,
TFOpType.Neg.name: self.convert_elementwise, TFOpType.Neg.name: self.convert_elementwise,
TFOpType.Abs.name: self.convert_elementwise, TFOpType.Abs.name: self.convert_elementwise,
TFOpType.Pow.name: self.convert_elementwise,
TFOpType.RealDiv.name: self.convert_elementwise, TFOpType.RealDiv.name: self.convert_elementwise,
TFOpType.SquaredDifference.name: self.convert_elementwise, TFOpType.SquaredDifference.name: self.convert_elementwise,
TFOpType.Square.name: self.convert_elementwise, TFOpType.Square.name: self.convert_elementwise,
...@@ -327,8 +330,17 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -327,8 +330,17 @@ class TensorflowConverter(base_converter.ConverterInterface):
dilation_val = tf_op.get_attr(tf_dilations_str)[1:3] dilation_val = tf_op.get_attr(tf_dilations_str)[1:3]
except ValueError: except ValueError:
dilation_val = [1, 1] dilation_val = [1, 1]
dilation_arg.ints.extend(dilation_val) dilation_arg.ints.extend(dilation_val)
else:
del op.input[1:]
output_shape_arg = op.arg.add()
output_shape_arg.name = MaceKeyword.mace_output_shape_str
output_shape_value = tf_op.inputs[0].eval().astype(np.int32).flat
output_shape_arg.ints.extend(output_shape_value)
self._skip_tensor.add(tf_op.inputs[0].name)
del op.input[0]
if len(tf_op.inputs) >= 3:
op.input.extend([tf_op.inputs[2].name, tf_op.inputs[1].name])
def convert_elementwise(self, tf_op): def convert_elementwise(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
...@@ -348,7 +360,6 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -348,7 +360,6 @@ class TensorflowConverter(base_converter.ConverterInterface):
value_arg.f = -0.5 value_arg.f = -0.5
if type_arg.i != EltwiseType.NEG.value \ if type_arg.i != EltwiseType.NEG.value \
and type_arg.i != EltwiseType.POW.value \
and type_arg.i != EltwiseType.ABS.value: and type_arg.i != EltwiseType.ABS.value:
if len(tf_op.inputs[0].shape) == 0: if len(tf_op.inputs[0].shape) == 0:
value_arg = op.arg.add() value_arg = op.arg.add()
...@@ -578,18 +589,30 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -578,18 +589,30 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
del op.input[1:] del op.input[1:]
reduce_dims = tf_op.inputs[1].eval()
op.type = MaceOp.ReduceMean.name op.type = MaceOp.ReduceMean.name
axis_arg = op.arg.add() axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str axis_arg.name = MaceKeyword.mace_axis_str
if len(tf_op.inputs) > 1:
reduce_dims = tf_op.inputs[1].eval()
else:
try:
reduce_dims = tf_op.get_attr('axis')
except ValueError:
try:
reduce_dims = tf_op.get_attr('reduction_indices')
except ValueError:
reduce_dims = []
axis_arg.ints.extend(reduce_dims) axis_arg.ints.extend(reduce_dims)
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
try: try:
keep_dims = tf_op.get_attr(MaceKeyword.mace_keepdims_str) keep_dims = tf_op.get_attr('keepdims')
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
keep_dims_arg.i = keep_dims
except ValueError: except ValueError:
pass try:
keep_dims = tf_op.get_attr('keep_dims')
except ValueError:
keep_dims = 0
keep_dims_arg.i = keep_dims
self._skip_tensor.add(tf_op.inputs[1].name) self._skip_tensor.add(tf_op.inputs[1].name)
......
...@@ -919,7 +919,10 @@ class Transformer(base_converter.ConverterInterface): ...@@ -919,7 +919,10 @@ class Transformer(base_converter.ConverterInterface):
filter = self._consts[op.input[1]] filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape( filter_data = np.array(filter.float_data).reshape(
filter.dims) filter.dims)
filter_data = filter_data.transpose(3, 2, 0, 1) if op.type == MaceOp.Deconv2D.name:
filter_data = filter_data.transpose(2, 3, 0, 1)
else:
filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
if op.type == MaceOp.FullyConnected.name: if op.type == MaceOp.FullyConnected.name:
...@@ -993,6 +996,13 @@ class Transformer(base_converter.ConverterInterface): ...@@ -993,6 +996,13 @@ class Transformer(base_converter.ConverterInterface):
self.buffer_to_image(op, 2, OpenCLBufferType.ARGUMENT) self.buffer_to_image(op, 2, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.BiasAdd.name: elif op.type == MaceOp.BiasAdd.name:
self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT) self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.Eltwise.name and len(op.input) == 2:
if op.input[0] in self._consts \
and len(self._consts[op.input[0]].dims) == 1:
self.buffer_to_image(op, 0, OpenCLBufferType.ARGUMENT)
if op.input[1] in self._consts \
and len(self._consts[op.input[1]].dims) == 1:
self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.FoldedBatchNorm.name: elif op.type == MaceOp.FoldedBatchNorm.name:
self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT) self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT)
self.buffer_to_image(op, 2, OpenCLBufferType.ARGUMENT) self.buffer_to_image(op, 2, OpenCLBufferType.ARGUMENT)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册