提交 c53025cc 编写于 作者: L liuruilong

fix prelu result error

上级 73a01087
......@@ -14,7 +14,7 @@
import Foundation
let testTo = 10
let testTo = 7
public class ResultHolder<P: PrecisionType> {
public let dim: [Int]
......
......@@ -28,16 +28,16 @@ kernel void prelu_channel(texture2d_array<float, access::sample> inTexture [[tex
}
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, gid.x, gid.y, gid.z);
float4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
float4 alpha_value = alpha[gid.z];
float4 output;
output.x = input.x > 0 ? input.x : alpha[gid.z].x;
output.x = input.y > 0 ? input.y : alpha[gid.z].y;
output.x = input.z > 0 ? input.z : alpha[gid.z].z;
output.x = input.w > 0 ? input.w : alpha[gid.z].w;
output.x = input.x > 0 ? input.x : (alpha_value.x * input.x);
output.y = input.y > 0 ? input.y : (alpha_value.y * input.y);
output.z = input.z > 0 ? input.z : (alpha_value.z * input.z);
output.w = input.w > 0 ? input.w : (alpha_value.w * input.w);
outTexture.write(output, gid.xy, gid.z);
}
kernel void prelu_element(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float4 *alpha [[buffer(0)]],
......@@ -49,19 +49,19 @@ kernel void prelu_element(texture2d_array<float, access::sample> inTexture [[tex
}
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, gid.x, gid.y, gid.z);
float4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
int alpha_to = (gid.y * inTexture.get_width() + gid.x) * inTexture.get_array_size();
float4 alpha_value = alpha[alpha_to + gid.z];
float4 output;
output.x = input.x > 0 ? input.x : alpha[alpha_to + gid.z].x;
output.x = input.y > 0 ? input.y : alpha[alpha_to + gid.z].y;
output.x = input.z > 0 ? input.z : alpha[alpha_to + gid.z].z;
output.x = input.w > 0 ? input.w : alpha[alpha_to + gid.z].w;
output.x = input.x > 0 ? input.x : (alpha_value.x * input.x);
output.y = input.y > 0 ? input.y : (alpha_value.y * input.y);
output.z = input.z > 0 ? input.z : (alpha_value.z * input.z);
output.w = input.w > 0 ? input.w : (alpha_value.w * input.w);
outTexture.write(output, gid.xy, gid.z);
}
kernel void prelu_other(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float *alpha [[buffer(0)]],
......@@ -73,12 +73,12 @@ kernel void prelu_other(texture2d_array<float, access::sample> inTexture [[textu
}
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, gid.x, gid.y, gid.z);
float4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
float alpha_value = alpha[0];
float4 output;
output.x = input.x > 0 ? input.x : alpha[0];
output.x = input.y > 0 ? input.y : alpha[0];
output.x = input.z > 0 ? input.z : alpha[0];
output.x = input.w > 0 ? input.w : alpha[0];
output.x = input.x > 0 ? input.x : (alpha_value * input.x);
output.y = input.y > 0 ? input.y : (alpha_value * input.y);
output.z = input.z > 0 ? input.z : (alpha_value * input.z);
output.w = input.w > 0 ? input.w : (alpha_value * input.w);
outTexture.write(output, gid.xy, gid.z);
}
......@@ -50,8 +50,17 @@ class PreluOp<P: PrecisionType>: Operator<PreluKernel<P>, PreluParam<P>>, Runabl
}
func delogOutput() {
print(" \(type) input: ")
print(para.input.metalTexture.toTensor(dim: (n: para.input.originDim[0], c: para.input.originDim[1], h: para.input.originDim[2], w: para.input.originDim[3])).strideArray())
print(" \(type) Alpha: ")
let _: Float32? = para.alpha.buffer.logDesc(header: " alpha: ", stridable: false)
print(" \(type) output: ")
print(para.output.metalTexture.toTensor(dim: (n: para.output.originDim[0], c: para.output.originDim[1], h: para.output.originDim[2], w: para.output.originDim[3])).strideArray())
}
// print("softmax delog")
// let _: P? = para.input.metalTexture.logDesc(header: "softmax input: ", stridable: false)
// let _: P? = para.output.metalTexture.logDesc(header: "softmax output: ", stridable: false)
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册