提交 b8f3d025 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #626 from dolphin8/metal

fix shader
...@@ -146,37 +146,34 @@ kernel void pool(texture2d_array<float, access::read> inTexture [[texture(0)]], ...@@ -146,37 +146,34 @@ kernel void pool(texture2d_array<float, access::read> inTexture [[texture(0)]],
kernel void reshape(texture2d_array<float, access::read> inTexture [[texture(0)]], kernel void reshape(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d<float, access::write> outTexture [[texture(1)]], texture2d_array<float, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) return; gid.y >= outTexture.get_height() ||
int zz = gid.x / 4; gid.z >= outTexture.get_array_size()) return;
int cc = gid.x % 4;
float4 r = inTexture.read(uint2(0, 0), zz); float4 r = inTexture.read(uint2(0, 0), gid.z);
r[0] = r[cc];
r[1] = 0;
r[2] = 0;
r[3] = 0;
outTexture.write(r, gid.xy, gid.z); outTexture.write(r, gid.xy, gid.z);
} }
kernel void softmax(texture2d<float, access::read> inTexture [[texture(0)]], kernel void softmax(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d<float, access::write> outTexture [[texture(1)]], texture2d_array<float, access::write> outTexture [[texture(1)]],
uint3 gid [[thread_position_in_grid]]) { uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() || if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height()) return; gid.y >= outTexture.get_height() ||
int xsize = inTexture.get_width(); gid.z >= outTexture.get_array_size()) return;
float maxv = inTexture.read(uint2(0, 0), gid.z)[0]; int zsize = inTexture.get_array_size();
for (int x = 0; x < xsize; x++) { float maxv = inTexture.read(uint2(0, 0), 0)[0];
float r = inTexture.read(uint2(x, 0), gid.z)[0]; for (int z = 0; z < zsize; z++) {
maxv = max(maxv, r); float4 r = inTexture.read(uint2(0, 0), z);
maxv = max(maxv, max(max(r[0], r[1]), max(r[2], r[3])));
} }
float sum = 0; float sum = 0;
for (int x = 0; x < xsize; x++) { for (int z = 0; z < zsize; z++) {
float r = inTexture.read(uint2(x, 0), gid.z)[0]; float4 r = inTexture.read(uint2(0, 0), z);
sum += exp(r - maxv); sum += exp(r[0] - maxv) + exp(r[1] - maxv) + exp(r[2] - maxv) + exp(r[3] - maxv);
} }
float4 rr = inTexture.read(gid.xy, gid.z); float4 rr = inTexture.read(gid.xy, gid.z);
rr[0] = exp(rr[0] - maxv) / sum; rr = exp(rr - maxv) / sum;
outTexture.write(rr, gid.xy, gid.z); outTexture.write(rr, gid.xy, gid.z);
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册