未验证 提交 5a351346 编写于 作者: Y Yanzhan Yang 提交者: GitHub

1.check slice and nearest op result. 2.reshape output should have same layout as its input. (#1640)

上级 f7ebe337
......@@ -67,8 +67,8 @@ struct ConcatParam {
#undef R
#undef V
// lens: (R=4, N=3, V=x)
#define V VX
// lens: (R=4, N=3, V=y)
#define V VY
#define R 4
#define N 3
#define P float
......
......@@ -28,8 +28,8 @@ kernel void nearest_interp(texture2d_array<float, access::sample> inTexture [[te
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float scale = param.scale;
uint x = uint(round(float(gid.x) / scale));
uint y = uint(round(float(gid.y) / scale));
uint x = uint(floor(float(gid.x) / scale));
uint y = uint(floor(float(gid.y) / scale));
const float4 input = inTexture.read(uint2(x, y), gid.z);
outTexture.write(input, gid.xy, gid.z);
}
......@@ -43,8 +43,8 @@ kernel void nearest_interp_half(texture2d_array<half, access::sample> inTexture
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float scale = param.scale;
uint x = uint(round(float(gid.x) / scale));
uint y = uint(round(float(gid.y) / scale));
uint x = uint(floor(float(gid.x) / scale));
uint y = uint(floor(float(gid.y) / scale));
const half4 input = inTexture.read(uint2(x, y), gid.z);
outTexture.write(input, gid.xy, gid.z);
}
......@@ -24,6 +24,8 @@ struct MetalSliceParam {
short end1;
short end2;
short end3;
int iC;
int oC;
};
kernel void slice(texture2d_array<float, access::sample> inTexture [[texture(0)]],
......@@ -35,9 +37,14 @@ kernel void slice(texture2d_array<float, access::sample> inTexture [[texture(0)]
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 output;
for (int i = 0; i < 4; i++) {
int input_c = gid.z * 4 + i + param.start1;
int input_z = input_c / 4;
for (int i = 0; i < 4; ++i) {
int tmp = gid.z * 4 + i;
int output_c = tmp % param.oC;
int output_n = tmp / param.oC;
int c = output_c + param.start1;
tmp = output_n * param.iC + c;
int input_z = tmp / 4;
int input_c = tmp % 4;
const float4 input = inTexture.read(gid.xy, input_z);
output[i] = input[input_c % 4];
}
......@@ -52,12 +59,17 @@ kernel void slice_half(texture2d_array<half, access::sample> inTexture [[texture
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
float4 output;
for (int i = 0; i < 4; i++) {
int input_c = gid.z * 4 + i + param.start1;
int input_z = input_c / 4;
const float4 input = float4(inTexture.read(gid.xy, input_z));
half4 output;
for (int i = 0; i < 4; ++i) {
int tmp = gid.z * 4 + i;
int output_c = tmp % param.oC;
int output_n = tmp / param.oC;
int c = output_c + param.start1;
tmp = output_n * param.iC + c;
int input_z = tmp / 4;
int input_c = tmp % 4;
const half4 input = inTexture.read(gid.xy, input_z);
output[i] = input[input_c % 4];
}
outTexture.write(half4(output), gid.xy, gid.z);
outTexture.write(output, gid.xy, gid.z);
}
......@@ -534,7 +534,7 @@ public extension MTLTexture {
for c in 0..<4{
for h in 0..<dim.h {
for w in 0..<dim.w {
if (s * 4 + c) < dim.c {
if (s * 4 + c) < (dim.c * dim.n) {
let textureValue = textureArray[dim.w * dim.h * 4 * s + h * dim.w * 4 + w * 4 + c]
output.append(textureValue)
}
......
......@@ -34,7 +34,7 @@ class ReshapeKernel<P: PrecisionProtocol>: Kernel, Computable{
required init(device: MTLDevice, param: ReshapeParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, computePrecision: GlobalConfig.shared.computePrecision)
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
......
......@@ -23,6 +23,8 @@ public struct SliceMetalParam {
let end1: Int16
let end2: Int16
let end3: Int16
let iC: Int32
let oC: Int32
}
class SliceKernel<P: PrecisionProtocol>: Kernel, Computable {
......@@ -40,7 +42,7 @@ class SliceKernel<P: PrecisionProtocol>: Kernel, Computable {
required init(device: MTLDevice, param: SliceParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
try param.output.initTexture(device: device, inTranspose: [0, 2, 3, 1], computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
......@@ -60,7 +62,11 @@ class SliceKernel<P: PrecisionProtocol>: Kernel, Computable {
let end1 = ranges[1][1]
let end2 = ranges[2][1]
let end3 = ranges[3][1]
metalParam = SliceMetalParam.init(start0: start0, start1: start1, start2: start2, start3: start3, end0: end0, end1: end1, end2: end2, end3: end3)
let iC = Int32(param.input.tensorDim[1])
let oC = Int32(param.output.tensorDim[1])
metalParam = SliceMetalParam.init(start0: start0, start1: start1, start2: start2, start3: start3, end0: end0, end1: end1, end2: end2, end3: end3, iC: iC, oC: oC)
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "slice", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册