未验证 提交 5a351346 编写于 作者: Y Yanzhan Yang 提交者: GitHub

1.check slice and nearest op result. 2.reshape output should have same layout as its input. (#1640)

上级 f7ebe337
...@@ -67,8 +67,8 @@ struct ConcatParam { ...@@ -67,8 +67,8 @@ struct ConcatParam {
#undef R #undef R
#undef V #undef V
// lens: (R=4, N=3, V=x) // lens: (R=4, N=3, V=y)
#define V VX #define V VY
#define R 4 #define R 4
#define N 3 #define N 3
#define P float #define P float
......
...@@ -28,8 +28,8 @@ kernel void nearest_interp(texture2d_array<float, access::sample> inTexture [[te ...@@ -28,8 +28,8 @@ kernel void nearest_interp(texture2d_array<float, access::sample> inTexture [[te
gid.z >= outTexture.get_array_size()) return; gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float scale = param.scale; float scale = param.scale;
uint x = uint(round(float(gid.x) / scale)); uint x = uint(floor(float(gid.x) / scale));
uint y = uint(round(float(gid.y) / scale)); uint y = uint(floor(float(gid.y) / scale));
const float4 input = inTexture.read(uint2(x, y), gid.z); const float4 input = inTexture.read(uint2(x, y), gid.z);
outTexture.write(input, gid.xy, gid.z); outTexture.write(input, gid.xy, gid.z);
} }
...@@ -43,8 +43,8 @@ kernel void nearest_interp_half(texture2d_array<half, access::sample> inTexture ...@@ -43,8 +43,8 @@ kernel void nearest_interp_half(texture2d_array<half, access::sample> inTexture
gid.z >= outTexture.get_array_size()) return; gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float scale = param.scale; float scale = param.scale;
uint x = uint(round(float(gid.x) / scale)); uint x = uint(floor(float(gid.x) / scale));
uint y = uint(round(float(gid.y) / scale)); uint y = uint(floor(float(gid.y) / scale));
const half4 input = inTexture.read(uint2(x, y), gid.z); const half4 input = inTexture.read(uint2(x, y), gid.z);
outTexture.write(input, gid.xy, gid.z); outTexture.write(input, gid.xy, gid.z);
} }
...@@ -24,6 +24,8 @@ struct MetalSliceParam { ...@@ -24,6 +24,8 @@ struct MetalSliceParam {
short end1; short end1;
short end2; short end2;
short end3; short end3;
int iC;
int oC;
}; };
kernel void slice(texture2d_array<float, access::sample> inTexture [[texture(0)]], kernel void slice(texture2d_array<float, access::sample> inTexture [[texture(0)]],
...@@ -35,9 +37,14 @@ kernel void slice(texture2d_array<float, access::sample> inTexture [[texture(0)] ...@@ -35,9 +37,14 @@ kernel void slice(texture2d_array<float, access::sample> inTexture [[texture(0)]
gid.z >= outTexture.get_array_size()) return; gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 output; float4 output;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; ++i) {
int input_c = gid.z * 4 + i + param.start1; int tmp = gid.z * 4 + i;
int input_z = input_c / 4; int output_c = tmp % param.oC;
int output_n = tmp / param.oC;
int c = output_c + param.start1;
tmp = output_n * param.iC + c;
int input_z = tmp / 4;
int input_c = tmp % 4;
const float4 input = inTexture.read(gid.xy, input_z); const float4 input = inTexture.read(gid.xy, input_z);
output[i] = input[input_c % 4]; output[i] = input[input_c % 4];
} }
...@@ -52,12 +59,17 @@ kernel void slice_half(texture2d_array<half, access::sample> inTexture [[texture ...@@ -52,12 +59,17 @@ kernel void slice_half(texture2d_array<half, access::sample> inTexture [[texture
gid.y >= outTexture.get_height() || gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return; gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero); constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 output; half4 output;
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; ++i) {
int input_c = gid.z * 4 + i + param.start1; int tmp = gid.z * 4 + i;
int input_z = input_c / 4; int output_c = tmp % param.oC;
const float4 input = float4(inTexture.read(gid.xy, input_z)); int output_n = tmp / param.oC;
int c = output_c + param.start1;
tmp = output_n * param.iC + c;
int input_z = tmp / 4;
int input_c = tmp % 4;
const half4 input = inTexture.read(gid.xy, input_z);
output[i] = input[input_c % 4]; output[i] = input[input_c % 4];
} }
outTexture.write(half4(output), gid.xy, gid.z); outTexture.write(output, gid.xy, gid.z);
} }
...@@ -534,7 +534,7 @@ public extension MTLTexture { ...@@ -534,7 +534,7 @@ public extension MTLTexture {
for c in 0..<4{ for c in 0..<4{
for h in 0..<dim.h { for h in 0..<dim.h {
for w in 0..<dim.w { for w in 0..<dim.w {
if (s * 4 + c) < dim.c { if (s * 4 + c) < (dim.c * dim.n) {
let textureValue = textureArray[dim.w * dim.h * 4 * s + h * dim.w * 4 + w * 4 + c] let textureValue = textureArray[dim.w * dim.h * 4 * s + h * dim.w * 4 + w * 4 + c]
output.append(textureValue) output.append(textureValue)
} }
......
...@@ -34,7 +34,7 @@ class ReshapeKernel<P: PrecisionProtocol>: Kernel, Computable{ ...@@ -34,7 +34,7 @@ class ReshapeKernel<P: PrecisionProtocol>: Kernel, Computable{
required init(device: MTLDevice, param: ReshapeParam<P>, initContext: InitContext) throws { required init(device: MTLDevice, param: ReshapeParam<P>, initContext: InitContext) throws {
do { do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision) try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error { } catch let error {
throw error throw error
} }
......
...@@ -23,6 +23,8 @@ public struct SliceMetalParam { ...@@ -23,6 +23,8 @@ public struct SliceMetalParam {
let end1: Int16 let end1: Int16
let end2: Int16 let end2: Int16
let end3: Int16 let end3: Int16
let iC: Int32
let oC: Int32
} }
class SliceKernel<P: PrecisionProtocol>: Kernel, Computable { class SliceKernel<P: PrecisionProtocol>: Kernel, Computable {
...@@ -40,7 +42,7 @@ class SliceKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -40,7 +42,7 @@ class SliceKernel<P: PrecisionProtocol>: Kernel, Computable {
required init(device: MTLDevice, param: SliceParam<P>, initContext: InitContext) throws { required init(device: MTLDevice, param: SliceParam<P>, initContext: InitContext) throws {
do { do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision) try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error { } catch let error {
throw error throw error
} }
...@@ -60,7 +62,11 @@ class SliceKernel<P: PrecisionProtocol>: Kernel, Computable { ...@@ -60,7 +62,11 @@ class SliceKernel<P: PrecisionProtocol>: Kernel, Computable {
let end1 = ranges[1][1] let end1 = ranges[1][1]
let end2 = ranges[2][1] let end2 = ranges[2][1]
let end3 = ranges[3][1] let end3 = ranges[3][1]
metalParam = SliceMetalParam.init(start0: start0, start1: start1, start2: start2, start3: start3, end0: end0, end1: end1, end2: end2, end3: end3)
let iC = Int32(param.input.tensorDim[1])
let oC = Int32(param.output.tensorDim[1])
metalParam = SliceMetalParam.init(start0: start0, start1: start1, start2: start2, start3: start3, end0: end0, end1: end1, end2: end2, end3: end3, iC: iC, oC: oC)
if GlobalConfig.shared.computePrecision == .Float32 { if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "slice", initContext: initContext) super.init(device: device, inFunctionName: "slice", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 { } else if GlobalConfig.shared.computePrecision == .Float16 {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册