提交 80fbb9f4 编写于 作者: D dolphin8

reshape infer shape

上级 88c6f04b
......@@ -20,14 +20,36 @@ class ReshapeParam<P: PrecisionType>: OpParam {
do {
input = try ReshapeParam.inputX(inputs: opDesc.inputs, from: inScope)
output = try ReshapeParam.outputOut(outputs: opDesc.outputs, from: inScope)
// shape = output.dim
shape = try ReshapeParam.getAttr(key: "shape", attrs: opDesc.attrs)
var s: [Int] = shape.map { Int($0) }
var di = -1
var ml = 1
for i in 0..<s.count {
if s[i] == -1 {
di = i
continue
}
ml *= s[i]
}
if di >= 0 {
s[di] = input.dim.numel() / ml
}
output.tensorDim = Dim.init(inDim: s)
var dim: [Int] = [1, 1, 1, 1]
for i in 0..<s.count {
dim[4-s.count+i] = s[i]
}
output.originDim = Dim.init(inDim: dim)
output.dim = output.originDim
inplace = try ReshapeParam.getAttr(key: "inplace", attrs: opDesc.attrs)
} catch let error {
throw error
}
}
let input: Texture<P>
// let shape: [Int]
let shape: [Int32]
let inplace: Bool
var output: Texture<P>
}
......
......@@ -40,8 +40,8 @@ extension InputTexture {
public class Texture<P: PrecisionType>: Tensorial {
var dim: Dim
private(set) public var tensorDim: Dim
private(set) public var originDim: Dim
public var tensorDim: Dim
public var originDim: Dim
private var textureDesc: MTLTextureDescriptor!
public var metalTexture: MTLTexture!
var transpose: [Int] = [0, 1, 2, 3]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册