eval.go 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
package expr

import (
	"fmt"
)

// EvalProgram returns the result of executing the program with the given resolver.
func EvalProgram(resolver Resolver, program *Program) (v interface{}, err error) {
	defer func() {
		if r := recover(); r != nil {
			if rerr, ok := r.(error); ok {
				err = rerr
			} else {
				panic(r)
			}
		}
	}()

	v = evalnode(resolver, program.root).RawValue()
	return
}

// Eval returns the result of evaluating the provided expression.
func Eval(resolver Resolver, expr string) (interface{}, error) {
	return EvalProgram(resolver, ParseString(expr))
}

func evalnode(resolver Resolver, node node) Value {
	switch n := node.(type) {
	case identnode:
		v := resolver.Resolve(n.ident)
		if v == nil {
			panic(fmt.Errorf("unresolved name %s", n.ident))
		}
		return v
	case intnode:
		if n.sign {
			return literalintval(n.ival)
		}
		return literaluintval(n.uval)
	case floatnode:
		return literalfloatval(n.fval)
	case boolnode:
		return literalboolval(n.val)
	case strnode:
		return literalstrval(n.val)
	case runenode:
		return literalintval(int64(n.val))
	case nilnode:
		return literalnilval()
	case unaryexpr:
		return evalunary(resolver, n)
	case binaryexpr:
		return evalbinary(resolver, n)
	case ternaryexpr:
		return evalternary(resolver, n)
	default:
		panic("invalid node")
	}
}

func evalunary(resolver Resolver, node unaryexpr) Value {
	n := evalnode(resolver, node.n)
	switch node.op {
	case unaryplus:
		return n
	case unarynegate:
		return n.Negate()
	case unarynot:
		return n.Not()
	case unarybitnot:
		return n.BitNot()
	case unaryderef:
		return n.Deref()
	case unaryref:
		return n.Ref()
	default:
		panic("invalid unary expression")
	}
}

func flattengroup(n node) []node {
	if n, ok := n.(binaryexpr); ok {
		if n.op == binarygroup {
			return append(flattengroup(n.a), flattengroup(n.b)...)
		}
	}
	return []node{n}
}

func evalbinary(resolver Resolver, node binaryexpr) Value {
	a := evalnode(resolver, node.a)
	switch node.op {
	case binarymember:
		if id, ok := node.b.(identnode); ok {
			return a.Dot(id.ident)
		}
		panic(fmt.Errorf("expected ident node, got %T", node.b))
	case binarycall:
		in := []Value{}
		for _, n := range flattengroup(node.b) {
			in = append(in, evalnode(resolver, n))
		}
		return a.Call(in)
	case binarygroup:
		return evalnode(resolver, node.b)
	}

	b := evalnode(resolver, node.b)
	switch node.op {
	case binarylogicalor:
		return a.LogicalOr(b)
	case binarylogicaland:
		return a.LogicalAnd(b)
	case binaryequal:
		return a.Equal(b)
	case binarynotequal:
		return a.NotEqual(b)
	case binarylesser:
		return a.Lesser(b)
	case binarylesserequal:
		return a.LesserEqual(b)
	case binarygreater:
		return a.Greater(b)
	case binarygreaterequal:
		return a.GreaterEqual(b)
	case binaryadd:
		return a.Add(b)
	case binarysub:
		return a.Sub(b)
	case binaryor:
		return a.Or(b)
	case binaryxor:
		return a.Xor(b)
	case binarymul:
		return a.Mul(b)
	case binarydiv:
		return a.Div(b)
	case binaryrem:
		return a.Rem(b)
	case binarylsh:
		return a.Lsh(b)
	case binaryrsh:
		return a.Rsh(b)
	case binaryand:
		return a.And(b)
	case binaryandnot:
		return a.AndNot(b)
	case binarysubscript:
		return a.Index(b)
	default:
		panic("invalid binary expression")
	}
}

func evalternary(resolver Resolver, node ternaryexpr) Value {
	a := evalnode(resolver, node.a).Value().Interface()
	cond, ok := a.(bool)
	if !ok {
		panic(fmt.Errorf("unexpected type %T for ternary", cond))
	}
	if cond {
		return evalnode(resolver, node.b)
	}
	return evalnode(resolver, node.c)
}