diff --git a/src/operators/kernel/cl/cl_kernel/softmax.cl b/src/operators/kernel/cl/cl_kernel/softmax.cl index 60f0cf409596632b67817cd236f9621010522571..ba5cee7358ca1c784a97134bb81b8f753cb9776c 100644 --- a/src/operators/kernel/cl/cl_kernel/softmax.cl +++ b/src/operators/kernel/cl/cl_kernel/softmax.cl @@ -12,6 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + __kernel void softmax(__read_only image2d_t input, __write_only image2d_t output, __private const int d0, @@ -24,18 +26,19 @@ __kernel void softmax(__read_only image2d_t input, const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - half4 maxv = read_imageh(input, sampler, int2(z * d3, y)); - half4 buf[d3] = {piece}; - for (int i = 1; i < d3; i++) { - buf[i] = read_imageh(input, sampler, int2(z * d3 + i, y)); - maxv = max(maxv, buf[i]); + half4 cv = read_imageh(input, sampler, (int2)(x, y)); + half4 maxv = cv; + for (int i = 0; i < d3; i++) { + half4 temp = read_imageh(input, sampler, (int2)(z * d3 + i, y)); + maxv = max(maxv, temp); } - float4 sum = 0; + half4 sum = (half4)0.0f; + // half4 x = = (half4)0.0f; for (int i = 0; i < d3; i++) { - buf[i] = exp(buf[i] - maxv); - sum += buf[i]; + half4 temp = read_imageh(input, sampler, (int2)(z * d3 + i, y)); + sum += exp(temp - maxv); } - half4 r = buf[x] / sum; + half4 r = exp(cv - maxv) / sum; - write_imageh(output, int2(z * d3 + x, y), r); + write_imageh(output, (int2)(z * d3 + x, y), r); }