提交 4259b1ce 编写于 作者: R Ruilong Liu 提交者: GitHub

Merge pull request #864 from codeWorm2015/metal

fix prelu result error
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import Foundation import Foundation
let testTo = 10 let testTo = 7
public class ResultHolder<P: PrecisionType> { public class ResultHolder<P: PrecisionType> {
public let dim: [Int] public let dim: [Int]
......
...@@ -28,16 +28,16 @@ kernel void prelu_channel(texture2d_array<float, access::sample> inTexture [[tex ...@@ -28,16 +28,16 @@ kernel void prelu_channel(texture2d_array<float, access::sample> inTexture [[tex
} }
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero); constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, gid.x, gid.y, gid.z); float4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
float4 alpha_value = alpha[gid.z];
float4 output; float4 output;
output.x = input.x > 0 ? input.x : alpha[gid.z].x; output.x = input.x > 0 ? input.x : (alpha_value.x * input.x);
output.x = input.y > 0 ? input.y : alpha[gid.z].y; output.y = input.y > 0 ? input.y : (alpha_value.y * input.y);
output.x = input.z > 0 ? input.z : alpha[gid.z].z; output.z = input.z > 0 ? input.z : (alpha_value.z * input.z);
output.x = input.w > 0 ? input.w : alpha[gid.z].w; output.w = input.w > 0 ? input.w : (alpha_value.w * input.w);
outTexture.write(output, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
kernel void prelu_element(texture2d_array<float, access::sample> inTexture [[texture(0)]], kernel void prelu_element(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]], texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float4 *alpha [[buffer(0)]], const device float4 *alpha [[buffer(0)]],
...@@ -49,19 +49,19 @@ kernel void prelu_element(texture2d_array<float, access::sample> inTexture [[tex ...@@ -49,19 +49,19 @@ kernel void prelu_element(texture2d_array<float, access::sample> inTexture [[tex
} }
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero); constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, gid.x, gid.y, gid.z); float4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
int alpha_to = (gid.y * inTexture.get_width() + gid.x) * inTexture.get_array_size(); int alpha_to = (gid.y * inTexture.get_width() + gid.x) * inTexture.get_array_size();
float4 alpha_value = alpha[alpha_to + gid.z];
float4 output; float4 output;
output.x = input.x > 0 ? input.x : alpha[alpha_to + gid.z].x; output.x = input.x > 0 ? input.x : (alpha_value.x * input.x);
output.x = input.y > 0 ? input.y : alpha[alpha_to + gid.z].y; output.y = input.y > 0 ? input.y : (alpha_value.y * input.y);
output.x = input.z > 0 ? input.z : alpha[alpha_to + gid.z].z; output.z = input.z > 0 ? input.z : (alpha_value.z * input.z);
output.x = input.w > 0 ? input.w : alpha[alpha_to + gid.z].w; output.w = input.w > 0 ? input.w : (alpha_value.w * input.w);
outTexture.write(output, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
kernel void prelu_other(texture2d_array<float, access::sample> inTexture [[texture(0)]], kernel void prelu_other(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]], texture2d_array<float, access::write> outTexture [[texture(1)]],
const device float *alpha [[buffer(0)]], const device float *alpha [[buffer(0)]],
...@@ -73,12 +73,12 @@ kernel void prelu_other(texture2d_array<float, access::sample> inTexture [[textu ...@@ -73,12 +73,12 @@ kernel void prelu_other(texture2d_array<float, access::sample> inTexture [[textu
} }
constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero); constexpr sampler sample(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 input = inTexture.sample(sample, gid.x, gid.y, gid.z); float4 input = inTexture.sample(sample, float2(gid.x, gid.y), gid.z);
float alpha_value = alpha[0];
float4 output; float4 output;
output.x = input.x > 0 ? input.x : alpha[0]; output.x = input.x > 0 ? input.x : (alpha_value * input.x);
output.x = input.y > 0 ? input.y : alpha[0]; output.y = input.y > 0 ? input.y : (alpha_value * input.y);
output.x = input.z > 0 ? input.z : alpha[0]; output.z = input.z > 0 ? input.z : (alpha_value * input.z);
output.x = input.w > 0 ? input.w : alpha[0]; output.w = input.w > 0 ? input.w : (alpha_value * input.w);
outTexture.write(output, gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
...@@ -50,8 +50,17 @@ class PreluOp<P: PrecisionType>: Operator<PreluKernel<P>, PreluParam<P>>, Runabl ...@@ -50,8 +50,17 @@ class PreluOp<P: PrecisionType>: Operator<PreluKernel<P>, PreluParam<P>>, Runabl
} }
func delogOutput() { func delogOutput() {
print(" \(type) input: ")
print(para.input.metalTexture.toTensor(dim: (n: para.input.originDim[0], c: para.input.originDim[1], h: para.input.originDim[2], w: para.input.originDim[3])).strideArray())
print(" \(type) Alpha: ")
let _: Float32? = para.alpha.buffer.logDesc(header: " alpha: ", stridable: false)
print(" \(type) output: ")
print(para.output.metalTexture.toTensor(dim: (n: para.output.originDim[0], c: para.output.originDim[1], h: para.output.originDim[2], w: para.output.originDim[3])).strideArray())
}
// print("softmax delog") // print("softmax delog")
// let _: P? = para.input.metalTexture.logDesc(header: "softmax input: ", stridable: false) // let _: P? = para.input.metalTexture.logDesc(header: "softmax input: ", stridable: false)
// let _: P? = para.output.metalTexture.logDesc(header: "softmax output: ", stridable: false) // let _: P? = para.output.metalTexture.logDesc(header: "softmax output: ", stridable: false)
}
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册