From d3d811af59f2bd7da42c2c0d03d5306c0a4cc29e Mon Sep 17 00:00:00 2001 From: dolphin8 Date: Mon, 27 Aug 2018 17:58:49 +0800 Subject: [PATCH] reshape infer shape --- .../paddle-mobile/Operators/ReshapeOp.swift | 26 +++++++++++++++++-- .../paddle-mobile/framework/Texture.swift | 4 +-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift b/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift index 3bcf9c15a0..4c4e910499 100644 --- a/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift +++ b/metal/paddle-mobile/paddle-mobile/Operators/ReshapeOp.swift @@ -20,14 +20,36 @@ class ReshapeParam: OpParam { do { input = try ReshapeParam.inputX(inputs: opDesc.inputs, from: inScope) output = try ReshapeParam.outputOut(outputs: opDesc.outputs, from: inScope) - // shape = output.dim + shape = try ReshapeParam.getAttr(key: "shape", attrs: opDesc.attrs) + + var s: [Int] = shape.map { Int($0) } + var di = -1 + var ml = 1 + for i in 0..= 0 { + s[di] = input.dim.numel() / ml + } + output.tensorDim = Dim.init(inDim: s) + var dim: [Int] = [1, 1, 1, 1] + for i in 0.. - // let shape: [Int] + let shape: [Int32] let inplace: Bool var output: Texture

} diff --git a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift index 231ef0fb27..9889171dba 100644 --- a/metal/paddle-mobile/paddle-mobile/framework/Texture.swift +++ b/metal/paddle-mobile/paddle-mobile/framework/Texture.swift @@ -40,8 +40,8 @@ extension InputTexture { public class Texture: Tensorial { var dim: Dim - private(set) public var tensorDim: Dim - private(set) public var originDim: Dim + public var tensorDim: Dim + public var originDim: Dim private var textureDesc: MTLTextureDescriptor! public var metalTexture: MTLTexture! var transpose: [Int] = [0, 1, 2, 3] -- GitLab