diff --git a/mace/ops/opencl/cl/eltwise.cl b/mace/ops/opencl/cl/eltwise.cl index 6f352d4f9429a605fc49d6809709c3f6cae34948..5c38a7dc505150ac6f663f4e508e42093e1c0ec6 100644 --- a/mace/ops/opencl/cl/eltwise.cl +++ b/mace/ops/opencl/cl/eltwise.cl @@ -83,21 +83,22 @@ __kernel void eltwise(OUT_OF_RANGE_PARAMS #endif #endif -#if INPUT_TYPE == 1 || INPUT_TYPE == 4 - #if ELTWISE_TYPE == 0 || ELTWISE_TYPE == 1 || ELTWISE_TYPE == 4 || \ - ELTWISE_TYPE == 5 || ELTWISE_TYPE == 8 || ELTWISE_TYPE == 9 - const int remain_channel = channel - 4 * chan_idx; - if (remain_channel < 4) { - switch (remain_channel) { - case 1: - out.y = 0; - case 2: - out.z = 0; - case 3: - out.w = 0; - } +#if ((INPUT_TYPE == 1 || INPUT_TYPE == 4) && \ + (ELTWISE_TYPE == 0 || ELTWISE_TYPE == 1 || ELTWISE_TYPE == 4 || \ + ELTWISE_TYPE == 5 || ELTWISE_TYPE == 8 || ELTWISE_TYPE == 9)) || \ + ((INPUT_TYPE != 1 || INPUT_TYPE != 4) && \ + (ELTWISE_TYPE == 3 || ELTWISE_TYPE == 9)) + const int remain_channel = channel - 4 * chan_idx; + if (remain_channel < 4) { + switch (remain_channel) { + case 1: + out.y = 0; + case 2: + out.z = 0; + case 3: + out.w = 0; } - #endif + } #endif WRITE_IMAGET(output, (int2)(pos, hb), out);