提交 857ffcbc 编写于 作者: N NazgulLee 提交者: Yanzhan Yang

add reshape2, transpose2, relu6 and scale operator (#1628)

上级 7deff6de
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
/* Begin PBXBuildFile section */ /* Begin PBXBuildFile section */
165F38D72276F4C00088E29F /* ConvAddReluMetal.metal in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */; }; 165F38D72276F4C00088E29F /* ConvAddReluMetal.metal in Sources */ = {isa = PBXBuildFile; fileRef = 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */; };
5CCC0CF6759710BAFE999DB7 /* Pods_paddle_mobile_metallib.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */; }; 5CCC0CF6759710BAFE999DB7 /* Pods_paddle_mobile_metallib.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */; };
A74CAFF0228D9B9B000BBFCA /* ScaleKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */; };
FCC15DE5221E69E100DC3CB2 /* ReluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBC221E69DD00DC3CB2 /* ReluKernel.metal */; }; FCC15DE5221E69E100DC3CB2 /* ReluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBC221E69DD00DC3CB2 /* ReluKernel.metal */; };
FCC15DE6221E69E100DC3CB2 /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBD221E69DD00DC3CB2 /* BoxCoder.metal */; }; FCC15DE6221E69E100DC3CB2 /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBD221E69DD00DC3CB2 /* BoxCoder.metal */; };
FCC15DE7221E69E100DC3CB2 /* ConvAddBNReluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBE221E69DD00DC3CB2 /* ConvAddBNReluKernel.metal */; }; FCC15DE7221E69E100DC3CB2 /* ConvAddBNReluKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = FCC15DBE221E69DD00DC3CB2 /* ConvAddBNReluKernel.metal */; };
...@@ -56,6 +57,7 @@ ...@@ -56,6 +57,7 @@
165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvAddReluMetal.metal; sourceTree = "<group>"; }; 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.metal; path = ConvAddReluMetal.metal; sourceTree = "<group>"; };
33511F4FF7FE78679BE12DC0 /* Pods-paddle-mobile-metallib.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.release.xcconfig"; sourceTree = "<group>"; }; 33511F4FF7FE78679BE12DC0 /* Pods-paddle-mobile-metallib.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.release.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.release.xcconfig"; sourceTree = "<group>"; };
5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile_metallib.framework; sourceTree = BUILT_PRODUCTS_DIR; }; 5D9D330A035906298947080B /* Pods_paddle_mobile_metallib.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paddle_mobile_metallib.framework; sourceTree = BUILT_PRODUCTS_DIR; };
A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ScaleKernel.metal; sourceTree = "<group>"; };
C6D31B9F9533810DBCA6B28D /* Pods-paddle-mobile-metallib.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.debug.xcconfig"; sourceTree = "<group>"; }; C6D31B9F9533810DBCA6B28D /* Pods-paddle-mobile-metallib.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paddle-mobile-metallib.debug.xcconfig"; path = "../Pods/Target Support Files/Pods-paddle-mobile-metallib/Pods-paddle-mobile-metallib.debug.xcconfig"; sourceTree = "<group>"; };
FCC15D60221E66DE00DC3CB2 /* paddle-mobile-metallib.metallib */ = {isa = PBXFileReference; explicitFileType = "archive.metal-library"; includeInIndex = 0; path = "paddle-mobile-metallib.metallib"; sourceTree = BUILT_PRODUCTS_DIR; }; FCC15D60221E66DE00DC3CB2 /* paddle-mobile-metallib.metallib */ = {isa = PBXFileReference; explicitFileType = "archive.metal-library"; includeInIndex = 0; path = "paddle-mobile-metallib.metallib"; sourceTree = BUILT_PRODUCTS_DIR; };
FCC15DBC221E69DD00DC3CB2 /* ReluKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ReluKernel.metal; sourceTree = "<group>"; }; FCC15DBC221E69DD00DC3CB2 /* ReluKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ReluKernel.metal; sourceTree = "<group>"; };
...@@ -192,6 +194,7 @@ ...@@ -192,6 +194,7 @@
FCC15DBF221E69DD00DC3CB2 /* Split.metal */, FCC15DBF221E69DD00DC3CB2 /* Split.metal */,
FCC15DC9221E69DE00DC3CB2 /* TransposeKernel.inc.metal */, FCC15DC9221E69DE00DC3CB2 /* TransposeKernel.inc.metal */,
FCC15DDA221E69E000DC3CB2 /* TransposeKernel.metal */, FCC15DDA221E69E000DC3CB2 /* TransposeKernel.metal */,
A74CAFEF228D9B9B000BBFCA /* ScaleKernel.metal */,
165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */, 165F38D62276F4C00088E29F /* ConvAddReluMetal.metal */,
); );
path = "paddle-mobile-metallib"; path = "paddle-mobile-metallib";
...@@ -284,6 +287,7 @@ ...@@ -284,6 +287,7 @@
FCC15DE8221E69E100DC3CB2 /* Split.metal in Sources */, FCC15DE8221E69E100DC3CB2 /* Split.metal in Sources */,
FCC15DF2221E69E100DC3CB2 /* TransposeKernel.inc.metal in Sources */, FCC15DF2221E69E100DC3CB2 /* TransposeKernel.inc.metal in Sources */,
FCC15DE7221E69E100DC3CB2 /* ConvAddBNReluKernel.metal in Sources */, FCC15DE7221E69E100DC3CB2 /* ConvAddBNReluKernel.metal in Sources */,
A74CAFF0228D9B9B000BBFCA /* ScaleKernel.metal in Sources */,
FCC15E04221E69E100DC3CB2 /* ConvAddPrelu.inc.metal in Sources */, FCC15E04221E69E100DC3CB2 /* ConvAddPrelu.inc.metal in Sources */,
FCC15DF9221E69E100DC3CB2 /* Kernels.metal in Sources */, FCC15DF9221E69E100DC3CB2 /* Kernels.metal in Sources */,
FCC15DF0221E69E100DC3CB2 /* PreluKernel.metal in Sources */, FCC15DF0221E69E100DC3CB2 /* PreluKernel.metal in Sources */,
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
#include <metal_stdlib> #include <metal_stdlib>
using namespace metal; using namespace metal;
struct Relu6Param {
float threshold;
};
kernel void relu_half(texture2d_array<half, access::sample> inTexture [[texture(0)]], kernel void relu_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]], texture2d_array<half, access::write> outTexture [[texture(1)]],
...@@ -39,3 +42,31 @@ kernel void relu(texture2d_array<float, access::sample> inTexture [[texture(0)]] ...@@ -39,3 +42,31 @@ kernel void relu(texture2d_array<float, access::sample> inTexture [[texture(0)]]
const float4 relu = fmax((float4)input, 0.0); const float4 relu = fmax((float4)input, 0.0);
outTexture.write(float4(relu), gid.xy, gid.z); outTexture.write(float4(relu), gid.xy, gid.z);
} }
kernel void relu6_half(texture2d_array<half, access::sample> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant Relu6Param &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
const float threshold = pm.threshold;
const float4 relu = fmin(fmax((float4)input, 0.0), threshold);
outTexture.write(half4(relu), gid.xy, gid.z);
}
kernel void relu6(texture2d_array<float, access::sample> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant Relu6Param &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const float4 input = inTexture.read(gid.xy, gid.z);
const float threshold = pm.threshold;
const float4 relu = fmin(fmax((float4)input, 0.0), threshold);
outTexture.write(float4(relu), gid.xy, gid.z);
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <metal_stdlib>
#include "Common.metal"
using namespace metal;
struct ScaleParam {
float scale;
float abias;
};
kernel void scale_before_bias_float(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant ScaleParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const float4 input = inTexture.read(gid.xy, gid.z);
const float scale = pm.scale;
const float abias = pm.abias;
const float4 output = scale * input + abias;
outTexture.write(output, gid.xy, gid.z);
}
kernel void scale_after_bias_float(texture2d_array<float, access::read> inTexture [[texture(0)]],
texture2d_array<float, access::write> outTexture [[texture(1)]],
constant ScaleParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const float4 input = inTexture.read(gid.xy, gid.z);
const float scale = pm.scale;
const float abias = pm.abias;
const float4 output = scale * (input + abias);
outTexture.write(output, gid.xy, gid.z);
}
kernel void scale_before_bias_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant ScaleParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
const float scale = pm.scale;
const float abias = pm.abias;
const float4 output = scale * (float4)input + abias;
outTexture.write(half4(output), gid.xy, gid.z);
}
kernel void scale_after_bias_half(texture2d_array<half, access::read> inTexture [[texture(0)]],
texture2d_array<half, access::write> outTexture [[texture(1)]],
constant ScaleParam &pm [[buffer(0)]],
uint3 gid [[thread_position_in_grid]]) {
if (gid.x >= outTexture.get_width() ||
gid.y >= outTexture.get_height() ||
gid.z >= outTexture.get_array_size()) return;
constexpr sampler s(coord::pixel, filter::nearest, address::clamp_to_zero);
const half4 input = inTexture.read(gid.xy, gid.z);
const float scale = pm.scale;
const float abias = pm.abias;
const float4 output = scale * ((float4)input + abias);
outTexture.write(half4(output), gid.xy, gid.z);
}
...@@ -31,7 +31,7 @@ kernel void FUNC(transpose, R, P)(texture2d_array<P, access::read> inTexture [[t ...@@ -31,7 +31,7 @@ kernel void FUNC(transpose, R, P)(texture2d_array<P, access::read> inTexture [[t
for (int n = 0; n < 4; n++) { for (int n = 0; n < 4; n++) {
oxyzn[3] = n; oxyzn[3] = n;
#if R == 4 #if R == 4
xyzn2abcd_4(pm.oC, oxyzn, iabcd); xyzn2abcd_4(pm.oC, oxyzn, oabcd);
#endif // R == 4 #endif // R == 4
#if R == 3 #if R == 3
xyzn2abcd_3(oxyzn, oabcd); xyzn2abcd_3(oxyzn, oabcd);
......
...@@ -19,6 +19,10 @@ ...@@ -19,6 +19,10 @@
4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; }; 4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA932146661500D0F791 /* ShapeKernel.swift */; };
4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; }; 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; };
4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; }; 4AA1EAA2214912CD00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */; };
A73DC749227F1C7A001EB663 /* ScaleOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = A73DC748227F1C7A001EB663 /* ScaleOp.swift */; };
A73DC74B227F1EDE001EB663 /* ScaleOpKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = A73DC74A227F1EDE001EB663 /* ScaleOpKernel.swift */; };
A7F26FDA22842EF200365D47 /* Relu6Op.swift in Sources */ = {isa = PBXBuildFile; fileRef = A7F26FD922842EF200365D47 /* Relu6Op.swift */; };
A7F26FDC2284301500365D47 /* Relu6Kernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = A7F26FDB2284301500365D47 /* Relu6Kernel.swift */; };
C28FE02F21BA68C00054EFAC /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02C21BA68C00054EFAC /* Metal.framework */; settings = {ATTRIBUTES = (Weak, ); }; }; C28FE02F21BA68C00054EFAC /* Metal.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02C21BA68C00054EFAC /* Metal.framework */; settings = {ATTRIBUTES = (Weak, ); }; };
C28FE03021BA68C00054EFAC /* MetalPerformanceShaders.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02D21BA68C00054EFAC /* MetalPerformanceShaders.framework */; settings = {ATTRIBUTES = (Weak, ); }; }; C28FE03021BA68C00054EFAC /* MetalPerformanceShaders.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02D21BA68C00054EFAC /* MetalPerformanceShaders.framework */; settings = {ATTRIBUTES = (Weak, ); }; };
C28FE03121BA68C00054EFAC /* MetalKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02E21BA68C00054EFAC /* MetalKit.framework */; settings = {ATTRIBUTES = (Weak, ); }; }; C28FE03121BA68C00054EFAC /* MetalKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = C28FE02E21BA68C00054EFAC /* MetalKit.framework */; settings = {ATTRIBUTES = (Weak, ); }; };
...@@ -115,6 +119,10 @@ ...@@ -115,6 +119,10 @@
4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = "<group>"; }; 4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = "<group>"; };
4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = "<group>"; }; 4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = "<group>"; };
4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = "<group>"; }; 4AA1EAA1214912CC00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = "<group>"; };
A73DC748227F1C7A001EB663 /* ScaleOp.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ScaleOp.swift; sourceTree = "<group>"; };
A73DC74A227F1EDE001EB663 /* ScaleOpKernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ScaleOpKernel.swift; sourceTree = "<group>"; };
A7F26FD922842EF200365D47 /* Relu6Op.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Relu6Op.swift; sourceTree = "<group>"; };
A7F26FDB2284301500365D47 /* Relu6Kernel.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = Relu6Kernel.swift; sourceTree = "<group>"; };
C28FE02C21BA68C00054EFAC /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; }; C28FE02C21BA68C00054EFAC /* Metal.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Metal.framework; path = System/Library/Frameworks/Metal.framework; sourceTree = SDKROOT; };
C28FE02D21BA68C00054EFAC /* MetalPerformanceShaders.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShaders.framework; path = System/Library/Frameworks/MetalPerformanceShaders.framework; sourceTree = SDKROOT; }; C28FE02D21BA68C00054EFAC /* MetalPerformanceShaders.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShaders.framework; path = System/Library/Frameworks/MetalPerformanceShaders.framework; sourceTree = SDKROOT; };
C28FE02E21BA68C00054EFAC /* MetalKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalKit.framework; path = System/Library/Frameworks/MetalKit.framework; sourceTree = SDKROOT; }; C28FE02E21BA68C00054EFAC /* MetalKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalKit.framework; path = System/Library/Frameworks/MetalKit.framework; sourceTree = SDKROOT; };
...@@ -328,6 +336,8 @@ ...@@ -328,6 +336,8 @@
FCE3A1A82153DE5100C37CDE /* ConvAddAddPreluOp.swift */, FCE3A1A82153DE5100C37CDE /* ConvAddAddPreluOp.swift */,
FCE3A1AC2153E8BA00C37CDE /* ElementwiseAddPreluOp.swift */, FCE3A1AC2153E8BA00C37CDE /* ElementwiseAddPreluOp.swift */,
165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */, 165F38D22276CDEA0088E29F /* ConvAddReluOp.swift */,
A73DC748227F1C7A001EB663 /* ScaleOp.swift */,
A7F26FD922842EF200365D47 /* Relu6Op.swift */,
); );
path = Operators; path = Operators;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -383,6 +393,8 @@ ...@@ -383,6 +393,8 @@
FC2BFD4521DF685F00C262B2 /* Scale.swift */, FC2BFD4521DF685F00C262B2 /* Scale.swift */,
FCB40E5821E0DCAB0075EC91 /* FetchKernel.swift */, FCB40E5821E0DCAB0075EC91 /* FetchKernel.swift */,
165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */, 165F38D42276CE7D0088E29F /* ConvAddReluKernel.swift */,
A73DC74A227F1EDE001EB663 /* ScaleOpKernel.swift */,
A7F26FDB2284301500365D47 /* Relu6Kernel.swift */,
); );
path = Kernels; path = Kernels;
sourceTree = "<group>"; sourceTree = "<group>";
...@@ -530,6 +542,7 @@ ...@@ -530,6 +542,7 @@
files = ( files = (
FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */, FC9D038020E22FBB000F735A /* FeedOp.swift in Sources */,
FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */, FC039B9F20E11CB20081E9F8 /* Tensor.swift in Sources */,
A73DC74B227F1EDE001EB663 /* ScaleOpKernel.swift in Sources */,
4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */, 4AA1EA942146661500D0F791 /* ShapeKernel.swift in Sources */,
FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */, FC0E2DBC20EE45FE009C1FAC /* ConvKernel.swift in Sources */,
FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */, FC039BAA20E11CBC0081E9F8 /* ElementwiseAddOp.swift in Sources */,
...@@ -594,6 +607,7 @@ ...@@ -594,6 +607,7 @@
FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */, FC0E2DBE20EE460D009C1FAC /* BatchNormKernel.swift in Sources */,
FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */, FC039BAB20E11CBC0081E9F8 /* Operator.swift in Sources */,
FCD04E6A20F319EC0007374F /* SoftmaxOp.swift in Sources */, FCD04E6A20F319EC0007374F /* SoftmaxOp.swift in Sources */,
A7F26FDA22842EF200365D47 /* Relu6Op.swift in Sources */,
FCBCCC612122FBDF00D94F7E /* PriorBoxKernel.swift in Sources */, FCBCCC612122FBDF00D94F7E /* PriorBoxKernel.swift in Sources */,
FCBCCC5F2122FB3B00D94F7E /* PriorBoxOp.swift in Sources */, FCBCCC5F2122FB3B00D94F7E /* PriorBoxOp.swift in Sources */,
FC9D038220E2312E000F735A /* FetchOp.swift in Sources */, FC9D038220E2312E000F735A /* FetchOp.swift in Sources */,
...@@ -607,6 +621,7 @@ ...@@ -607,6 +621,7 @@
FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */, FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */,
FCE3A1AD2153E8BA00C37CDE /* ElementwiseAddPreluOp.swift in Sources */, FCE3A1AD2153E8BA00C37CDE /* ElementwiseAddPreluOp.swift in Sources */,
FC039BC020E11CC20081E9F8 /* PMBlockDesc.swift in Sources */, FC039BC020E11CC20081E9F8 /* PMBlockDesc.swift in Sources */,
A7F26FDC2284301500365D47 /* Relu6Kernel.swift in Sources */,
FCD04E6820F315020007374F /* PoolKernel.swift in Sources */, FCD04E6820F315020007374F /* PoolKernel.swift in Sources */,
FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */, FC039BAD20E11CBC0081E9F8 /* ReluOp.swift in Sources */,
FCBCCC572122F41300D94F7E /* DwConvBNReluOp.swift in Sources */, FCBCCC572122F41300D94F7E /* DwConvBNReluOp.swift in Sources */,
...@@ -615,6 +630,7 @@ ...@@ -615,6 +630,7 @@
4AA1EA88214662BD00D0F791 /* BilinearInterpKernel.swift in Sources */, 4AA1EA88214662BD00D0F791 /* BilinearInterpKernel.swift in Sources */,
FC2BFD4621DF685F00C262B2 /* Scale.swift in Sources */, FC2BFD4621DF685F00C262B2 /* Scale.swift in Sources */,
FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */, FC039B9720E11C9A0081E9F8 /* Extensions.swift in Sources */,
A73DC749227F1C7A001EB663 /* ScaleOp.swift in Sources */,
); );
runOnlyForDeploymentPostprocessing = 0; runOnlyForDeploymentPostprocessing = 0;
}; };
......
...@@ -69,7 +69,13 @@ class OpCreator<P: PrecisionProtocol> { ...@@ -69,7 +69,13 @@ class OpCreator<P: PrecisionProtocol> {
gConvAddAddPreluType : ConvAddAddPreluOp<P>.creat, gConvAddAddPreluType : ConvAddAddPreluOp<P>.creat,
gElementwiseAddPreluType : ElementwiseAddPreluOp<P>.creat, gElementwiseAddPreluType : ElementwiseAddPreluOp<P>.creat,
gFusionConvAddType : ConvAddOp<P>.creat, gFusionConvAddType : ConvAddOp<P>.creat,
gConvAddReluType : ConvAddReluOp<P>.creat] gConvAddReluType : ConvAddReluOp<P>.creat,
gReshape2Type : ReshapeOp<P>.creat,
gTranspose2Type : TransposeOp<P>.creat,
gScaleType : ScaleOp<P>.creat,
gRelu6Type : Relu6Op<P>.creat
]
private init(){} private init(){}
} }
...@@ -181,6 +181,10 @@ let gConvAddPreluType = "conv_add_prelu" ...@@ -181,6 +181,10 @@ let gConvAddPreluType = "conv_add_prelu"
let gConvAddAddPreluType = "conv_add_add_prelu" let gConvAddAddPreluType = "conv_add_add_prelu"
let gElementwiseAddPreluType = "elementwise_add_prelu" let gElementwiseAddPreluType = "elementwise_add_prelu"
let gFusionConvAddType = "fusion_conv_add" let gFusionConvAddType = "fusion_conv_add"
let gReshape2Type = "reshape2"
let gTranspose2Type = "transpose2"
let gScaleType = "scale"
let gRelu6Type = "relu6"
let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]), let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Output"]),
...@@ -211,5 +215,9 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out ...@@ -211,5 +215,9 @@ let opInfos = [gConvType : (inputs: ["Input"], outputs: ["Out
gConvAddPreluType : (inputs: ["Input"], outputs: ["Out"]), gConvAddPreluType : (inputs: ["Input"], outputs: ["Out"]),
gConvAddAddPreluType : (inputs: ["Input"], outputs: ["Out"]), gConvAddAddPreluType : (inputs: ["Input"], outputs: ["Out"]),
gElementwiseAddPreluType : (inputs: ["X"], outputs: ["Out"]), gElementwiseAddPreluType : (inputs: ["X"], outputs: ["Out"]),
gFusionConvAddType : (inputs: ["Input"], outputs: ["Out"]) gFusionConvAddType : (inputs: ["Input"], outputs: ["Out"]),
gReshape2Type : (inputs: ["X"], outputs: ["Out"]),
gTranspose2Type : (inputs: ["X"], outputs: ["Out"]),
gScaleType : (inputs: ["X"], outputs: ["Out"]),
gRelu6Type : (inputs: ["X"], outputs: ["Out"]),
] ]
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct Relu6MetalParam {
let threshold: Float32
}
class Relu6Kernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: Relu6MetalParam
func compute(commandBuffer: MTLCommandBuffer, param: Relu6Param<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encode is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<Relu6MetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
required init(device: MTLDevice, param: Relu6Param<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
metalParam = Relu6MetalParam(threshold: param.threshold)
if GlobalConfig.shared.computePrecision == .Float32 {
super.init(device: device, inFunctionName: "relu6", initContext: initContext)
} else if GlobalConfig.shared.computePrecision == .Float16 {
super.init(device: device, inFunctionName: "relu6_half", initContext: initContext)
} else {
fatalError()
}
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
struct ScaleMetalParam {
let scale: Float32
let abias: Float32
}
class ScaleOpKernel<P: PrecisionProtocol>: Kernel, Computable{
var metalParam: ScaleMetalParam
required init(device: MTLDevice, param: ScaleParam<P>, initContext: InitContext) throws {
do {
try param.output.initTexture(device: device, inTranspose: param.input.transpose, computePrecision: GlobalConfig.shared.computePrecision)
} catch let error {
throw error
}
metalParam = ScaleMetalParam(scale: param.scale, abias: param.bias)
if GlobalConfig.shared.computePrecision == .Float32 {
if param.biasAfterScale {
super.init(device: device, inFunctionName: "scale_before_bias_float", initContext: initContext)
} else {
super.init(device: device, inFunctionName: "scale_after_bias_float", initContext: initContext)
}
} else if GlobalConfig.shared.computePrecision == .Float16 {
if param.biasAfterScale {
super.init(device: device, inFunctionName: "scale_before_bias_half", initContext: initContext)
} else {
super.init(device: device, inFunctionName: "scale_after_bias_half", initContext: initContext)
}
} else {
fatalError()
}
}
func compute(commandBuffer: MTLCommandBuffer, param: ScaleParam<P>) throws {
guard let encoder = commandBuffer.makeComputeCommandEncoder() else {
throw PaddleMobileError.predictError(message: " encoder is nil")
}
encoder.setTexture(param.input.metalTexture, index: 0)
encoder.setTexture(param.output.metalTexture, index: 1)
encoder.setBytes(&metalParam, length: MemoryLayout<PoolMetalParam>.size, index: 0)
encoder.dispatch(computePipline: pipline, outTexture: param.output.metalTexture)
encoder.endEncoding()
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class Relu6Param<P: PrecisionProtocol>: OpParam {
required init(opDesc: PMOpDesc, inScope: Scope) throws {
do {
input = try Relu6Param.inputX(inputs: opDesc.inputs, from: inScope)
output = try Relu6Param.outputOut(outputs: opDesc.outputs, from: inScope)
threshold = try Relu6Param.getAttr(key: "threshold", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let input: Texture
var output: Texture
let threshold: Float32
}
class Relu6Op<P: PrecisionProtocol>: Operator<Relu6Kernel<P>, Relu6Param<P>>, Runable, Creator, InferShaperable {
typealias OpType = Relu6Op<P>
func inferShape() {
para.output.dim = para.input.dim
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
print(para.output.metalTexture)
print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray())
}
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
import Foundation
class ScaleParam<P: PrecisionProtocol>: OpParam {
required init(opDesc: PMOpDesc, inScope: Scope) throws {
do {
input = try ScaleParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try ScaleParam.outputOut(outputs: opDesc.outputs, from: inScope)
scale = try ScaleParam.getAttr(key: "scale", attrs: opDesc.attrs)
bias = try ScaleParam.getAttr(key: "bias", attrs: opDesc.attrs)
biasAfterScale = try ScaleParam.getAttr(key: "bias_after_scale", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let input: Texture
var output: Texture
let scale: Float32
let bias: Float32
let biasAfterScale: Bool
}
class ScaleOp<P: PrecisionProtocol>: Operator<ScaleOpKernel<P>, ScaleParam<P>>, Runable, Creator, InferShaperable{
typealias OpType = ScaleOp<P>
func inferShape() {
para.output.dim = para.input.dim
}
func runImpl(device: MTLDevice, buffer: MTLCommandBuffer) throws {
do {
try kernel.compute(commandBuffer: buffer, param: para)
} catch let error {
throw error
}
}
func delogOutput() {
print(" \(type) output: ")
print(para.output.metalTexture)
print(para.output.metalTexture.toTensor(dim: (n: para.output.tensorDim[0], c: para.output.tensorDim[1], h: para.output.tensorDim[2], w: para.output.tensorDim[3])).strideArray())
}
}
...@@ -33,6 +33,10 @@ class PMOpDesc { ...@@ -33,6 +33,10 @@ class PMOpDesc {
return map return map
} }
guard let _ = opInfos[protoOpDesc.type] else {
fatalError()
}
inputs = creator(protoOpDesc.inputsArray as! [OpDesc_Var]) { inputs = creator(protoOpDesc.inputsArray as! [OpDesc_Var]) {
opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false opInfos[protoOpDesc.type]?.inputs.contains($0) ?? false
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册