From 694515e14494825914752a10a10a93d35a3edd57 Mon Sep 17 00:00:00 2001 From: dolphin8 Date: Wed, 12 Sep 2018 14:11:55 +0800 Subject: [PATCH] reshape --- .../paddle-mobile.xcodeproj/project.pbxproj | 6 +- .../Kernels/metal/ConcatKernel.metal | 17 +- .../Kernels/metal/ConcatKernel.metal.inc | 14 -- .../Kernels/metal/ReshapeKernel.metal | 220 ++++++++++-------- .../Kernels/metal/ReshapeKernel.metal.inc | 129 ++-------- 5 files changed, 153 insertions(+), 233 deletions(-) diff --git a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj index c384195cb4..42e4984dd9 100644 --- a/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj +++ b/metal/paddle-mobile/paddle-mobile.xcodeproj/project.pbxproj @@ -18,6 +18,7 @@ 4AA1EA962146665A00D0F791 /* FlattenKernel.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA952146665A00D0F791 /* FlattenKernel.swift */; }; 4AA1EA982146666500D0F791 /* FlattenOp.swift in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA972146666500D0F791 /* FlattenOp.swift */; }; 4AA1EA9E2148D6F900D0F791 /* ConcatKernel.metal.inc in Headers */ = {isa = PBXBuildFile; fileRef = 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */; }; + 4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.metal.inc in Sources */ = {isa = PBXBuildFile; fileRef = 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */; }; 4AF928772133F1DB005B6C3A /* BoxCoder.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928762133F1DB005B6C3A /* BoxCoder.metal */; }; 4AF9287921341661005B6C3A /* Softmax.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF9287821341661005B6C3A /* Softmax.metal */; }; 4AF928822135673D005B6C3A /* ConcatKernel.metal in Sources */ = {isa = PBXBuildFile; fileRef = 4AF928812135673D005B6C3A /* ConcatKernel.metal */; }; @@ -125,7 +126,8 @@ 4AA1EA932146661500D0F791 /* ShapeKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = ShapeKernel.swift; sourceTree = ""; }; 4AA1EA952146665A00D0F791 /* FlattenKernel.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenKernel.swift; sourceTree = ""; }; 4AA1EA972146666500D0F791 /* FlattenOp.swift */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.swift; path = FlattenOp.swift; sourceTree = ""; }; - 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.pascal; path = ConcatKernel.metal.inc; sourceTree = ""; }; + 4AA1EA9D2148D6F900D0F791 /* ConcatKernel.metal.inc */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ConcatKernel.metal.inc; sourceTree = ""; }; + 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */ = {isa = PBXFileReference; explicitFileType = sourcecode.metal; fileEncoding = 4; path = ReshapeKernel.metal.inc; sourceTree = ""; }; 4AF928762133F1DB005B6C3A /* BoxCoder.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = BoxCoder.metal; sourceTree = ""; }; 4AF9287821341661005B6C3A /* Softmax.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = Softmax.metal; sourceTree = ""; }; 4AF928812135673D005B6C3A /* ConcatKernel.metal */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.metal; path = ConcatKernel.metal; sourceTree = ""; }; @@ -450,6 +452,7 @@ FCDDC6CB212FDFDB00E5EF74 /* ReluKernel.metal */, FCDDC6CE212FE14700E5EF74 /* PriorBoxKernel.metal */, FCA3A1622132A4AC00084FE5 /* ReshapeKernel.metal */, + 4AA1EA9F2148DEEE00D0F791 /* ReshapeKernel.metal.inc */, FCA3A1642132A5EB00084FE5 /* Common.metal */, FCA67B1621364EF000BD58AA /* ConvTransposeKernel.metal */, FCA67CD42138272900BD58AA /* ConvAddMetal.metal */, @@ -651,6 +654,7 @@ FCBCCC67212306B000D94F7E /* ConcatOp.swift in Sources */, FCD04E6C20F31A280007374F /* SoftmaxKernel.swift in Sources */, FCEB684A212F00DB00D2448E /* PreluKernel.metal in Sources */, + 4AA1EAA02148DEEE00D0F791 /* ReshapeKernel.metal.inc in Sources */, FCDDC6CF212FE14700E5EF74 /* PriorBoxKernel.metal in Sources */, FC4CB74B20F12C30007C0C6D /* ProgramOptimize.swift in Sources */, FC5163F620EF556E00636C28 /* Texture2DTo2DArrayKernel.swift in Sources */, diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal index b874a74663..65a01182d2 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal @@ -25,33 +25,32 @@ struct ConcatParam { int32_t vdim[6]; }; - #define P float #define D 4 -#include "Concat.metal.inc" +#include "ConcatKernel.metal.inc" #undef D #define D 3 -#include "Concat.metal.inc" +#include "ConcatKernel.metal.inc" #undef D #define D 2 -#include "Concat.metal.inc" +#include "ConcatKernel.metal.inc" #undef D #define D 1 -#include "Concat.metal.inc" +#include "ConcatKernel.metal.inc" #undef D #undef P #define P half #define D 4 -#include "Concat.metal.inc" +#include "ConcatKernel.metal.inc" #undef D #define D 3 -#include "Concat.metal.inc" +#include "ConcatKernel.metal.inc" #undef D #define D 2 -#include "Concat.metal.inc" +#include "ConcatKernel.metal.inc" #undef D #define D 1 -#include "Concat.metal.inc" +#include "ConcatKernel.metal.inc" #undef D #undef P diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal.inc b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal.inc index 0e8a4e019a..b473ea6c6d 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal.inc +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ConcatKernel.metal.inc @@ -1,17 +1,3 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - #ifndef D #define D 4 #endif diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal index 399287da71..75337990c3 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal @@ -24,114 +24,128 @@ struct ReshapeParam { int32_t otrans[4]; }; -//kernel void reshape(texture2d_array inTexture [[texture(0)]], -// texture2d_array outTexture [[texture(1)]], -// constant ReshapeParam &rp [[buffer(0)]], -// uint3 gid [[thread_position_in_grid]]) { -// if (gid.x >= outTexture.get_width() || -// gid.y >= outTexture.get_height() || -// gid.z >= outTexture.get_array_size()) return; -// -// int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, oabcd[4], ixyzn[4]; -// ReshapeParam lrp = rp; -// int oC = lrp.odim[lrp.otrans[3]]; -// int iC = lrp.idim[lrp.itrans[3]]; -// int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3]; -// float4 r; -// for (int n = 0; n < 4; n++) { -// oxyzn[3] = n; -// -// //4 (gid.x gid.y, gid.z, 0~4) -// xyzn2abcd(oC, oxyzn, oabcd); -// int tabcd[4]; -// invtrans(lrp.otrans, oabcd, tabcd); -// int index = abcd2index(lrp.odim, tabcd); -// if (index < count) { -// int c = index % 4; -// -// int temp0 = index % (inTexture.get_array_size() * 4); -// int slice = temp0 / 4; -// -// int temp1 = index % (inTexture.get_array_size() * 4 * lrp.idim[2]); -// int w = temp1 / (inTexture.get_array_size() * 4); -// -// int h = index / (inTexture.get_array_size() * 4 * lrp.idim[2]); -// -//// index2abcd(lrp.idim, index, tabcd); -//// abcd2xyzn(iC, tabcd, ixyzn); -// r[n] = inTexture.read(uint2(w, h), slice)[c]; -// } else { -// r[n] = 0; -// } -// } -// outTexture.write(r, gid.xy, gid.z); -//} +#define P float +#define DIN 4 +#define DOUT 4 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 3 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 2 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 1 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#undef DIN +#define DIN 3 +#define DOUT 4 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 3 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 2 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 1 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#undef DIN +#define DIN 2 +#define DOUT 4 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 3 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 2 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 1 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#undef DIN +#define DIN 1 +#define DOUT 4 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 3 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 2 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 1 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#undef DIN -kernel void reshape(texture2d_array inTexture [[texture(0)]], - texture2d_array outTexture [[texture(1)]], - constant ReshapeParam &rp [[buffer(0)]], - uint3 gid [[thread_position_in_grid]]) { - if (gid.x >= outTexture.get_width() || - gid.y >= outTexture.get_height() || - gid.z >= outTexture.get_array_size()) return; +#undef P - int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, oabcd[4], ixyzn[4], iabcd[4]; - ReshapeParam lrp = rp; - int oC = lrp.odim[lrp.otrans[3]]; - int iC = lrp.idim[lrp.itrans[3]]; - int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3]; - float4 r; - for (int n = 0; n < 4; n++) { - oxyzn[3] = n; - xyzn2abcd(oC, oxyzn, oabcd); - int tabcd[4]; - invtrans(lrp.otrans, oabcd, tabcd); - int index = abcd2index(lrp.odim, tabcd); - if (index < count) { - index2abcd(lrp.idim, index, tabcd); - trans(lrp.itrans, tabcd, iabcd); - abcd2xyzn(iC, iabcd, ixyzn); - r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]]; - } else { - r[n] = 0; - } - } - outTexture.write(r, gid.xy, gid.z); -} +#define P half +#define DIN 4 +#define DOUT 4 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 3 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 2 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 1 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#undef DIN +#define DIN 3 +#define DOUT 4 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 3 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 2 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 1 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#undef DIN -kernel void reshape_half(texture2d_array inTexture [[texture(0)]], - texture2d_array outTexture [[texture(1)]], - constant ReshapeParam &rp [[buffer(0)]], - uint3 gid [[thread_position_in_grid]]) { - if (gid.x >= outTexture.get_width() || - gid.y >= outTexture.get_height() || - gid.z >= outTexture.get_array_size()) return; - - int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, oabcd[4], ixyzn[4], iabcd[4]; - ReshapeParam lrp = rp; - int oC = lrp.odim[lrp.otrans[3]]; - int iC = lrp.idim[lrp.itrans[3]]; - int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3]; - half4 r; - for (int n = 0; n < 4; n++) { - oxyzn[3] = n; - xyzn2abcd(oC, oxyzn, oabcd); - int tabcd[4]; - invtrans(lrp.otrans, oabcd, tabcd); - int index = abcd2index(lrp.odim, tabcd); - if (index < count) { - index2abcd(lrp.idim, index, tabcd); - trans(lrp.itrans, tabcd, iabcd); - abcd2xyzn(iC, iabcd, ixyzn); - r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]]; - } else { - r[n] = 0; - } - } - outTexture.write(r, gid.xy, gid.z); -} +#define DIN 2 +#define DOUT 4 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 3 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 2 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 1 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#undef DIN +#define DIN 1 +#define DOUT 4 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 3 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 2 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#define DOUT 1 +#include "ReshapeKernel.metal.inc" +#undef DOUT +#undef DIN + +#undef P diff --git a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal.inc b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal.inc index 399287da71..b5e64aa774 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal.inc +++ b/metal/paddle-mobile/paddle-mobile/Operators/Kernels/metal/ReshapeKernel.metal.inc @@ -1,77 +1,18 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ +#ifndef P +#define P float +#endif -#include -#include "Common.metal" +#define CONCAT2(a, b) a ## b +#define CONCAT2_(a, b) a ## _ ## b +#define CONCAT3_(a, b, c) a ## _ ## b ## _ ## c +#define CONCAT4_(a, b, c, d) a ## _ ## b ## _ ## c ## _ ## d -using namespace metal; +#define FUNC(f, d1, d2, p) CONCAT4_(f, d1, d2, p) +#define VECTOR(p, n) CONCAT2(p, n) +#define FUNC_D(f, d) CONCAT2_(f, d) -struct ReshapeParam { - int32_t idim[4]; - int32_t itrans[4]; - int32_t odim[4]; - int32_t otrans[4]; -}; - -//kernel void reshape(texture2d_array inTexture [[texture(0)]], -// texture2d_array outTexture [[texture(1)]], -// constant ReshapeParam &rp [[buffer(0)]], -// uint3 gid [[thread_position_in_grid]]) { -// if (gid.x >= outTexture.get_width() || -// gid.y >= outTexture.get_height() || -// gid.z >= outTexture.get_array_size()) return; -// -// int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, oabcd[4], ixyzn[4]; -// ReshapeParam lrp = rp; -// int oC = lrp.odim[lrp.otrans[3]]; -// int iC = lrp.idim[lrp.itrans[3]]; -// int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3]; -// float4 r; -// for (int n = 0; n < 4; n++) { -// oxyzn[3] = n; -// -// //4 (gid.x gid.y, gid.z, 0~4) -// xyzn2abcd(oC, oxyzn, oabcd); -// int tabcd[4]; -// invtrans(lrp.otrans, oabcd, tabcd); -// int index = abcd2index(lrp.odim, tabcd); -// if (index < count) { -// int c = index % 4; -// -// int temp0 = index % (inTexture.get_array_size() * 4); -// int slice = temp0 / 4; -// -// int temp1 = index % (inTexture.get_array_size() * 4 * lrp.idim[2]); -// int w = temp1 / (inTexture.get_array_size() * 4); -// -// int h = index / (inTexture.get_array_size() * 4 * lrp.idim[2]); -// -//// index2abcd(lrp.idim, index, tabcd); -//// abcd2xyzn(iC, tabcd, ixyzn); -// r[n] = inTexture.read(uint2(w, h), slice)[c]; -// } else { -// r[n] = 0; -// } -// } -// outTexture.write(r, gid.xy, gid.z); -//} - - - - -kernel void reshape(texture2d_array inTexture [[texture(0)]], - texture2d_array outTexture [[texture(1)]], +kernel void FUNC(reshape, DIN, DOUT, P)(texture2d_array inTexture [[texture(0)]], + texture2d_array outTexture [[texture(1)]], constant ReshapeParam &rp [[buffer(0)]], uint3 gid [[thread_position_in_grid]]) { if (gid.x >= outTexture.get_width() || @@ -83,43 +24,14 @@ kernel void reshape(texture2d_array inTexture [[texture(0)] int oC = lrp.odim[lrp.otrans[3]]; int iC = lrp.idim[lrp.itrans[3]]; int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3]; - float4 r; - for (int n = 0; n < 4; n++) { - oxyzn[3] = n; - xyzn2abcd(oC, oxyzn, oabcd); - int tabcd[4]; - invtrans(lrp.otrans, oabcd, tabcd); - int index = abcd2index(lrp.odim, tabcd); - if (index < count) { - index2abcd(lrp.idim, index, tabcd); - trans(lrp.itrans, tabcd, iabcd); - abcd2xyzn(iC, iabcd, ixyzn); - r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]]; - } else { - r[n] = 0; - } - } - outTexture.write(r, gid.xy, gid.z); -} - - -kernel void reshape_half(texture2d_array inTexture [[texture(0)]], - texture2d_array outTexture [[texture(1)]], - constant ReshapeParam &rp [[buffer(0)]], - uint3 gid [[thread_position_in_grid]]) { - if (gid.x >= outTexture.get_width() || - gid.y >= outTexture.get_height() || - gid.z >= outTexture.get_array_size()) return; - - int oxyzn[4] = {int(gid.x), int(gid.y), int(gid.z), 0}, oabcd[4], ixyzn[4], iabcd[4]; - ReshapeParam lrp = rp; - int oC = lrp.odim[lrp.otrans[3]]; - int iC = lrp.idim[lrp.itrans[3]]; - int count = lrp.odim[0] * lrp.odim[1] * lrp.odim[2] * lrp.odim[3]; - half4 r; + VECTOR(P, 4) r; for (int n = 0; n < 4; n++) { oxyzn[3] = n; - xyzn2abcd(oC, oxyzn, oabcd); +#if DOUT == 4 + xyzn2abcd_4(oC, oxyzn, oabcd); +#else + FUNC_D(xyzn2abcd, DOUT)(oxyzn, oabcd); +#endif int tabcd[4]; invtrans(lrp.otrans, oabcd, tabcd); int index = abcd2index(lrp.odim, tabcd); @@ -127,6 +39,11 @@ kernel void reshape_half(texture2d_array inTexture [[texture index2abcd(lrp.idim, index, tabcd); trans(lrp.itrans, tabcd, iabcd); abcd2xyzn(iC, iabcd, ixyzn); +#if DIN == 4 + abcd2xyzn_4(iC, iabcd, ixyzn); +#else + FUNC_D(abcd2xyzn, DIN)(iabcd, ixyzn); +#endif r[n] = inTexture.read(uint2(ixyzn[0], ixyzn[1]), ixyzn[2])[ixyzn[3]]; } else { r[n] = 0; -- GitLab