提交 6054df4f 编写于 作者: L liuqi

Fix cwise and filter b2i write overflow bug.

上级 859fbc90
......@@ -87,7 +87,7 @@ __kernel void conv_2d(KERNEL_ERROR_PARAMS
#undef READ_INPUT
// int filter_idx = (hb_idx * filter_width + width_idx) * rounded_in_ch + (in_ch_blk << 2);
// int filter_idx = (hb_idx * filter_width + width_idx) * in_ch + (in_ch_blk << 2);
weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 0, filter_y_idx));
weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 1, filter_y_idx));
weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 2, filter_y_idx));
......
......@@ -87,7 +87,7 @@ __kernel void conv_2d_3x3(KERNEL_ERROR_PARAMS
#undef READ_INPUT
// int filter_idx = (hb_idx * 3 + width_idx) * rounded_in_ch + (in_ch_blk << 2);
// int filter_idx = (hb_idx * 3 + width_idx) * in_ch + (in_ch_blk << 2);
weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 0, filter_y_idx));
weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 1, filter_y_idx));
weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 2, filter_y_idx));
......
......@@ -3,6 +3,8 @@
__kernel void cwise(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__private const int width,
__private const int channel,
__private const float value,
__write_only image2d_t output) {
const int w = get_global_id(0);
......@@ -12,6 +14,8 @@ __kernel void cwise(KERNEL_ERROR_PARAMS
if (w >= global_size_dim0 || hb >= global_size_dim1) return;
#endif
const int remain_chan = channel - mul24((w / width), 4);
DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb));
DATA_TYPE4 in1 = (DATA_TYPE4){value, value, value, value};
DATA_TYPE4 out;
......@@ -21,15 +25,9 @@ __kernel void cwise(KERNEL_ERROR_PARAMS
#elif CWISE_TYPE == 1
out = in0 + in1;
#elif CWISE_TYPE == 2
out.x = fmax(in0.x, value);
out.y = fmax(in0.y, value);
out.z = fmax(in0.z, value);
out.z = fmax(in0.w, value);
out = fmax(in0, in1);
#elif CWISE_TYPE == 3
out.x = fmin(in0.x, value);
out.y = fmin(in0.y, value);
out.z = fmin(in0.z, value);
out.z = fmin(in0.w, value);
out = fmin(in0, in1);
#elif CWISE_TYPE == 4
out = in0 - in1;
#elif CWISE_TYPE == 5
......@@ -38,10 +36,20 @@ __kernel void cwise(KERNEL_ERROR_PARAMS
in1 = (DATA_TYPE4)(0, 0, 0, 0);
out = in1 - in0;
#elif CWISE_TYPE == 7
out.x = fabs(in0.x);
out.y = fabs(in0.y);
out.z = fabs(in0.z);
out.w = fabs(in0.w);
out = fabs(in0);
#endif
#if CWISE_TYPE == 1 || CWISE_TYPE == 2 || CWISE_TYPE == 3 || CWISE_TYPE == 4
if (remain_chan < 4) {
switch (remain_chan) {
case 1:
out.y = 0;
case 2:
out.z = 0;
case 3:
out.w = 0;
}
}
#endif
WRITE_IMAGET(output, (int2)(w, hb), out);
......
......@@ -71,6 +71,8 @@ void CWiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
kernel_.setArg(idx++, gws[1]);
}
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, static_cast<int32_t>(width));
kernel_.setArg(idx++, static_cast<int32_t>(channels));
kernel_.setArg(idx++, static_cast<float>(coeff_));
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
......
......@@ -34,12 +34,12 @@ void CalInOutputImageShape(const std::vector<index_t> &shape, /* NHWC */
(*image_shape)[1] = shape[0] * shape[1];
}
// [RoundUp<4>(Ic), H * W * (Oc + 3) / 4]
// [Ic, H * W * (Oc + 3) / 4]
void CalConv2dFilterImageShape(const std::vector<index_t> &shape, /* HWOI */
std::vector<size_t> *image_shape) {
MACE_CHECK(shape.size() == 4);
image_shape->resize(2);
(*image_shape)[0] = RoundUp<index_t>(shape[3], 4);
(*image_shape)[0] = shape[3];
(*image_shape)[1] = shape[0] * shape[1] * RoundUpDiv4(shape[2]);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册