提交 b8f3d025 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #626 from dolphin8/metal

fix shader
......@@ -146,37 +146,34 @@ kernel void pool(texture2d_array<float, access::read> inTexture [[texture(0)]],
kernel void reshape(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d<float, access::write> outTexture [[texture(1)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) return;
int zz = gid.x / 4;
int cc = gid.x % 4;
float4 r = inTexture.read(uint2(0, 0), zz);
r[0] = r[cc];
r[1] = 0;
r[2] = 0;
r[3] = 0;
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
float4 r = inTexture.read(uint2(0, 0), gid.z);
outTexture.write(r, gid.xy, gid.z);
}
kernel void softmax(texture2d<float, access::read> inTexture [[texture(0)]],
texture2d<float, access::write> outTexture [[texture(1)]],
kernel void softmax(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) return;
int xsize = inTexture.get_width();
float maxv = inTexture.read(uint2(0, 0), gid.z)[0];
for (int x = 0; x < xsize; x++) {
float r = inTexture.read(uint2(x, 0), gid.z)[0];
maxv = max(maxv, r);
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
int zsize = inTexture.get_array_size();
float maxv = inTexture.read(uint2(0, 0), 0)[0];
for (int z = 0; z < zsize; z++) {
float4 r = inTexture.read(uint2(0, 0), z);
maxv = max(maxv, max(max(r[0], r[1]), max(r[2], r[3])));
}
float sum = 0;
for (int x = 0; x < xsize; x++) {
float r = inTexture.read(uint2(x, 0), gid.z)[0];
sum += exp(r - maxv);
for (int z = 0; z < zsize; z++) {
float4 r = inTexture.read(uint2(0, 0), z);
sum += exp(r[0] - maxv) + exp(r[1] - maxv) + exp(r[2] - maxv) + exp(r[3] - maxv);
}
float4 rr = inTexture.read(gid.xy, gid.z);
rr[0] = exp(rr[0] - maxv) / sum;
rr = exp(rr - maxv) / sum;
outTexture.write(rr, gid.xy, gid.z);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册