diff --git a/mace/kernels/opencl/cl/conv_2d.cl b/mace/kernels/opencl/cl/conv_2d.cl index 5289e513301d41d2684be090597636f263b2b601..fae002e4340edee0a4f14932db5e35a0beec840e 100644 --- a/mace/kernels/opencl/cl/conv_2d.cl +++ b/mace/kernels/opencl/cl/conv_2d.cl @@ -87,7 +87,7 @@ __kernel void conv_2d(KERNEL_ERROR_PARAMS #undef READ_INPUT - // int filter_idx = (hb_idx * filter_width + width_idx) * rounded_in_ch + (in_ch_blk << 2); + // int filter_idx = (hb_idx * filter_width + width_idx) * in_ch + (in_ch_blk << 2); weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 0, filter_y_idx)); weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 1, filter_y_idx)); weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 2, filter_y_idx)); diff --git a/mace/kernels/opencl/cl/conv_2d_3x3.cl b/mace/kernels/opencl/cl/conv_2d_3x3.cl index f7c1149f8fd2bbeddc0ba9efccd9f35cd019196c..1389af6654172009bccea78ac60c58b8b50a771c 100644 --- a/mace/kernels/opencl/cl/conv_2d_3x3.cl +++ b/mace/kernels/opencl/cl/conv_2d_3x3.cl @@ -87,7 +87,7 @@ __kernel void conv_2d_3x3(KERNEL_ERROR_PARAMS #undef READ_INPUT - // int filter_idx = (hb_idx * 3 + width_idx) * rounded_in_ch + (in_ch_blk << 2); + // int filter_idx = (hb_idx * 3 + width_idx) * in_ch + (in_ch_blk << 2); weights0 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 0, filter_y_idx)); weights1 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 1, filter_y_idx)); weights2 = READ_IMAGET(filter, SAMPLER, (int2)(filter_x_idx + 2, filter_y_idx)); diff --git a/mace/kernels/opencl/cl/cwise.cl b/mace/kernels/opencl/cl/cwise.cl index e93dfc7cbde706d4d3cb4ead20a6efdb3ea0c5ea..2d3f3105cbddb0dfd9d8b3b208bf400772f60fb4 100644 --- a/mace/kernels/opencl/cl/cwise.cl +++ b/mace/kernels/opencl/cl/cwise.cl @@ -3,6 +3,8 @@ __kernel void cwise(KERNEL_ERROR_PARAMS GLOBAL_WORK_GROUP_SIZE_DIM2 __read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ + __private const int width, + __private const int channel, __private const float value, __write_only image2d_t output) { const int w = get_global_id(0); @@ -12,6 +14,8 @@ __kernel void cwise(KERNEL_ERROR_PARAMS if (w >= global_size_dim0 || hb >= global_size_dim1) return; #endif + const int remain_chan = channel - mul24((w / width), 4); + DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb)); DATA_TYPE4 in1 = (DATA_TYPE4){value, value, value, value}; DATA_TYPE4 out; @@ -21,15 +25,9 @@ __kernel void cwise(KERNEL_ERROR_PARAMS #elif CWISE_TYPE == 1 out = in0 + in1; #elif CWISE_TYPE == 2 - out.x = fmax(in0.x, value); - out.y = fmax(in0.y, value); - out.z = fmax(in0.z, value); - out.z = fmax(in0.w, value); + out = fmax(in0, in1); #elif CWISE_TYPE == 3 - out.x = fmin(in0.x, value); - out.y = fmin(in0.y, value); - out.z = fmin(in0.z, value); - out.z = fmin(in0.w, value); + out = fmin(in0, in1); #elif CWISE_TYPE == 4 out = in0 - in1; #elif CWISE_TYPE == 5 @@ -38,10 +36,20 @@ __kernel void cwise(KERNEL_ERROR_PARAMS in1 = (DATA_TYPE4)(0, 0, 0, 0); out = in1 - in0; #elif CWISE_TYPE == 7 - out.x = fabs(in0.x); - out.y = fabs(in0.y); - out.z = fabs(in0.z); - out.w = fabs(in0.w); + out = fabs(in0); +#endif + +#if CWISE_TYPE == 1 || CWISE_TYPE == 2 || CWISE_TYPE == 3 || CWISE_TYPE == 4 + if (remain_chan < 4) { + switch (remain_chan) { + case 1: + out.y = 0; + case 2: + out.z = 0; + case 3: + out.w = 0; + } + } #endif WRITE_IMAGET(output, (int2)(w, hb), out); diff --git a/mace/kernels/opencl/cwise_opencl.cc b/mace/kernels/opencl/cwise_opencl.cc index cf716c27937c9c14b8df53f4061e3722aa4e8797..a9565a3d41c41a6f1d1975c6c744aafa5eb5a6e8 100644 --- a/mace/kernels/opencl/cwise_opencl.cc +++ b/mace/kernels/opencl/cwise_opencl.cc @@ -71,6 +71,8 @@ void CWiseFunctor::operator()(const Tensor *input, kernel_.setArg(idx++, gws[1]); } kernel_.setArg(idx++, *(input->opencl_image())); + kernel_.setArg(idx++, static_cast(width)); + kernel_.setArg(idx++, static_cast(channels)); kernel_.setArg(idx++, static_cast(coeff_)); kernel_.setArg(idx++, *(output->opencl_image())); input_shape_ = input->shape(); diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index 06df28895506f64463e9d00451b3162d30a7b2c6..7934b7209f0456edda559044b041f482b8554472 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -34,12 +34,12 @@ void CalInOutputImageShape(const std::vector &shape, /* NHWC */ (*image_shape)[1] = shape[0] * shape[1]; } -// [RoundUp<4>(Ic), H * W * (Oc + 3) / 4] +// [Ic, H * W * (Oc + 3) / 4] void CalConv2dFilterImageShape(const std::vector &shape, /* HWOI */ std::vector *image_shape) { MACE_CHECK(shape.size() == 4); image_shape->resize(2); - (*image_shape)[0] = RoundUp(shape[3], 4); + (*image_shape)[0] = shape[3]; (*image_shape)[1] = shape[0] * shape[1] * RoundUpDiv4(shape[2]); }