提交 cb9db5f5 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #613 from dolphin8/metal

add kernel (pool & reshape & softmax)
...@@ -95,3 +95,73 @@ kernel void texture2d_to_2d_array(texture2d<float, access::read> inTexture [[tex ...@@ -95,3 +95,73 @@ kernel void texture2d_to_2d_array(texture2d<float, access::read> inTexture [[tex
outTexture.write(input, gid.xy, 0); outTexture.write(input, gid.xy, 0);
} }
kernel void pool(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
const device int * ksize [[buffer(0)]],
const device int * stride [[buffer(1)]],
const device int * padding [[buffer(2)]],
const device int * poolType [[buffer(3)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
int xmin = gid.x * stride[0] - padding[0];
int xmax = min(xmin + ksize[0], int(inTexture.get_width()));
xmin = max(xmin, 0);
int ymin = gid.y * stride[1] - padding[1];
int ymax = min(ymin + ksize[1], int(inTexture.get_width()));
ymin = max(ymin, 0);
half4 r = 0;
if (*poolType == 0) {
r = inTexture.read(uint2(xmin, ymin), gid.z);
for (int32_t x = xmin; x < xmax; x++) {
for (int y = ymin; y < ymax; y++) {
r = fmax(r, inTexture.read(uint2(x, y), gid.z));
}
}
} else if (*poolType == 1) {
for (int32_t x = xmin; x < xmax; x++) {
for (int y = ymin; y < ymax; y++) {
r += inTexture.read(uint2(x, y), gid.z);
}
}
r /= ksize[0] * ksize[1];
}
outTexture.write(r, gid.xy, gid.z);
}
kernel void reshape(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d<half, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) return;
int zz = gid.y / 4;
int cc = gid.y % 4;
half4 r = inTexture.read(uint2(0, 0), zz);
r[0] = r[cc];
outTexture.write(r, gid.xy, gid.z);
}
kernel void softmax(texture2d<half, access::read> inTexture [[texture(1)]],
texture2d<half, access::write> outTexture [[texture(2)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) return;
// int xsize = inTexture.get_width();
int ysize = inTexture.get_height();
half maxv = inTexture.read(uint2(0, 0), gid.z)[0];
for (int y = 0; y < ysize; y++) {
half r = inTexture.read(uint2(0, y), gid.z)[0];
maxv = max(maxv, r);
}
half sum = 0;
for (int y = 0; y < ysize; y++) {
half r = inTexture.read(uint2(0, y), gid.z)[0];
sum += exp(r - maxv);
}
half4 rr = inTexture.read(gid.xy, gid.z);
rr[0] = exp(rr[0] - maxv) / sum;
outTexture.write(rr, gid.xy, gid.z);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册