diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal index daedf92c6c67cab32acdd183d35bb6a1d629fb60..f56e54362a8dda18db9c1fb199c9eb6af10f9b92 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/Kernels.metal @@ -95,3 +95,73 @@ kernel void texture2d_to_2d_array(texture2d inTexture [[tex outTexture.write(input, gid.xy, 0); } +kernel void pool(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], + const device int * ksize [[buffer(0)]], + const device int * stride [[buffer(1)]], + const device int * padding [[buffer(2)]], + const device int * poolType [[buffer(3)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height() || + gid.z >= outTexture.get_array_size()) return; + int xmin = gid.x * stride[0] - padding[0]; + int xmax = min(xmin + ksize[0], int(inTexture.get_width())); + xmin = max(xmin, 0); + int ymin = gid.y * stride[1] - padding[1]; + int ymax = min(ymin + ksize[1], int(inTexture.get_width())); + ymin = max(ymin, 0); + + half4 r = 0; + if (*poolType == 0) { + r = inTexture.read(uint2(xmin, ymin), gid.z); + for (int32_t x = xmin; x < xmax; x++) { + for (int y = ymin; y < ymax; y++) { + r = fmax(r, inTexture.read(uint2(x, y), gid.z)); + } + } + } else if (*poolType == 1) { + for (int32_t x = xmin; x < xmax; x++) { + for (int y = ymin; y < ymax; y++) { + r += inTexture.read(uint2(x, y), gid.z); + } + } + r /= ksize[0] * ksize[1]; + } + outTexture.write(r, gid.xy, gid.z); +} + + +kernel void reshape(texture2d_array inTexture [[texture(0)]], + texture2d outTexture [[texture(1)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height()) return; + int zz = gid.y / 4; + int cc = gid.y % 4; + half4 r = inTexture.read(uint2(0, 0), zz); + r[0] = r[cc]; + outTexture.write(r, gid.xy, gid.z); +} + +kernel void softmax(texture2d inTexture [[texture(1)]], + texture2d outTexture [[texture(2)]], + uint3 gid [[thread_position_in_grid]]) { + if (gid.x >= outTexture.get_width() || + gid.y >= outTexture.get_height()) return; +// int xsize = inTexture.get_width(); + int ysize = inTexture.get_height(); + half maxv = inTexture.read(uint2(0, 0), gid.z)[0]; + for (int y = 0; y < ysize; y++) { + half r = inTexture.read(uint2(0, y), gid.z)[0]; + maxv = max(maxv, r); + } + half sum = 0; + for (int y = 0; y < ysize; y++) { + half r = inTexture.read(uint2(0, y), gid.z)[0]; + sum += exp(r - maxv); + } + half4 rr = inTexture.read(gid.xy, gid.z); + rr[0] = exp(rr[0] - maxv) / sum; + outTexture.write(rr, gid.xy, gid.z); +}