提交 0bb2c1cd 编写于 作者: D dolphin8 提交者: GitHub

Merge pull request #975 from dolphin8/metal

x
......@@ -21,6 +21,7 @@
4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; };
4AA1EAA4214A295C00D0F791 /* Split.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA3214A295C00D0F791 /* Split.inc.metal */; };
4AA1EAA6214B5F6800D0F791 /* Shape.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA5214B5F6800D0F791 /* Shape.metal */; };
4AA1EAA8214B7AFB00D0F791 /* BilinearInterp.inc.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA7214B7AFB00D0F791 /* BilinearInterp.inc.metal */; };
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; };
4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; };
4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; };
......@@ -134,6 +135,7 @@
4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = "<group>"; };
4AA1EAA3214A295C00D0F791 /* Split.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Split.inc.metal; sourceTree = "<group>"; };
4AA1EAA5214B5F6800D0F791 /* Shape.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Shape.metal; sourceTree = "<group>"; };
4AA1EAA7214B7AFB00D0F791 /* BilinearInterp.inc.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BilinearInterp.inc.metal; sourceTree = "<group>"; };
4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = "<group>"; };
4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = "<group>"; };
4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = "<group>"; };
......@@ -456,6 +458,7 @@
4AA1EA8F214664CD00D0F791 /* Split.metal */,
4AA1EAA3214A295C00D0F791 /* Split.inc.metal */,
4AA1EA892146631C00D0F791 /* BilinearInterp.metal */,
4AA1EAA7214B7AFB00D0F791 /* BilinearInterp.inc.metal */,
4AF9287821341661005B6C3A /* Softmax.metal */,
FCEB6849212F00DB00D2448E /* PreluKernel.metal */,
FCDDC6C9212FDF6800E5EF74 /* BatchNormKernel.metal */,
......@@ -609,6 +612,7 @@
4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */,
FC33B0F02147659000714A93 /* MobileNet.swift in Sources */,
FCEB684C212F093800D2448E /* PreluOp.swift in Sources */,
4AA1EAA8214B7AFB00D0F791 /* BilinearInterp.inc.metal in Sources */,
FCA67CD92138287B00BD58AA /* ConvBNReluKernel.metal in Sources */,
FC60DB8920E9AAA500FF203F /* MetalExtension.swift in Sources */,
FCEBC0F620F1FE120099DBAF /* ConvAddBatchNormReluKernel.swift in Sources */,
......@@ -833,7 +837,7 @@
"$(PROJECT_DIR)/paddle-mobile/CPU",
);
MACH_O_TYPE = mh_dylib;
MTL_LANGUAGE_REVISION = UseDeploymentTarget;
MTL_LANGUAGE_REVISION = Metal12;
PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
......@@ -869,7 +873,7 @@
"$(PROJECT_DIR)/paddle-mobile/CPU",
);
MACH_O_TYPE = mh_dylib;
MTL_LANGUAGE_REVISION = UseDeploymentTarget;
MTL_LANGUAGE_REVISION = Metal12;
PRODUCT_BUNDLE_IDENTIFIER = "orange.paddle-mobile";
PRODUCT_NAME = "$(TARGET_NAME:c99extidentifier)";
SKIP_INSTALL = YES;
......
......@@ -38,7 +38,7 @@ class BilinearInterpKernel<P: PrecisionType>: Kernel, Computable{
required init(device: MTLDevice, param: BilinearInterpParam<P>) {
param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: computePrecision)
if computePrecision == .Float32 {
super.init(device: device, inFunctionName: "bilinear_interp")
super.init(device: device, inFunctionName: "bilinear_interp_float")
} else if computePrecision == .Float16 {
super.init(device: device, inFunctionName: "bilinear_interp_half")
} else {
......
......@@ -32,7 +32,7 @@ class SplitKernel<P: PrecisionType>: Kernel, Computable{
for i in 0..<param.outputList.count {
encoder.setTexture(param.outputList[i].metalTexture, index: i + 1)
}
encoder.setBytes(&smp, length: MemoryLayout<BoxcoderMetalParam>.size, index: 0)
encoder.setBytes(&smp, length: MemoryLayout<SplitMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.input.metalTexture)
encoder.endEncoding()
}
......
#ifdef P
#define CONCAT2(a, b) a ## b
#define CONCAT2_(a, b) a ## _ ## b
#define FUNC(f, p) CONCAT2_(f, p)
#define VECTOR(p, n) CONCAT2(p, n)
kernel void FUNC(bilinear_interp, P)(texture2d_array<P, access::read> input [[texture(0)]],
texture2d_array<P, access::write> output [[texture(1)]],
constant bilinear_interp_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z);
} else {
float w = gid.x * pm.ratio_w;
float h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
P w1lambda = w - w0, h1lambda = h - h0;
P w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0;
VECTOR(P, 4) r0 = input.read(uint2(w0, h0), gid.z);
VECTOR(P, 4) r1 = input.read(uint2(w1, h0), gid.z);
VECTOR(P, 4) r2 = input.read(uint2(w0, h1), gid.z);
VECTOR(P, 4) r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3);
}
output.write(r, gid.xy, gid.z);
}
#endif
......@@ -22,54 +22,10 @@ struct bilinear_interp_param {
float ratio_w;
};
kernel void bilinear_interp(texture2d_array<float, access::read> input [[texture(0)]],
texture2d_array<float, access::write> output [[texture(1)]],
constant bilinear_interp_param & pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
float4 r;
if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
r = input.read(gid.xy, gid.z);
} else {
float w = gid.x * pm.ratio_w;
float h = gid.y * pm.ratio_h;
uint w0 = w, h0 = h;
uint w1 = w0 + 1, h1 = h0 + 1;
float w1lambda = w - w0, h1lambda = h - h0;
float w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
if (w1 >= input.get_width()) w1 = w0;
if (h1 >= input.get_height()) h1 = h0;
float4 r0 = input.read(uint2(w0, h0), gid.z);
float4 r1 = input.read(uint2(w1, h0), gid.z);
float4 r2 = input.read(uint2(w0, h1), gid.z);
float4 r3 = input.read(uint2(w1, h1), gid.z);
r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3);
}
output.write(r, gid.xy, gid.z);
}
#define P float
#include "BilinearInterp.inc.metal"
#undef P
//kernel void bilinear_interp_half(texture2d_array<half, access::read> input [[texture(0)]],
// texture2d_array<half, access::write> output [[texture(1)]],
// constant bilinear_interp_param & pm [[buffer(0)]],
// uint3 gid [[thread_position_in_grid]]) {
//
// half4 r;
// if ((input.get_width() == output.get_width()) && (input.get_height() == output.get_height())) {
// r = input.read(gid.xy, gid.z);
// } else {
// half w = gid.x * pm.ratio_w;
// half h = gid.y * pm.ratio_h;
// uint w0 = w, h0 = h;
// uint w1 = w0 + 1, h1 = h0 + 1;
// half w1lambda = w - w0, h1lambda = h - h0;
// half w2lambda = 1.0 - w1lambda, h2lambda = 1.0 - h1lambda;
// if (w1 >= input.get_width()) w1 = w0;
// if (h1 >= input.get_height()) h1 = h0;
// half4 r0 = input.read(uint2(w0, h0), gid.z);
// half4 r1 = input.read(uint2(w1, h0), gid.z);
// half4 r2 = input.read(uint2(w0, h1), gid.z);
// half4 r3 = input.read(uint2(w1, h1), gid.z);
// r = h2lambda * (w2lambda * r0 + w1lambda * r1) + h1lambda * (w2lambda * r2 + w1lambda * r3);
// }
// output.write(r, gid.xy, gid.z);
// output.write(r, gid.xy, gid.z);
//}
#define P half
#include "BilinearInterp.inc.metal"
#undef P
......@@ -20,70 +20,89 @@
#define VV normal
#endif
#if V == VY
kernel void FUNC(split, R, N, VV, P)(texture2d_array<P, access::read> input [[texture(0)]],
texture2d_array<P, access::write> out1 [[texture(1)]],
texture2d_array<P, access::write> out2 [[texture(2)]],
#if N >= 3
texture2d_array<P, access::write> out3 [[texture(3)]],
#endif
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::write> out4 [[texture(4)]],
#endif
#endif // N >= 4
constant SplitParam &sp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) r = input.read(gid.xy, gid.z);
#if V == VY
int y = gid.y - sp.offset;
if (y < sp.vdim[0]) {
out1.write(r, gid.xy, gid.z);
} else {
y -= sp.vdim[0];
if (y < sp.vdim[1]) {
out2.write(r, uint2(gid.x, y), gid.z);
} else {
return;
}
y -= sp.vdim[0];
if (y < sp.vdim[1]) {
out2.write(r, uint2(gid.x, y), gid.z);
return;
}
#if N >= 3
y -= sp.vdim[1];
if (y < sp.vdim[2]) {
out3.write(r, uint2(gid.x, y), gid.z);
} else {
y -= sp.vdim[1];
if (y < sp.vdim[2]) {
out3.write(r, uint2(gid.x, y), gid.z);
return;
}
#endif // N >= 3
#if N >= 4
y -= sp.vdim[2];
if (y < sp.vdim[3]) {
out4.write(r, uint2(gid.x, y), gid.z);
}
#endif
}
#endif
}
y -= sp.vdim[2];
if (y < sp.vdim[3]) {
out4.write(r, uint2(gid.x, y), gid.z);
return;
}
#elif V == VX
#endif // N >= 4
}
#endif // V == VY
#if V == VX
kernel void FUNC(split, R, N, VV, P)(texture2d_array<P, access::read> input [[texture(0)]],
texture2d_array<P, access::write> out1 [[texture(1)]],
texture2d_array<P, access::write> out2 [[texture(2)]],
#if N >= 3
texture2d_array<P, access::write> out3 [[texture(3)]],
#endif // N >= 3
#if N >= 4
texture2d_array<P, access::write> out4 [[texture(4)]],
#endif // N >= 4
constant SplitParam &sp [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
VECTOR(P, 4) r = input.read(gid.xy, gid.z);
int x = gid.x;
if (x < sp.vdim[0]) {
out1.write(r, gid.xy, gid.z);
} else {
x -= sp.vdim[0];
if (x < sp.vdim[1]) {
out2.write(r, uint2(x, gid.y), gid.z);
} else {
return;
}
x -= sp.vdim[0];
if (x < sp.vdim[1]) {
out2.write(r, uint2(x, gid.y), gid.z);
return;
}
#if N >= 3
x -= sp.vdim[1];
if (x < sp.vdim[2]) {
out3.write(r, uint2(x, gid.y), gid.z);
} else {
x -= sp.vdim[1];
if (x < sp.vdim[2]) {
out3.write(r, uint2(x, gid.y), gid.z);
return;
}
#endif // N >= 3
#if N >= 4
x -= sp.vdim[2];
if (x < sp.vdim[3]) {
out4.write(r, uint2(x, gid.y), gid.z);
}
#endif
}
#endif
}
x -= sp.vdim[2];
if (x < sp.vdim[3]) {
out4.write(r, uint2(x, gid.y), gid.z);
return;
}
#else
#endif
#endif // N >= 4
}
#endif // V == VX
#undef VV
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册