提交 d80d88bc 编写于 作者: B Ben Shi

llvm: implement support of closure functions.

上级 cb393db0
# Test the llvm backend. # Test the llvm backend.
# Test anonymous functions and closure functions. # Test anonymous functions and closure functions.
type pair struct {
i :f32
j :f32
}
fn main() { fn main() {
print("Hello, ") print("Hello, ")
fn() { fn() {
println("World!") println("World!")
}() }()
var i: int var i: int = 31
show := fn() { var j: [4]f32 = [4]f32{1, 2.2, 5.5, 9.8}
println("i = ", i) var k: pair = pair{3.14, 2.718}
show := fn(q: int) {
println(i, " + ", q, " = ", i+q)
println("{", j[0], ", ", j[1], ", ", j[2], ", ", j[3], "}")
println("{", k.i, ", ", k.j, "}")
} }
for i = 0; i < 10; i++ { for i := int(0); i < 4; i++ {
show() j[i] += 1.0
k.i += 0.1
k.j -= 0.1
show(i)
} }
} }
...@@ -16,12 +16,22 @@ type FmtStr struct { ...@@ -16,12 +16,22 @@ type FmtStr struct {
size int size int
} }
type Argument struct {
AType string
AName string
}
type InnerFunc struct {
fn *ssa.Function
args []Argument
}
type Compiler struct { type Compiler struct {
target string target string
output strings.Builder output strings.Builder
debug bool debug bool
fmts []FmtStr fmts []FmtStr
anofn []*ssa.Function anofn []InnerFunc
} }
func New(target string, debug bool) *Compiler { func New(target string, debug bool) *Compiler {
...@@ -138,18 +148,18 @@ func (p *Compiler) compilePackage(pkg *ssa.Package) error { ...@@ -138,18 +148,18 @@ func (p *Compiler) compilePackage(pkg *ssa.Package) error {
// Generate LLVM-IR for each global function. // Generate LLVM-IR for each global function.
for _, v := range fns { for _, v := range fns {
if err := p.compileFunction(v); err != nil { if err := p.compileFunction(v, []Argument{}); err != nil {
return err return err
} }
} }
// Generate LLVM-IR for each internal function. // Generate LLVM-IR for each internal function.
for _, v := range p.anofn { for _, v := range p.anofn {
if err := p.compileFunction(v); err != nil { if err := p.compileFunction(v.fn, v.args); err != nil {
return err return err
} }
} }
p.anofn = []*ssa.Function{} p.anofn = []InnerFunc{}
return nil return nil
} }
......
...@@ -11,7 +11,7 @@ import ( ...@@ -11,7 +11,7 @@ import (
"github.com/wa-lang/wa/internal/types" "github.com/wa-lang/wa/internal/types"
) )
func (p *Compiler) compileFunction(fn *ssa.Function) error { func (p *Compiler) compileFunction(fn *ssa.Function, extraArgs []Argument) error {
if isTargetBuiltin(fn.LinkName(), p.target) || isTargetBuiltin(fn.Name(), p.target) { if isTargetBuiltin(fn.LinkName(), p.target) || isTargetBuiltin(fn.Name(), p.target) {
return nil return nil
} }
...@@ -82,6 +82,14 @@ func (p *Compiler) compileFunction(fn *ssa.Function) error { ...@@ -82,6 +82,14 @@ func (p *Compiler) compileFunction(fn *ssa.Function) error {
p.output.WriteString(", ") p.output.WriteString(", ")
} }
} }
// Emit binded values as extra arguments for closure functions.
for _, v := range extraArgs {
p.output.WriteString(", ")
p.output.WriteString(v.AType)
p.output.WriteString(" ")
p.output.WriteString(v.AName)
}
// Finish emitting function header.
if len(fn.Blocks) == 0 { if len(fn.Blocks) == 0 {
p.output.WriteString(")\n\n") p.output.WriteString(")\n\n")
return nil return nil
......
...@@ -114,6 +114,10 @@ func (p *Compiler) compileValue(val ssa.Value) error { ...@@ -114,6 +114,10 @@ func (p *Compiler) compileValue(val ssa.Value) error {
p.output.WriteString(fmt.Sprintf("%d", val.Index)) p.output.WriteString(fmt.Sprintf("%d", val.Index))
p.output.WriteString("\n") p.output.WriteString("\n")
case *ssa.MakeClosure:
// We postpone the process of closure functions to the fist call of them.
break
default: default:
p.output.WriteString(" ; " + val.Name() + " = " + val.String() + "\n") p.output.WriteString(" ; " + val.Name() + " = " + val.String() + "\n")
// panic("unsupported Value '" + val.Name() + " = " + val.String() + "'") // panic("unsupported Value '" + val.Name() + " = " + val.String() + "'")
...@@ -359,7 +363,7 @@ func (p *Compiler) compileBinOp(val *ssa.BinOp) error { ...@@ -359,7 +363,7 @@ func (p *Compiler) compileBinOp(val *ssa.BinOp) error {
func (p *Compiler) compileCall(val *ssa.Call) error { func (p *Compiler) compileCall(val *ssa.Call) error {
switch val.Call.Value.(type) { switch val.Call.Value.(type) {
case *ssa.Function: case *ssa.Function, *ssa.MakeClosure:
// Special process for float32 constants. // Special process for float32 constants.
paf32 := map[int]string{} paf32 := map[int]string{}
for i, v := range val.Call.Args { for i, v := range val.Call.Args {
...@@ -389,16 +393,42 @@ func (p *Compiler) compileCall(val *ssa.Call) error { ...@@ -389,16 +393,42 @@ func (p *Compiler) compileCall(val *ssa.Call) error {
callee := val.Call.StaticCallee() callee := val.Call.StaticCallee()
// This callee is an internal function, whose body will be genereated later. // This callee is an internal function, whose body will be genereated later.
if callee.Parent() != nil { if callee.Parent() != nil {
p.anofn = append(p.anofn, callee) found := false
for _, f := range p.anofn {
if f.fn == callee {
found = true
}
}
if !found {
inner := InnerFunc{callee, []Argument{}}
// Collect all binded values for closure functions, then emit them as implicit arguments.
if cl, ok := val.Call.Value.(*ssa.MakeClosure); ok {
for _, v := range cl.Bindings {
if al, ok := v.(*ssa.Alloc); ok {
arg := Argument{getTypeStr(v.Type(), p.target), "%" + al.Comment}
inner.args = append(inner.args, arg)
}
}
}
p.anofn = append(p.anofn, inner)
}
} }
// Emit link name for external functions.
if len(callee.LinkName()) > 0 { if len(callee.LinkName()) > 0 {
p.output.WriteString(callee.LinkName()) p.output.WriteString(callee.LinkName())
} else { } else {
p.output.WriteString(getNormalName(callee.Pkg.Pkg.Path() + "." + callee.Name())) p.output.WriteString(getNormalName(callee.Pkg.Pkg.Path() + "." + callee.Name()))
} }
p.output.WriteString("(") p.output.WriteString("(")
// Collect all binded values for closure functions, then pass them as implicit parameters.
params := val.Call.Args[0:]
if cl, ok := val.Call.Value.(*ssa.MakeClosure); ok {
for _, v := range cl.Bindings {
params = append(params, v)
}
}
// Emit parameters. // Emit parameters.
for i, v := range val.Call.Args { for i, v := range params {
ty := getRealType(v.Type()) ty := getRealType(v.Type())
tyStr := getTypeStr(ty, p.target) tyStr := getTypeStr(ty, p.target)
switch ty.(type) { switch ty.(type) {
...@@ -414,7 +444,7 @@ func (p *Compiler) compileCall(val *ssa.Call) error { ...@@ -414,7 +444,7 @@ func (p *Compiler) compileCall(val *ssa.Call) error {
} else { } else {
p.output.WriteString(getValueStr(v)) p.output.WriteString(getValueStr(v))
} }
if i < len(val.Call.Args)-1 { if i < len(params)-1 {
p.output.WriteString(", ") p.output.WriteString(", ")
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册