提交 ffdae79f 编写于 作者: Y yejianwu

Merge branch 'master' of v9.git.n.xiaomi.com:deep-learning/mace into bm_to_image

......@@ -141,6 +141,7 @@ const std::map<std::string, std::string>
OpenCLRuntime::program_map_ = {
{"addn", "addn.cl"},
{"batch_norm", "batch_norm.cl"},
{"conv_2d", "conv_2d.cl"},
{"conv_2d_1x1", "conv_2d_1x1.cl"},
{"conv_2d_3x3", "conv_2d_3x3.cl"},
{"depthwise_conv_3x3", "depthwise_conv_3x3.cl"},
......
......@@ -58,19 +58,27 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input_tensor,
Tensor *output_tensor) {
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
std::vector<index_t> filter_shape(4);
filter_shape[0] = input_tensor->shape()[1];
filter_shape[1] = input_tensor->shape()[1];
filter_shape[2] = kernels_[0];
filter_shape[3] = kernels_[1];
kernels::CalcPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(), this->dilations_,
strides_, this->padding_, output_shape.data(),
paddings.data());
output_tensor->Resize(output_shape);
const float *input = input_tensor->data<float>();
float *output = output_tensor->mutable_data<float>();
const index_t *input_shape = input_tensor->shape().data();
const index_t *output_shape = output_tensor->shape().data();
int paddings[2];
std::vector<index_t> filter_shape = {input_shape[1], input_shape[0],
kernels_[0], kernels_[1]};
kernels::CalPaddingSize(input_shape, filter_shape.data(), this->dilations_,
strides_, this->padding_, paddings);
#ifdef __COPY_MAKE_PADDING
Tensor padded_input;
ConstructInputWithPadding(input_tensor, paddings, &padded_input);
ConstructInputWithPadding(input_tensor, paddings.data(), &padded_input);
input = padded_input.data<float>();
input_shape = padded_input.shape().data();
#endif
......@@ -80,17 +88,17 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
// kernel_size: 2x2, strides: 2x2
if (pooling_type_ == MAX) { // MAX_POOL_2x2s2x2
#ifdef __COPY_MAKE_PADDING
PoolingMaxNeonK2x2S2x2Padded(input, input_shape, output, output_shape);
PoolingMaxNeonK2x2S2x2Padded(input, input_shape, output, output_shape.data());
#else
PoolingMaxNeonK2x2S2x2(input, input_shape, output, output_shape,
paddings);
PoolingMaxNeonK2x2S2x2(input, input_shape, output, output_shape.data(),
paddings.data());
#endif
} else { // AVG_POOL_2x2s2x2
#ifdef __COPY_MAKE_PADDING
PoolingAvgNeonK2x2S2x2Padded(input, input_shape, output, output_shape);
PoolingAvgNeonK2x2S2x2Padded(input, input_shape, output, output_shape.data());
#else
PoolingAvgNeonK2x2S2x2(input, input_shape, output, output_shape,
paddings);
PoolingAvgNeonK2x2S2x2(input, input_shape, output, output_shape.data(),
paddings.data());
#endif
}
} else if (kernels_[0] == 3 && kernels_[1] == 3 && strides_[0] == 2 &&
......@@ -98,17 +106,17 @@ void PoolingFunctor<DeviceType::NEON, float>::operator()(
// kernel_size: 3x3, strides: 2x2
if (pooling_type_ == MAX) { // MAX_POOL_3x3s2x2
#ifdef __COPY_MAKE_PADDING
PoolingMaxNeonK3x3S2x2Padded(input, input_shape, output, output_shape);
PoolingMaxNeonK3x3S2x2Padded(input, input_shape, output, output_shape.data());
#else
PoolingMaxNeonK3x3S2x2(input, input_shape, output, output_shape,
paddings);
PoolingMaxNeonK3x3S2x2(input, input_shape, output, output_shape.data(),
paddings.data());
#endif
} else { // AVG_POOL_3x3s2x2
#ifdef __COPY_MAKE_PADDING
PoolingAvgNeonK3x3S2x2Padded(input, input_shape, output, output_shape);
PoolingAvgNeonK3x3S2x2Padded(input, input_shape, output, output_shape.data());
#else
PoolingAvgNeonK3x3S2x2(input, input_shape, output, output_shape,
paddings);
PoolingAvgNeonK3x3S2x2(input, input_shape, output, output_shape.data(),
paddings.data());
#endif
}
} else { // not implement yet
......
......@@ -18,4 +18,7 @@
#define READ_IMAGET CMD_TYPE(read_image, CMD_DATA_TYPE)
#define WRITE_IMAGET CMD_TYPE(write_image, CMD_DATA_TYPE)
__constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#endif // MACE_KERNELS_OPENCL_CL_COMMON_H_
#include <common.h>
__kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * cin * kw * kh, cout/4 */
#ifdef BIAS
__read_only image2d_t bias, /* cout%4 * cout/4 */
#endif
__write_only image2d_t output,
__private const int in_height,
__private const int in_width,
__private const int in_ch_blks,
__private const int out_height,
__private const int out_width,
__private const int filter_height,
__private const int filter_width,
__private const int padding_top,
__private const int padding_left) {
const int out_ch_blk = get_global_id(0);
const int out_w_blk = get_global_id(1);
const int out_w_blks = get_global_size(1);
const int out_hb = get_global_id(2);
const int rounded_in_ch = in_ch_blks * 4;
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#ifdef BIAS
DATA_TYPE4 out0 =
READ_IMAGET(bias, sampler, (int2)(out_ch_blk, 0));
DATA_TYPE4 out1 = out0;
DATA_TYPE4 out2 = out0;
DATA_TYPE4 out3 = out0;
#else
DATA_TYPE4 out0 = 0;
DATA_TYPE4 out1 = 0;
DATA_TYPE4 out2 = 0;
DATA_TYPE4 out3 = 0;
#endif
#if STRIDE == 1
int in_width0 = out_w_blk - padding_left;
int in_width1 = in_width0 + out_w_blks;
int in_width2 = in_width1 + out_w_blks;
int in_width3 = in_width2 + out_w_blks;
const int height_idx = (out_hb % out_height) - padding_top;
#else
int in_width0 = out_w_blk * 2 - padding_left;
int in_width1 = (out_w_blk + out_w_blks) * 2 - padding_left;
int in_width2 = (out_w_blk + 2 * out_w_blks) * 2 - padding_left;
int in_width3 = (out_w_blk + 3 * out_w_blks) * 2 - padding_left;
const int height_idx = (out_hb % out_height) * 2 - padding_top;
#endif
const int batch_idx = (out_hb / out_height) * in_height;
DATA_TYPE4 in0, in1, in2, in3;
DATA_TYPE4 weights0, weights1, weights2, weights3;
int in_idx, in_width_idx;
// Unrolling this loop hurt perfmance
for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) {
for (short hb_idx = 0; hb_idx < filter_height; ++hb_idx) {
int in_hb_value = height_idx + hb_idx;
in_hb_value = select(in_hb_value + batch_idx,
-1,
(in_hb_value < 0 || in_hb_value >= in_height));
for (short width_idx = 0; width_idx < filter_width; ++width_idx) {
in_idx = in_ch_blk * in_width;
int in_width_value;
#define READ_INPUT(i) \
in_width_value = in_width##i + width_idx; \
in_width_value = select(in_idx + in_width_value, \
-1, \
(in_width_value < 0 || in_width_value >= in_width)); \
in##i = READ_IMAGET(input, sampler, (int2)(in_width_value, in_hb_value));
READ_INPUT(0);
READ_INPUT(1);
READ_INPUT(2);
READ_INPUT(3);
#undef READ_INPUT
int filter_idx = (in_ch_blk << 2) + (hb_idx * filter_width + width_idx) * rounded_in_ch;
weights0 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 0, out_ch_blk));
weights1 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 1, out_ch_blk));
weights2 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 2, out_ch_blk));
weights3 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 3, out_ch_blk));
// Will prefetch L2 improve performance? How to pretch image data?
// Interleaving load and mul does not improve performance as expected
out0 += in0.x * weights0;
out0 += in0.y * weights1;
out0 += in0.z * weights2;
out0 += in0.w * weights3;
out1 += in1.x * weights0;
out1 += in1.y * weights1;
out1 += in1.z * weights2;
out1 += in1.w * weights3;
out2 += in2.x * weights0;
out2 += in2.y * weights1;
out2 += in2.z * weights2;
out2 += in2.w * weights3;
out3 += in3.x * weights0;
out3 += in3.y * weights1;
out3 += in3.z * weights2;
out3 += in3.w * weights3;
}
}
}
#ifdef FUSED_RELU
// TODO relux
out0 = fmax(out0, 0);
out1 = fmax(out1, 0);
out2 = fmax(out2, 0);
out3 = fmax(out3, 0);
#endif
const int out_x_base = out_ch_blk * out_width;
int w = out_w_blk;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
out0);
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
out1);
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
out2);
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
out3);
}
......@@ -59,15 +59,13 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
// Unrolling this loop hurt perfmance
for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) {
for (short hb_idx = 0; hb_idx < 3; ++hb_idx) {
int in_hb_value = height_idx + hb_idx;
in_hb_value = select(in_hb_value + batch_idx,
-1,
(in_hb_value < 0 || in_hb_value >= in_height));
for (short width_idx = 0; width_idx < 3; ++width_idx) {
in_idx = in_ch_blk * in_width;
int in_hb_value = height_idx + hb_idx;
in_hb_value = select(in_hb_value + batch_idx,
-1,
(in_hb_value < 0 || in_hb_value >= in_height));
int in_width_value;
#define READ_INPUT(i) \
in_width_value = in_width##i + width_idx; \
......
#include <common.h>
VEC_DATA_TYPE(DATA_TYPE,4) vec_pooling_3_s1(const DATA_TYPE *input_ptr, const int in_width) {
VEC_DATA_TYPE(DATA_TYPE,4) row00 = vload4(0, input_ptr);
VEC_DATA_TYPE(DATA_TYPE,2) row01 = vload2(0, input_ptr + 4);
VEC_DATA_TYPE(DATA_TYPE,4) row10 = vload4(0, input_ptr + in_width);
VEC_DATA_TYPE(DATA_TYPE,2) row11 = vload2(0, input_ptr + in_width + 4);
VEC_DATA_TYPE(DATA_TYPE,4) row20 = vload4(0, input_ptr + in_width * 2);
VEC_DATA_TYPE(DATA_TYPE,2) row21 = vload2(0, input_ptr + in_width * 2 + 4);
VEC_DATA_TYPE(DATA_TYPE,8) data00 = (VEC_DATA_TYPE(DATA_TYPE,8))(row00.s01212323);
VEC_DATA_TYPE(DATA_TYPE,4) data01 = (VEC_DATA_TYPE(DATA_TYPE,4))(row01.s0, row00.s3, row01.s01);
VEC_DATA_TYPE(DATA_TYPE,8) data10 = (VEC_DATA_TYPE(DATA_TYPE,8))(row10.s01212323);
VEC_DATA_TYPE(DATA_TYPE,4) data11 = (VEC_DATA_TYPE(DATA_TYPE,4))(row11.s0, row10.s3, row11.s01);
VEC_DATA_TYPE(DATA_TYPE,8) data20 = (VEC_DATA_TYPE(DATA_TYPE,8))(row20.s01212323);
VEC_DATA_TYPE(DATA_TYPE,4) data21 = (VEC_DATA_TYPE(DATA_TYPE,4))(row21.s0, row20.s3, row21.s01);
VEC_DATA_TYPE(DATA_TYPE,8) left = fmax(fmax(data00, data10), data20);
VEC_DATA_TYPE(DATA_TYPE,4) right = fmax(fmax(data01, data11), data21);
VEC_DATA_TYPE(DATA_TYPE,4) res = fmax((VEC_DATA_TYPE(DATA_TYPE,4))(left.s036, right.s1),
(VEC_DATA_TYPE(DATA_TYPE,4))(left.s147, right.s2));
res = fmax(res, (VEC_DATA_TYPE(DATA_TYPE,4))(left.s25, right.s03));
return res;
}
VEC_DATA_TYPE(DATA_TYPE,4) vec_pooling_3_s2(const DATA_TYPE *input_ptr, const int in_width) {
VEC_DATA_TYPE(DATA_TYPE,8) row00 = vload8(0, input_ptr);
DATA_TYPE row01 = *(input_ptr + 8);
VEC_DATA_TYPE(DATA_TYPE,8) row10 = vload8(0, input_ptr + in_width);
DATA_TYPE row11 = *(input_ptr + in_width + 8);
VEC_DATA_TYPE(DATA_TYPE,8) row20 = vload8(0, input_ptr + in_width * 2);
DATA_TYPE row21 = *(input_ptr + in_width * 2 + 8);
VEC_DATA_TYPE(DATA_TYPE,8) data00 = (VEC_DATA_TYPE(DATA_TYPE,8))(row00.s01223445);
VEC_DATA_TYPE(DATA_TYPE,4) data01 = (VEC_DATA_TYPE(DATA_TYPE,4))(row00.s667, row01);
VEC_DATA_TYPE(DATA_TYPE,8) data10 = (VEC_DATA_TYPE(DATA_TYPE,8))(row10.s01223445);
VEC_DATA_TYPE(DATA_TYPE,4) data11 = (VEC_DATA_TYPE(DATA_TYPE,4))(row10.s667, row11);
VEC_DATA_TYPE(DATA_TYPE,8) data20 = (VEC_DATA_TYPE(DATA_TYPE,8))(row20.s01223445);
VEC_DATA_TYPE(DATA_TYPE,4) data21 = (VEC_DATA_TYPE(DATA_TYPE,4))(row20.s667, row21);
VEC_DATA_TYPE(DATA_TYPE,8) left = fmax(fmax(data00, data10), data20);
VEC_DATA_TYPE(DATA_TYPE,4) right = fmax(fmax(data01, data11), data21);
VEC_DATA_TYPE(DATA_TYPE,4) res = fmax((VEC_DATA_TYPE(DATA_TYPE,4))(left.s036, right.s1),
(VEC_DATA_TYPE(DATA_TYPE,4))(left.s147, right.s2));
res = fmax(res, (VEC_DATA_TYPE(DATA_TYPE,4))(left.s25, right.s03));
return res;
}
DATA_TYPE inner_pooling_3(const DATA_TYPE *input_ptr, const int in_width) {
VEC_DATA_TYPE(DATA_TYPE,3) row0 = vload3(0, input_ptr);
VEC_DATA_TYPE(DATA_TYPE,3) row1 = vload3(0, input_ptr + in_width);
VEC_DATA_TYPE(DATA_TYPE,3) row2 = vload3(0, input_ptr + in_width * 2);
VEC_DATA_TYPE(DATA_TYPE,3) data = fmax(fmax(row0, row1), row2);
DATA_TYPE res = fmax(fmax(data.s0, data.s1), data.s2);
return res;
}
// Supported data type: half/float
__kernel void pooling3(__global const DATA_TYPE *input, /* n, c, h, w */
__private const int in_height,
__private const int in_width,
__private const int out_chan_num,
__private const int out_height,
__private const int out_width,
__private const int stride,
__global DATA_TYPE *output) {
int batch = get_global_id(0);
int out_chan_blk = get_global_id(1);
int out_pixel_blk = get_global_id(2);
const int round_out_width = (out_width + 3) / 4;
const int out_pixel_height = out_pixel_blk / round_out_width;
const int out_pixel_width = out_pixel_blk % round_out_width;
const int out_chan_begin = out_chan_blk * 4;
const int out_chan_end = min(out_chan_begin + 4, out_chan_num);
const int out_pixel_begin = out_pixel_height * out_width + out_pixel_width * 4;
const int out_pixel_end = min(out_pixel_begin + 4, (out_pixel_height + 1) * out_width);
const int in_pixel_begin = out_pixel_height * stride * in_width + out_pixel_width * stride * 4;
const int in_pixel = in_height * in_width;
const int out_pixel = out_height * out_width;
const int in_offset = batch * out_chan_num * in_pixel;
const int out_offset = batch * out_chan_num * out_pixel;
const DATA_TYPE *input_base = input + in_offset + in_pixel_begin;
DATA_TYPE *output_base = output + out_offset + out_pixel_begin;
const int pixels = out_pixel_end - out_pixel_begin;
for (int i = out_chan_begin; i < out_chan_end; ++i) {
const DATA_TYPE *input_ptr = input_base + i * in_pixel;
DATA_TYPE *output_ptr = output_base + i * out_pixel;
if (pixels == 4) {
VEC_DATA_TYPE(DATA_TYPE,4) res;
#ifdef STRIDE_1
res = vec_pooling_3_s1(input_ptr, in_width);
#ifdef FP16
#define MIN_VALUE -USHRT_MAX
#else
res = vec_pooling_3_s2(input_ptr, in_width);
#define MIN_VALUE -FLT_MAX
#endif
vstore4(res, 0, output_ptr);
} else {
for (int p = 0; p < pixels; ++p) {
output_ptr[p] = inner_pooling_3(input_ptr, in_width);
input_ptr += stride;
}
}
}
}
int calculate_avg_block_size(const int pos_h,
const int pos_w,
const int pool_size,
const int pad_h,
const int pad_w,
const int h_size,
const int w_size) {
const int h_start = max(0, pos_h - pad_h);
const int w_start = max(0, pos_w - pad_w);
const int h_end = min(pos_h + pool_size - pad_h, h_size);
const int w_end = min(pos_w + pool_size - pad_w, w_size);
inline int calculate_avg_block_size(const int pool_size,
const int pos_h,
const int pos_w,
const int h_size,
const int w_size) {
const int h_start = max(0, pos_h);
const int w_start = max(0, pos_w);
const int h_end = min(pos_h + pool_size, h_size);
const int w_end = min(pos_w + pool_size, w_size);
return (h_end - h_start) * (w_end - w_start);
}
// Supported data type: half/float
__kernel void poolingn(__global const DATA_TYPE *input, /* n, c, h, w */
__private const int in_height,
__private const int in_width,
__private const int out_chan_num,
__private const int out_height,
__private const int out_width,
__private const int stride,
__private const int pad_h,
__private const int pad_w,
__private const int pooling_size,
__global DATA_TYPE *output) {
int batch = get_global_id(0);
int out_chan_idx = get_global_id(1);
int out_pixel_idx = get_global_id(2);
const int out_pixel_height = out_pixel_idx / out_width;
const int out_pixel_width = out_pixel_idx % out_width;
const int out_chan_begin = out_chan_idx * 4;
const int out_chan_end = min(out_chan_begin + 4, out_chan_num);
const int in_pixel_idx = out_pixel_height * stride * in_width
+ out_pixel_width * stride;
const int in_pixel = in_height * in_width;
const int out_pixel = out_height * out_width;
const int in_offset = batch * out_chan_num * in_pixel;
const int out_offset = batch * out_chan_num * out_pixel;
const DATA_TYPE *input_base = input + in_offset + in_pixel_idx;
DATA_TYPE *output_base = output + out_offset + out_pixel_idx;
const int block_size = calculate_avg_block_size(
out_pixel_height * stride,
out_pixel_width * stride,
pooling_size,
pad_h/2,
pad_w/2,
in_height - pad_h,
in_width - pad_w);
for (int i = out_chan_begin; i < out_chan_end; ++i) {
VEC_DATA_TYPE(DATA_TYPE,8) sum8 = 0.0f;
DATA_TYPE sum1 = 0.0f;
DATA_TYPE *output_ptr = output_base + i * out_pixel;
for (int y = 0; y < pooling_size; ++y) {
const DATA_TYPE *input_ptr = input_base + i * in_pixel + y * in_width;
int x = 0;
for (; x < (pooling_size-8); x += 8) {
VEC_DATA_TYPE(DATA_TYPE,8) data = vload8(0, input_ptr);
sum8 += data;
input_ptr += 8;
}
for (; x < pooling_size; ++x) {
sum1 += *input_ptr;
input_ptr++;
__kernel void pooling(__read_only image2d_t input,
__private const int in_height,
__private const int in_width,
__private const int out_height,
__private const int pad_top,
__private const int pad_left,
__private const int stride,
__private const int pooling_size,
__write_only image2d_t output) {
const int out_chan_idx = get_global_id(0);
const int out_width_idx = get_global_id(1);
const int out_width = get_global_size(1);
const int out_hb_idx = get_global_id(2);
const int batch_idx = (out_hb_idx / out_height) * in_height;
const int in_height_start = (out_hb_idx % out_height) * stride - pad_top;
const int in_width_start = out_width_idx * stride - pad_left;
const int in_channel_offset = out_chan_idx * in_width;
#ifdef POOL_AVG
DATA_TYPE4 res = 0;
for (int height = 0; height < pooling_size; ++height) {
int in_height_idx = in_height_start + height;
in_height_idx = select(batch_idx + in_height_idx,
-1,
(in_height_idx < 0 || in_height_idx >= in_height));
for (int width = 0; width < pooling_size; ++width) {
int in_width_idx = in_width_start + width;
in_width_idx = select(in_channel_offset + in_width_idx,
-1,
(in_width_idx < 0 || in_width_idx >= in_width));
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(in_width_idx, in_height_idx));
res = res + in;
}
}
const int block_size = calculate_avg_block_size(pooling_size,
in_height_start, in_width_start,
in_height, in_width);
res /= block_size;
#else
DATA_TYPE4 res = (DATA_TYPE4)(MIN_VALUE);
for (int height = 0; height < pooling_size; ++height) {
int in_height_idx = in_height_start + height;
in_height_idx = select(batch_idx + in_height_idx,
-1,
(in_height_idx < 0 || in_height_idx >= in_height));
if (in_height_idx != -1) {
for (int width = 0; width < pooling_size; ++width) {
int in_width_idx = in_width_start + width;
in_width_idx = select(in_channel_offset + in_width_idx,
-1,
(in_width_idx < 0 || in_width_idx >= in_width));
if (in_width_idx != -1) {
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(in_width_idx, in_height_idx));
res = fmax(res, in);
}
}
}
VEC_DATA_TYPE(DATA_TYPE,4) sum4 = sum8.s0123 + sum8.s4567;
VEC_DATA_TYPE(DATA_TYPE,2) sum2 = sum4.s01 + sum4.s23;
*output_ptr = (sum2.s0 + sum2.s1 + sum1) / block_size;
}
#endif
WRITE_IMAGET(output, (int2)(out_chan_idx * out_width + out_width_idx, out_hb_idx), res);
}
......@@ -28,6 +28,11 @@ extern void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter,
const int *padding, const DataType dt,
Tensor *output);
extern void Conv2dOpencl(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const uint32_t stride, const int *padding,
const DataType dt, Tensor *output);
template<typename T>
void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const Tensor *filter,
......@@ -47,17 +52,13 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
index_t kernel_h = filter->dim(0);
index_t kernel_w = filter->dim(1);
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
if (!input->is_image() || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1) {
LOG(WARNING) << "OpenCL conv2d kernel with "
<< "filter" << kernel_h << "x" << kernel_w << ","
<< " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version";
// TODO(heliangliang) The CPU/NEON kernel should map the buffer
Conv2dFunctor<DeviceType::CPU, T>(strides_, paddings_, dilations_)(
input, filter, bias, output);
return;
MACE_NOT_IMPLEMENTED;
}
std::vector<index_t> output_shape(4);
......@@ -66,16 +67,18 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
input->shape().data(), filter->shape().data(), dilations_,
strides_, paddings_, output_shape.data(), paddings.data());
if (input->is_image()) {
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
if (kernel_h == kernel_w && kernel_h <= 5 &&
selector[kernel_h - 1][strides_[0] - 1] != nullptr) {
auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_func(input, filter, bias, false, paddings.data(), DataTypeToEnum<T>::value, output);
} else {
output->Resize(output_shape);
Conv2dOpencl(input, filter, bias, false, strides_[0], paddings.data(), DataTypeToEnum<T>::value, output);
}
auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_func(input, filter, bias, false, paddings.data(), DataTypeToEnum<T>::value, output);
}
template
......
......@@ -38,7 +38,6 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter,
auto program = runtime->program();
auto conv_2d_kernel = runtime->BuildKernel("conv_2d_3x3", "conv_2d_3x3", built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(conv_2d_kernel);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/common.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/conv_2d.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/utils.h"
namespace mace {
namespace kernels {
void Conv2dOpencl(const Tensor *input, const Tensor *filter,
const Tensor *bias, const bool fused_relu,
const uint32_t stride, const int *padding,
const DataType dt, Tensor *output) {
const index_t batch = output->dim(0);
const index_t height = output->dim(1);
const index_t width = output->dim(2);
const index_t channels = output->dim(3);
const index_t input_channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
const index_t width_blocks = RoundUpDiv4(width);
std::set<std::string> built_options;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace("-DSTRIDE=" + ToString(stride));
if (fused_relu) {
built_options.emplace("-DFUSED_RELU");
}
auto runtime = OpenCLRuntime::Get();
auto program = runtime->program();
auto conv_2d_kernel = runtime->BuildKernel("conv_2d", "conv_2d", built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(conv_2d_kernel);
uint32_t idx = 0;
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(filter->buffer())));
if (bias != nullptr) {
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
}
conv_2d_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(output->buffer())));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(1)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input->dim(2)));
conv_2d_kernel.setArg(idx++, static_cast<int>(input_channel_blocks));
conv_2d_kernel.setArg(idx++, static_cast<int>(height));
conv_2d_kernel.setArg(idx++, static_cast<int>(width));
conv_2d_kernel.setArg(idx++, static_cast<int>(filter->dim(0)));
conv_2d_kernel.setArg(idx++, static_cast<int>(filter->dim(1)));
conv_2d_kernel.setArg(idx++, padding[0] / 2);
conv_2d_kernel.setArg(idx++, padding[1] / 2);
auto command_queue = runtime->command_queue();
cl_int error;
error = command_queue.enqueueNDRangeKernel(
conv_2d_kernel, cl::NullRange,
cl::NDRange(static_cast<uint32_t>(channel_blocks), static_cast<uint32_t>(width_blocks),
static_cast<uint32_t>(height * batch)),
cl::NDRange(16, 16, 4),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS, error);
}
} // namespace kernels
} // namespace mace
......@@ -10,131 +10,94 @@
namespace mace {
namespace kernels {
static void Pooling3(const Tensor *input,
const int *stride,
const PoolingType type,
Tensor *output) {
if (type != MAX) {
MACE_NOT_IMPLEMENTED;
}
static void Pooling(const Tensor *input,
const int *stride,
const int *paddings,
const int pooling_size,
const PoolingType type,
const DataType dt,
Tensor *output) {
index_t batch = output->dim(0);
index_t channels = output->dim(1);
index_t out_height = output->dim(2);
index_t out_width = output->dim(3);
index_t out_height = output->dim(1);
index_t out_width = output->dim(2);
index_t channels = output->dim(3);
index_t channel_blk = (channels + 3) / 4;
const index_t pixel_width = (out_width + 3) / 4 ;
index_t channel_blocks = (channels + 3) / 4;
const uint32_t gws[3] = {
static_cast<uint32_t>(batch),
static_cast<uint32_t>(channel_blk),
static_cast<uint32_t>(pixel_width * out_height),
static_cast<uint32_t>(channel_blocks),
static_cast<uint32_t>(out_width),
static_cast<uint32_t>(batch * out_height),
};
auto runtime = OpenCLRuntime::Get();
std::set<std::string> built_options;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype()));
built_options.emplace(stride[0] == 1 ? "-DSTRIDE_1" : "");
auto pooling_kernel = runtime->BuildKernel("pooling", "pooling3", built_options);
if (type == MAX && input->dtype() == output->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(dt));
built_options.emplace(dt == DT_HALF ? "-DFP16" : "");
} else {
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
}
if (type == AVG) {
built_options.emplace("-DPOOL_AVG");
}
auto pooling_kernel = runtime->BuildKernel("pooling", "pooling", built_options);
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(pooling_kernel);
const uint32_t lws[3] = {1, 8, 128};
uint32_t lws[3];
lws[0] = std::min<uint32_t>(channel_blocks, kwg_size);
lws[1] = std::min<uint32_t>(out_width, kwg_size / lws[0]);
lws[2] = std::min<uint32_t>(out_height * batch, kwg_size / (lws[0] * lws[1]));
uint32_t idx = 0;
pooling_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
pooling_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
pooling_kernel.setArg(idx++, static_cast<int32_t>(input->dim(1)));
pooling_kernel.setArg(idx++, static_cast<int32_t>(input->dim(2)));
pooling_kernel.setArg(idx++, static_cast<int32_t>(input->dim(3)));
pooling_kernel.setArg(idx++, static_cast<int32_t>(channels));
pooling_kernel.setArg(idx++, static_cast<int32_t>(out_height));
pooling_kernel.setArg(idx++, static_cast<int32_t>(out_width));
pooling_kernel.setArg(idx++, paddings[0] / 2);
pooling_kernel.setArg(idx++, paddings[1] / 2);
pooling_kernel.setArg(idx++, stride[0]);
pooling_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
pooling_kernel.setArg(idx++, pooling_size);
pooling_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
pooling_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS);
MACE_CHECK(error == CL_SUCCESS) << error;
}
static void PoolingN(const Tensor *input,
const int *stride,
const int *paddings,
const int pooling_size,
const PoolingType type,
Tensor *output) {
if (type != AVG) {
MACE_NOT_IMPLEMENTED;
}
index_t batch = output->dim(0);
index_t channels = output->dim(1);
index_t out_height = output->dim(2);
index_t out_width = output->dim(3);
index_t channel_blk = (channels + 3) / 4;
const uint32_t gws[3] = {
static_cast<uint32_t>(batch),
static_cast<uint32_t>(channel_blk),
static_cast<uint32_t>(out_height * out_width),
template<typename T>
void PoolingFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
Tensor *output) {
MACE_CHECK(dilations_[0] == 1 && dilations_[1] == 1) << "Pooling opencl kernel not support dilation yet";
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
std::vector<index_t> filter_shape = {
kernels_[0], kernels_[1],
input->dim(3), input->dim(3)
};
auto runtime = OpenCLRuntime::Get();
std::set<std::string> built_options;
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(input->dtype()));
auto pooling_kernel = runtime->BuildKernel("pooling", "poolingn", built_options);
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), filter_shape.data(),
dilations_, strides_, this->padding_,
output_shape.data(), paddings.data());
const uint32_t lws[3] = {1, 8, 128};
std::vector<size_t> output_image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT, output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
uint32_t idx = 0;
pooling_kernel.setArg(idx++, *(static_cast<const cl::Buffer *>(input->buffer())));
pooling_kernel.setArg(idx++, static_cast<int32_t>(input->dim(2)));
pooling_kernel.setArg(idx++, static_cast<int32_t>(input->dim(3)));
pooling_kernel.setArg(idx++, static_cast<int32_t>(channels));
pooling_kernel.setArg(idx++, static_cast<int32_t>(out_height));
pooling_kernel.setArg(idx++, static_cast<int32_t>(out_width));
pooling_kernel.setArg(idx++, stride[0]);
pooling_kernel.setArg(idx++, paddings[0]);
pooling_kernel.setArg(idx++, paddings[1]);
pooling_kernel.setArg(idx++, pooling_size);
pooling_kernel.setArg(idx++, *(static_cast<cl::Buffer *>(output->buffer())));
Pooling(input, strides_, paddings.data(), kernels_[0], pooling_type_,
DataTypeToEnum<T>::value, output);
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
pooling_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1], gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]),
NULL, OpenCLRuntime::Get()->GetDefaultEvent());
MACE_CHECK(error == CL_SUCCESS);
}
template <>
void PoolingFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
Tensor *output) {
int paddings[2];
std::vector<index_t> filter_shape = {input->dim(1), input->dim(0),
kernels_[0], kernels_[1]};
kernels::CalPaddingSize(input->shape().data(), filter_shape.data(), this->dilations_,
strides_, this->padding_, paddings);
#define POOLING_HELPER \
switch(kernels_[0]) { \
case 3: \
Pooling3(input, strides_, pooling_type_, output); \
break; \
default: \
PoolingN(input, strides_, paddings, kernels_[0], \
pooling_type_, output); \
break; \
}
if (paddings[0] > 0 || paddings[1] > 0) {
Tensor padded_input(GetDeviceAllocator(DeviceType::OPENCL), DataTypeToEnum<float>::v());
ConstructInputWithPadding(input, paddings, &padded_input, pooling_type_ == MAX);
input = &padded_input;
POOLING_HELPER
} else {
POOLING_HELPER
}
#undef POOLING_HELPER
}
template
struct PoolingFunctor<DeviceType::OPENCL, float>;
template
struct PoolingFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
......@@ -18,36 +18,66 @@ enum PoolingType {
namespace kernels {
template <DeviceType D, typename T>
struct PoolingFunctor {
PoolingFunctor(const PoolingType pooling_type,
const int *kernels,
const int *strides,
const Padding padding,
const int *dilations)
struct PoolingFunctorBase {
PoolingFunctorBase(const PoolingType pooling_type,
const int *kernels,
const int *strides,
const Padding padding,
const int *dilations)
: pooling_type_(pooling_type),
kernels_(kernels),
strides_(strides),
padding_(padding),
dilations_(dilations) {}
const PoolingType pooling_type_;
const int *kernels_;
const int *strides_;
const Padding padding_;
const int *dilations_;
};
template<DeviceType D, typename T>
struct PoolingFunctor : PoolingFunctorBase {
PoolingFunctor(const PoolingType pooling_type,
const int *kernels,
const int *strides,
const Padding padding,
const int *dilations)
: PoolingFunctorBase(pooling_type, kernels,
strides, padding,
dilations) {}
void operator()(const Tensor *input_tensor,
Tensor *output_tensor) {
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
std::vector<index_t> filter_shape = {
kernels_[0], kernels_[1],
input_tensor->dim(3), input_tensor->dim(3)
};
kernels::CalcNHWCPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(),
dilations_, strides_, this->padding_,
output_shape.data(), paddings.data());
output_tensor->Resize(output_shape);
Tensor::MappingGuard in_guard(input_tensor);
Tensor::MappingGuard out_guard(output_tensor);
const T *input = input_tensor->data<T>();
T *output = output_tensor->mutable_data<T>();
const index_t *input_shape = input_tensor->shape().data();
const index_t *output_shape = output_tensor->shape().data();
index_t batch = output_shape[0];
index_t channels = output_shape[1];
index_t height = output_shape[2];
index_t width = output_shape[3];
index_t height = output_shape[1];
index_t width = output_shape[2];
index_t channels = output_shape[3];
index_t out_image_size = height * width;
index_t input_channels = input_shape[1];
index_t input_height = input_shape[2];
index_t input_width = input_shape[3];
index_t input_height = input_shape[1];
index_t input_width = input_shape[2];
index_t input_channels = input_shape[3];
index_t in_image_size = input_height * input_width;
int kernel_h = kernels_[0];
......@@ -59,11 +89,6 @@ struct PoolingFunctor {
int dilation_h = dilations_[0];
int dilation_w = dilations_[1];
int paddings[2];
std::vector<index_t> filter_shape = {input_shape[1], input_shape[0],
kernels_[0], kernels_[1]};
kernels::CalPaddingSize(input_shape, filter_shape.data(), this->dilations_,
strides_, this->padding_, paddings);
// The left-upper most offset of the padded input
int padded_h_start = 0 - paddings[0] / 2;
int padded_w_start = 0 - paddings[1] / 2;
......@@ -71,25 +96,24 @@ struct PoolingFunctor {
if (pooling_type_ == MAX) {
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
index_t out_offset = (b * channels + c) * out_image_size;
index_t in_offset = (b * input_channels + c) * in_image_size;
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
T max = std::numeric_limits<T>::lowest();
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int c = 0; c < channels; ++c) {
index_t in_offset = b * in_image_size * input_channels + c;
T res = std::numeric_limits<T>::lowest();
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
int inw = padded_w_start + w * stride_w + dilation_w * kw;
if (inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width) {
index_t input_offset = in_offset + inh * input_width + inw;
max = std::max(max, input[input_offset]);
index_t input_offset = in_offset + (inh * input_width + inw) * input_channels;
res = std::max(res, input[input_offset]);
}
}
}
output[out_offset] = max;
out_offset += 1;
*output = res;
output++;
}
}
}
......@@ -97,11 +121,10 @@ struct PoolingFunctor {
} else if (pooling_type_ == AVG) {
#pragma omp parallel for collapse(2)
for (int b = 0; b < batch; ++b) {
for (int c = 0; c < channels; ++c) {
index_t out_offset = (b * channels + c) * out_image_size;
index_t in_offset = (b * input_channels + c) * in_image_size;
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
for (int c = 0; c < channels; ++c) {
index_t in_offset = b * in_image_size * input_channels + c;
T sum = 0;
int block_size = 0;
for (int kh = 0; kh < kernel_h; ++kh) {
......@@ -110,14 +133,14 @@ struct PoolingFunctor {
int inw = padded_w_start + w * stride_w + dilation_w * kw;
if (inh >= 0 && inh < input_height && inw >= 0 &&
inw < input_width) {
index_t input_offset = in_offset + inh * input_width + inw;
index_t input_offset = in_offset + (inh * input_width + inw) * input_channels;
sum += input[input_offset];
block_size += 1;
}
}
}
output[out_offset] = sum / block_size;
out_offset += 1;
*output = sum / block_size;
output++;
}
}
}
......@@ -125,22 +148,26 @@ struct PoolingFunctor {
}
}
const PoolingType pooling_type_;
const int *kernels_;
const int *strides_;
const Padding padding_;
const int *dilations_;
};
template <>
template<>
void PoolingFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input_tensor,
Tensor *output_tensor);
template <>
void PoolingFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *input_tensor,
Tensor *output_tensor);
template<typename T>
struct PoolingFunctor<DeviceType::OPENCL, T> : PoolingFunctorBase {
PoolingFunctor(const PoolingType pooling_type,
const int *kernels,
const int *strides,
const Padding padding,
const int *dilations)
: PoolingFunctorBase(pooling_type, kernels,
strides, padding,
dilations) {}
void operator()(const Tensor *input_tensor,
Tensor *output_tensor);
};
} // namespace kernels
} // namespace mace
......
......@@ -118,14 +118,13 @@ void TestDiffTypeBidirectionTransform(const int type, const std::vector<index_t>
.Input("B2IOutput")
.Output("I2BOutput")
.AddIntArg("buffer_type", type)
.AddIntArg("T", DataTypeToEnum<T>::value)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
// Check
ExpectTensorNear<float, T>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), 1e-2);
ExpectTensorNear<float>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), 1e-3);
}
TEST(BufferToImageTest, ArgFloatToHalfSmall) {
......
......@@ -558,18 +558,20 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) {
}
template<DeviceType D>
static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
static void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
const std::vector<index_t> &filter_shape) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
srand(time(NULL));
srand(time(NULL));
auto func = [&](int stride_h, int stride_w, Padding padding) {
// generate random input
index_t batch = 3 + (rand() % 10);
index_t height = shape[0];
index_t width = shape[1];
index_t input_channels = shape[2] + (rand() % 10);
index_t output_channels = shape[3] + (rand() % 10);
index_t height = input_shape[0];
index_t width = input_shape[1];
index_t kernel_h = filter_shape[0];
index_t kernel_w = filter_shape[1];
index_t input_channels = filter_shape[2] + (rand() % 10);
index_t output_channels = filter_shape[3] + (rand() % 10);
// Construct graph
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
......@@ -578,7 +580,7 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
......@@ -611,7 +613,7 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
.Input("BiasImage")
.Output("OutputImage")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataType::DT_HALF))
.Finalize(net.NewOperatorDef());
......@@ -620,20 +622,46 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
ImageToBuffer<D, float>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.2);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.5);
};
for (int kernel_size : {1, 3}) {
for (int stride : {1, 2}) {
func(kernel_size, kernel_size, stride, stride, VALID);
}
for (int stride : {1, 2}) {
func(stride, stride, VALID);
func(stride, stride, SAME);
}
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConvNxNS12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32, 32, 64});
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv1x1S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32},
{1, 1, 32, 64});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv3x3S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32},
{3, 3, 32, 64});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv15x1S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32},
{15, 1, 256, 2});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv1x15S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32},
{1, 15, 256, 2});
}
TEST_F(Conv2dOpTest, OPENCLHalfAlignedConv7x75S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({32, 32},
{7, 7, 3, 64});
}
TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConv1x1S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({107, 113},
{1, 1, 5, 7});
}
TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConvNxNS12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({107, 113, 5, 7});
TEST_F(Conv2dOpTest, OPENCLHalfUnalignedConv3x3S12) {
TestHalfComplexConvNxNS12<DeviceType::OPENCL>({107, 113},
{3, 3, 5, 7});
}
......@@ -10,6 +10,10 @@ REGISTER_CPU_OPERATOR(OpKeyBuilder("Pooling")
.TypeConstraint<float>("T")
.Build(),
PoolingOp<DeviceType::CPU, float>);
REGISTER_CPU_OPERATOR(OpKeyBuilder("Pooling")
.TypeConstraint<half>("T")
.Build(),
PoolingOp<DeviceType::CPU, half>);
#if __ARM_NEON
REGISTER_NEON_OPERATOR(OpKeyBuilder("Pooling")
......@@ -22,5 +26,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("Pooling")
.TypeConstraint<float>("T")
.Build(),
PoolingOp<DeviceType::OPENCL, float>);
REGISTER_OPENCL_OPERATOR(OpKeyBuilder("Pooling")
.TypeConstraint<half>("T")
.Build(),
PoolingOp<DeviceType::OPENCL, half>);
} // namespace mace
......@@ -27,21 +27,6 @@ class PoolingOp : public ConvPool2dOpBase<D, T> {
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
std::vector<index_t> filter_shape(4);
// TODO(chenghui): is it kind of a hack?
filter_shape[0] = input->shape()[1];
filter_shape[1] = input->shape()[0];
filter_shape[2] = kernels_[0];
filter_shape[3] = kernels_[1];
kernels::CalcPaddingAndOutputSize(
input->shape().data(), filter_shape.data(), this->dilations_.data(),
this->strides_.data(), this->padding_, output_shape.data(),
paddings.data());
output->Resize(output_shape);
functor_(input, output);
return true;
};
......
......@@ -28,48 +28,20 @@ TEST_F(PoolingOpTest, MAX_VALID) {
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 2, 4, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
"Input", {1, 4, 4, 2},
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
// Run
net.RunOp();
// Check
auto expected =
CreateTensor<float>({1, 2, 2, 2}, {5, 7, 13, 15, 21, 23, 29, 31});
CreateTensor<float>({1, 2, 2, 2}, {5, 21, 7, 23, 13, 29, 15, 31});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
TEST_F(PoolingOpTest, AVG_VALID) {
// Construct graph
auto &net = test_net();
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("pooling_type", PoolingType::AVG)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 2, 4, 4},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31});
// Run
net.RunOp();
// Check
auto expected = CreateTensor<float>(
{1, 2, 2, 2}, {2.5, 4.5, 10.5, 12.5, 18.5, 20.5, 26.5, 28.5});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
TEST_F(PoolingOpTest, MAX_SAME) {
// Construct graph
......@@ -85,14 +57,14 @@ TEST_F(PoolingOpTest, MAX_SAME) {
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 1, 3, 3},
{0, 1, 2, 3, 4, 5, 6, 7, 8});
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 3, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8});
// Run
net.RunOp();
// Check
auto expected = CreateTensor<float>({1, 1, 2, 2}, {4, 5, 7, 8});
auto expected = CreateTensor<float>({1, 2, 2, 1}, {4, 5, 7, 8});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
......@@ -112,14 +84,14 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 1, 4, 4},
"Input", {1, 4, 4, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
// Run
net.RunOp();
// Check
auto expected = CreateTensor<float>({1, 1, 2, 2}, {10, 11, 14, 15});
auto expected = CreateTensor<float>({1, 2, 2, 1}, {10, 11, 14, 15});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
......@@ -139,42 +111,57 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 1, 2, 9},
"Input", {1, 2, 9, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
// Run
net.RunOp(DeviceType::NEON);
net.RunOp();
// Check
auto expected = CreateTensor<float>({1, 1, 1, 5}, {10, 12, 14, 16, 17});
auto expected = CreateTensor<float>({1, 1, 5, 1}, {10, 12, 14, 16, 17});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
template <DeviceType D>
template<DeviceType D>
static void SimpleMaxPooling3S2() {
// Construct graph
OpsTestNet net;
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.AddIntArg("pooling_type", PoolingType::MAX)
.AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<D, float>(
"Input", {1, 1, 3, 9},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
"Input", {1, 3, 9, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26});
// Run
net.RunOp(D);
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
OpDefBuilder("Pooling", "PoolingTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntArg("pooling_type", PoolingType::MAX)
.AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
net.RunOp(D);
ImageToBuffer<D, float>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT);
} else {
// Run
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.AddIntArg("pooling_type", PoolingType::MAX)
.AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
net.RunOp(D);
}
// Check
auto expected = CreateTensor<float>({1, 1, 1, 4}, {20, 22, 24, 26});
auto expected = CreateTensor<float>({1, 1, 4, 1}, {20, 22, 24, 26});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
......@@ -182,15 +169,15 @@ static void SimpleMaxPooling3S2() {
TEST_F(PoolingOpTest, CPUSimpleMaxPooling3S2) {
SimpleMaxPooling3S2<CPU>();
}
TEST_F(PoolingOpTest, NEONSimpleMaxPooling3S2) {
SimpleMaxPooling3S2<NEON>();
}
TEST_F(PoolingOpTest, OPENCLSimpleMaxPooling3S2) {
SimpleMaxPooling3S2<OPENCL>();
}
template <DeviceType D>
static void AlignedMaxPooling3S2(Padding padding) {
template<DeviceType D, typename T>
static void MaxPooling3S2(const std::vector<index_t> &input_shape,
const std::vector<int> strides,
Padding padding) {
// Construct graph
OpsTestNet net;
OpDefBuilder("Pooling", "PoolingTest")
......@@ -198,22 +185,35 @@ static void AlignedMaxPooling3S2(Padding padding) {
.Output("Output")
.AddIntArg("pooling_type", PoolingType::MAX)
.AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", {2, 2})
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, float>("Input", {3, 128, 64, 64});
// Run
net.RunOp(D);
net.AddRandomInput<D, T>("Input", input_shape);
// run on cpu
net.RunOp();
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run on cpu
net.RunOp();
BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
OpDefBuilder("Pooling", "PoolingTest")
.Input("InputImage")
.Output("OutputImage")
.AddIntArg("pooling_type", PoolingType::MAX)
.AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
net.RunOp(D);
ImageToBuffer<D, T>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT);
ExpectTensorNear<float>(*net.GetOutput("Output"), expected, 0.001);
ExpectTensorNear<T>(expected, *net.GetOutput("OPENCLOutput"), 0.001);
}
// TODO(chenghui) : there is a bug.
......@@ -223,152 +223,158 @@ static void AlignedMaxPooling3S2(Padding padding) {
//}
TEST_F(PoolingOpTest, OPENCLAlignedMaxPooling3S2) {
AlignedMaxPooling3S2<OPENCL>(Padding::VALID);
AlignedMaxPooling3S2<OPENCL>(Padding::SAME);
MaxPooling3S2<OPENCL, float>({3, 64, 32, 32}, {1, 1}, Padding::VALID);
MaxPooling3S2<OPENCL, float>({3, 64, 32, 32}, {2, 2}, Padding::VALID);
MaxPooling3S2<OPENCL, float>({3, 64, 32, 32}, {1, 1}, Padding::SAME);
MaxPooling3S2<OPENCL, float>({3, 64, 32, 32}, {2, 2}, Padding::SAME);
}
TEST_F(PoolingOpTest, OPENCLHalfAlignedMaxPooling3S2) {
MaxPooling3S2<OPENCL, half>({3, 64, 32, 32}, {1, 1}, Padding::VALID);
MaxPooling3S2<OPENCL, half>({3, 64, 32, 32}, {2, 2}, Padding::VALID);
MaxPooling3S2<OPENCL, half>({3, 64, 32, 32}, {1, 1}, Padding::SAME);
MaxPooling3S2<OPENCL, half>({3, 64, 32, 32}, {2, 2}, Padding::SAME);
}
template <DeviceType D>
static void UnalignedMaxPooling3S2(Padding padding) {
TEST_F(PoolingOpTest, OPENCLUnalignedMaxPooling3S2) {
MaxPooling3S2<OPENCL, half>({3, 41, 43, 47}, {1, 1}, Padding::VALID);
MaxPooling3S2<OPENCL, half>({3, 41, 43, 47}, {2, 2}, Padding::VALID);
MaxPooling3S2<OPENCL, half>({3, 41, 43, 47}, {1, 1}, Padding::SAME);
MaxPooling3S2<OPENCL, half>({3, 41, 43, 47}, {2, 2}, Padding::SAME);
}
TEST_F(PoolingOpTest, AVG_VALID) {
// Construct graph
OpsTestNet net;
auto &net = test_net();
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.AddIntArg("pooling_type", PoolingType::MAX)
.AddIntsArg("kernels", {3, 3})
.AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", padding)
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("pooling_type", PoolingType::AVG)
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, float>("Input", {3, 113, 43, 47});
// Run
net.RunOp(D);
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 4, 4, 2},
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
// Run on cpu
// Run
net.RunOp();
ExpectTensorNear<float>(*net.GetOutput("Output"), expected, 0.001);
}
// TODO(chenghui) : there is a bug.
//TEST_F(PoolingOpTest, NEONUnalignedMaxPooling3S2) {
// UnalignedMaxPooling3S2<NEON>();
//}
// Check
auto expected = CreateTensor<float>(
{1, 2, 2, 2}, {2.5, 18.5, 4.5, 20.5, 10.5, 26.5, 12.5, 28.5});
TEST_F(PoolingOpTest, OPENCLUnalignedMaxPooling3S2) {
UnalignedMaxPooling3S2<OPENCL>(Padding::VALID);
UnalignedMaxPooling3S2<OPENCL>(Padding::SAME);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
template <DeviceType D>
template<DeviceType D>
static void SimpleAvgPoolingTest() {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>(
"Input", {1, 2, 8, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
BufferToImage<D, float>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.Input("InputImage")
.Output("OutputImage")
.AddIntArg("pooling_type", PoolingType::AVG)
.AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<D, float>(
"Input", {1, 1, 2, 8},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
// Run
net.RunOp(D);
ImageToBuffer<D, float>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT);
// Check
auto expected = CreateTensor<float>({1, 1, 1, 4}, {4.5, 6.5, 8.5, 10.5});
auto expected = CreateTensor<float>({1, 1, 4, 1}, {4.5, 6.5, 8.5, 10.5});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001);
}
TEST_F(PoolingOpTest, NEONSimpleAvgPooling) {
SimpleAvgPoolingTest<NEON>();
}
TEST_F(PoolingOpTest, OPENCLSimpleAvgPooling) {
SimpleAvgPoolingTest<OPENCL>();
}
template <DeviceType D>
static void AlignedAvgPoolingTest(Padding padding) {
template<DeviceType D, typename T>
static void AvgPoolingTest(const std::vector<index_t> &shape,
const std::vector<int> &kernels,
const std::vector<int> &strides,
Padding padding) {
// Construct graph
OpsTestNet net;
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.AddIntArg("pooling_type", PoolingType::AVG)
.AddIntsArg("kernels", {4, 4})
.AddIntsArg("strides", {4, 4})
.AddIntsArg("kernels", kernels)
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, float>("Input", {3, 128, 15, 15});
// Run
net.RunOp(D);
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
net.AddRandomInput<D, float>("Input", shape);
// Run on cpu
// run on cpu
net.RunOp();
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
ExpectTensorNear<float>(*net.GetOutput("Output"), expected, 1e-5);
}
TEST_F(PoolingOpTest, NEONAlignedAvgPooling) {
AlignedAvgPoolingTest<NEON>(Padding::VALID);
AlignedAvgPoolingTest<NEON>(Padding::SAME);
}
TEST_F(PoolingOpTest, OPENCLAlignedAvgPooling) {
AlignedAvgPoolingTest<OPENCL>(Padding::VALID);
AlignedAvgPoolingTest<OPENCL>(Padding::SAME);
}
template <DeviceType D>
static void UnAlignedAvgPoolingTest(Padding padding) {
// Construct graph
OpsTestNet net;
BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
OpDefBuilder("Pooling", "PoolingTest")
.Input("Input")
.Output("Output")
.Input("InputImage")
.Output("OutputImage")
.AddIntArg("pooling_type", PoolingType::AVG)
.AddIntsArg("kernels", {7, 7})
.AddIntsArg("strides", {7, 7})
.AddIntsArg("kernels", kernels)
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, float>("Input", {3, 128, 31, 37});
// Run
net.RunOp(D);
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
ImageToBuffer<D, T>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT);
// Run on cpu
net.RunOp();
ExpectTensorNear<float, T>(expected, *net.GetOutput("OPENCLOutput"), 0.01);
}
ExpectTensorNear<float>(*net.GetOutput("Output"), expected, 1e-5);
TEST_F(PoolingOpTest, OPENCLAlignedAvgPooling) {
AvgPoolingTest<OPENCL, float>({3, 15, 15, 128}, {4, 4}, {4, 4}, Padding::VALID);
AvgPoolingTest<OPENCL, float>({3, 15, 15, 128}, {4, 4}, {4, 4}, Padding::SAME);
}
TEST_F(PoolingOpTest, NEONUnAlignedAvgPooling) {
UnAlignedAvgPoolingTest<NEON>(Padding::VALID);
UnAlignedAvgPoolingTest<NEON>(Padding::SAME);
TEST_F(PoolingOpTest, OPENCLHalfAlignedAvgPooling) {
AvgPoolingTest<OPENCL, half>({3, 15, 15, 128}, {4, 4}, {4, 4}, Padding::VALID);
AvgPoolingTest<OPENCL, half>({3, 15, 15, 128}, {4, 4}, {4, 4}, Padding::SAME);
}
TEST_F(PoolingOpTest, OPENCLAlignedLargeKernelAvgPooling) {
AvgPoolingTest<OPENCL, float>({3, 64, 64, 128}, {16, 16}, {16, 16}, Padding::VALID);
AvgPoolingTest<OPENCL, float>({3, 64, 64, 128}, {16, 16}, {16, 16}, Padding::SAME);
}
TEST_F(PoolingOpTest, OPENCLHalfAlignedLargeKernelAvgPooling) {
AvgPoolingTest<OPENCL, half>({3, 64, 64, 128}, {16, 16}, {16, 16}, Padding::VALID);
AvgPoolingTest<OPENCL, half>({3, 64, 64, 128}, {16, 16}, {16, 16}, Padding::SAME);
}
TEST_F(PoolingOpTest, OPENCLUnAlignedAvgPooling) {
UnAlignedAvgPoolingTest<OPENCL>(Padding::VALID);
UnAlignedAvgPoolingTest<OPENCL>(Padding::SAME);
AvgPoolingTest<OPENCL, float>({3, 31, 37, 128}, {2, 2}, {2, 2}, Padding::VALID);
AvgPoolingTest<OPENCL, float>({3, 31, 37, 128}, {2, 2}, {2, 2}, Padding::SAME);
}
TEST_F(PoolingOpTest, OPENCLUnAlignedLargeKernelAvgPooling) {
AvgPoolingTest<OPENCL, float>({3, 31, 37, 128}, {8, 8}, {8, 8}, Padding::VALID);
AvgPoolingTest<OPENCL, float>({3, 31, 37, 128}, {8, 8}, {8, 8}, Padding::SAME);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册