提交 649f1819 编写于 作者: 刘琦

Merge branch 'fast-style-transfer' into 'master'

support fast style transfer model and fix reduce mean bugs to support resnet-V2-50 gpu

See merge request !592
......@@ -74,9 +74,9 @@ __kernel void eltwise(KERNEL_ERROR_PARAMS
out = diff * diff;
#elif ELTWISE_TYPE == 9
#ifdef SWAPPED
out = pow(in0, in1);
#else
out = pow(in1, in0);
#else
out = pow(in0, in1);
#endif
#endif
......
......@@ -3,46 +3,45 @@
__kernel void reduce_mean(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__local float4* group_sum,
__local DATA_TYPE4 *group_sum,
__private const int group_size,
__private const int partial_len,
__private const int remain_index,
__private const int batch,
__private const int in_height,
__private const int in_width,
__private const float in_height_r,
__private const float in_width_r,
__private const float image_size_reciprocal,
__private const float in_width_reciprocal,
__private const int channel_blocks,
__private const float channel_blocks_reciprocal,
__write_only image2d_t output) {
const int i = get_local_id(0);
const int j = get_local_id(1);
const int k = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (i >= local_size_dim0 || j >= local_size_dim1 || k >= global_size_dim2)
if (k >= global_size_dim2)
return;
const int dim0_size = local_size_dim0;
#else
const int dim0_size = get_local_size(0);
#endif
const int dim0_size = get_local_size(0);
DATA_TYPE4 tmp = (DATA_TYPE4){0, 0, 0, 0};
const int index = j * dim0_size + i;
const int b = k / channel_blocks;
const int ch = k - b * channel_blocks;
const int index = mad24(j, dim0_size, i);
const int b = floor(k * channel_blocks_reciprocal);
const int ch = mad24(b, -channel_blocks, k);
DATA_TYPE4 in;
const int valid_part_len = select(partial_len,
partial_len - 1,
remain_index > 0 && index >= remain_index);
const int full_offset = index * partial_len;
const int full_offset = mul24(index, partial_len);
const int base_offset = select(full_offset,
full_offset - (index - remain_index),
valid_part_len < partial_len);
#pragma unroll
for (int l = 0; l < valid_part_len; ++l) {
int offset = base_offset + l;
int h_id = floor(offset * in_width_r);
int w_id = offset - h_id * in_width;
int h_id = floor(offset * in_width_reciprocal);
int w_id = mad24(h_id, -in_width, offset);
int pos_x = mad24(ch, in_width, w_id);
int pos_y = mad24(b, in_height, h_id);
in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y));
......@@ -60,7 +59,7 @@ __kernel void reduce_mean(KERNEL_ERROR_PARAMS
for (int l = 0; l < group_size; ++l) {
out = out + group_sum[l];
}
out = out * in_height_r * in_width_r;
out = out * image_size_reciprocal;
WRITE_IMAGET(output, (int2)(ch, b), out);
}
}
......@@ -17,7 +17,7 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()(
Tensor *output,
StatsFuture *future) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims.");
// MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims.");
MACE_CHECK(input->dim_size() == 4,
"reduce mean gpu only support 4-dim input");
MACE_CHECK(axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2,
......@@ -83,8 +83,9 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()(
const int group_size = lws[0] * lws[1] * lws[2];
const int partial_len = (image_size + group_size - 1) / group_size;
const int remain_index = image_size % group_size;
const float in_width_r = 1.f / in_width;
const float in_height_r = 1.f / in_height;
const float in_width_reciprocal = 1.f / in_width;
const float img_size_reciprocal = 1.f / (in_width * in_height);
const float channel_blk_reciprocal = 1.f / channel_blocks;
if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0;
......@@ -98,7 +99,7 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()(
kernel_.setArg(idx++, gws[2]);
}
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, (group_size * 4 * sizeof(float)),
kernel_.setArg(idx++, (group_size * 4 * sizeof(T)),
nullptr);
kernel_.setArg(idx++, static_cast<int32_t>(group_size));
kernel_.setArg(idx++, static_cast<int32_t>(partial_len));
......@@ -106,9 +107,10 @@ MaceStatus ReduceMeanFunctor<DeviceType::GPU, T>::operator()(
kernel_.setArg(idx++, static_cast<int32_t>(batch));
kernel_.setArg(idx++, static_cast<int32_t>(in_height));
kernel_.setArg(idx++, static_cast<int32_t>(in_width));
kernel_.setArg(idx++, in_height_r);
kernel_.setArg(idx++, in_width_r);
kernel_.setArg(idx++, img_size_reciprocal);
kernel_.setArg(idx++, in_width_reciprocal);
kernel_.setArg(idx++, static_cast<int32_t>(channel_blocks));
kernel_.setArg(idx++, channel_blk_reciprocal);
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
......
......@@ -149,6 +149,7 @@ class MaceKeyword(object):
mace_device = 'device'
mace_value_str = 'value'
mace_wino_block_size = 'wino_block_size'
mace_output_shape_str = 'output_shape'
mace_begin_mask_str = 'begin_mask'
mace_end_mask_str = 'end_mask'
mace_ellipsis_mask_str = 'ellipsis_mask'
......
......@@ -57,6 +57,7 @@ TFSupportedOps = [
'Max',
'Neg',
'Abs',
'Pow',
'RealDiv',
'Square',
'SquaredDifference',
......@@ -119,6 +120,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Max.name: EltwiseType.MAX,
TFOpType.Neg.name: EltwiseType.NEG,
TFOpType.Abs.name: EltwiseType.ABS,
TFOpType.Pow.name: EltwiseType.POW,
TFOpType.RealDiv.name: EltwiseType.DIV,
TFOpType.SquaredDifference.name: EltwiseType.SQR_DIFF,
TFOpType.Square.name: EltwiseType.POW,
......@@ -145,6 +147,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Max.name: self.convert_elementwise,
TFOpType.Neg.name: self.convert_elementwise,
TFOpType.Abs.name: self.convert_elementwise,
TFOpType.Pow.name: self.convert_elementwise,
TFOpType.RealDiv.name: self.convert_elementwise,
TFOpType.SquaredDifference.name: self.convert_elementwise,
TFOpType.Square.name: self.convert_elementwise,
......@@ -327,8 +330,17 @@ class TensorflowConverter(base_converter.ConverterInterface):
dilation_val = tf_op.get_attr(tf_dilations_str)[1:3]
except ValueError:
dilation_val = [1, 1]
dilation_arg.ints.extend(dilation_val)
else:
del op.input[1:]
output_shape_arg = op.arg.add()
output_shape_arg.name = MaceKeyword.mace_output_shape_str
output_shape_value = tf_op.inputs[0].eval().astype(np.int32).flat
output_shape_arg.ints.extend(output_shape_value)
self._skip_tensor.add(tf_op.inputs[0].name)
del op.input[0]
if len(tf_op.inputs) >= 3:
op.input.extend([tf_op.inputs[2].name, tf_op.inputs[1].name])
def convert_elementwise(self, tf_op):
op = self.convert_general_op(tf_op)
......@@ -348,7 +360,6 @@ class TensorflowConverter(base_converter.ConverterInterface):
value_arg.f = -0.5
if type_arg.i != EltwiseType.NEG.value \
and type_arg.i != EltwiseType.POW.value \
and type_arg.i != EltwiseType.ABS.value:
if len(tf_op.inputs[0].shape) == 0:
value_arg = op.arg.add()
......@@ -578,18 +589,30 @@ class TensorflowConverter(base_converter.ConverterInterface):
op = self.convert_general_op(tf_op)
del op.input[1:]
reduce_dims = tf_op.inputs[1].eval()
op.type = MaceOp.ReduceMean.name
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
if len(tf_op.inputs) > 1:
reduce_dims = tf_op.inputs[1].eval()
else:
try:
reduce_dims = tf_op.get_attr('axis')
except ValueError:
try:
reduce_dims = tf_op.get_attr('reduction_indices')
except ValueError:
reduce_dims = []
axis_arg.ints.extend(reduce_dims)
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
try:
keep_dims = tf_op.get_attr(MaceKeyword.mace_keepdims_str)
keep_dims_arg = op.arg.add()
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
keep_dims_arg.i = keep_dims
keep_dims = tf_op.get_attr('keepdims')
except ValueError:
pass
try:
keep_dims = tf_op.get_attr('keep_dims')
except ValueError:
keep_dims = 0
keep_dims_arg.i = keep_dims
self._skip_tensor.add(tf_op.inputs[1].name)
......
......@@ -919,7 +919,10 @@ class Transformer(base_converter.ConverterInterface):
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(3, 2, 0, 1)
if op.type == MaceOp.Deconv2D.name:
filter_data = filter_data.transpose(2, 3, 0, 1)
else:
filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
if op.type == MaceOp.FullyConnected.name:
......@@ -993,6 +996,13 @@ class Transformer(base_converter.ConverterInterface):
self.buffer_to_image(op, 2, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.BiasAdd.name:
self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.Eltwise.name and len(op.input) == 2:
if op.input[0] in self._consts \
and len(self._consts[op.input[0]].dims) == 1:
self.buffer_to_image(op, 0, OpenCLBufferType.ARGUMENT)
if op.input[1] in self._consts \
and len(self._consts[op.input[1]].dims) == 1:
self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT)
elif op.type == MaceOp.FoldedBatchNorm.name:
self.buffer_to_image(op, 1, OpenCLBufferType.ARGUMENT)
self.buffer_to_image(op, 2, OpenCLBufferType.ARGUMENT)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册