traefik/vendor/github.com/vulcand/predicate/parse.go
2017-03-09 13:13:02 +01:00

191 lines
4.2 KiB
Go

package predicate
import (
"fmt"
"go/ast"
"go/parser"
"go/token"
"reflect"
"strconv"
)
func NewParser(d Def) (Parser, error) {
return &predicateParser{d: d}, nil
}
type predicateParser struct {
d Def
}
func (p *predicateParser) Parse(in string) (interface{}, error) {
expr, err := parser.ParseExpr(in)
if err != nil {
return nil, err
}
return p.parseNode(expr)
}
func (p *predicateParser) parseNode(node ast.Node) (interface{}, error) {
switch n := node.(type) {
case *ast.BasicLit:
return literalToValue(n)
case *ast.BinaryExpr:
x, err := p.parseNode(n.X)
if err != nil {
return nil, err
}
y, err := p.parseNode(n.Y)
if err != nil {
return nil, err
}
return p.joinPredicates(n.Op, x, y)
case *ast.CallExpr:
// We expect function that will return predicate
name, err := getIdentifier(n.Fun)
if err != nil {
return nil, err
}
fn, err := p.getFunction(name)
if err != nil {
return nil, err
}
arguments, err := collectLiterals(n.Args)
if err != nil {
return nil, err
}
return callFunction(fn, arguments)
case *ast.ParenExpr:
return p.parseNode(n.X)
}
return nil, fmt.Errorf("unsupported %T", node)
}
func (p *predicateParser) getFunction(name string) (interface{}, error) {
v, ok := p.d.Functions[name]
if !ok {
return nil, fmt.Errorf("unsupported function: %s", name)
}
return v, nil
}
func (p *predicateParser) joinPredicates(op token.Token, a, b interface{}) (interface{}, error) {
joinFn, err := p.getJoinFunction(op)
if err != nil {
return nil, err
}
return callFunction(joinFn, []interface{}{a, b})
}
func (p *predicateParser) getJoinFunction(op token.Token) (interface{}, error) {
var fn interface{}
switch op {
case token.LAND:
fn = p.d.Operators.AND
case token.LOR:
fn = p.d.Operators.OR
case token.GTR:
fn = p.d.Operators.GT
case token.GEQ:
fn = p.d.Operators.GE
case token.LSS:
fn = p.d.Operators.LT
case token.LEQ:
fn = p.d.Operators.LE
case token.EQL:
fn = p.d.Operators.EQ
case token.NEQ:
fn = p.d.Operators.NEQ
}
if fn == nil {
return nil, fmt.Errorf("%v is not supported", op)
}
return fn, nil
}
func getIdentifier(node ast.Node) (string, error) {
sexpr, ok := node.(*ast.SelectorExpr)
if ok {
id, ok := sexpr.X.(*ast.Ident)
if !ok {
return "", fmt.Errorf("expected selector identifier, got: %T", sexpr.X)
}
return fmt.Sprintf("%s.%s", id.Name, sexpr.Sel.Name), nil
}
id, ok := node.(*ast.Ident)
if !ok {
return "", fmt.Errorf("expected identifier, got: %T", node)
}
return id.Name, nil
}
func collectLiterals(nodes []ast.Expr) ([]interface{}, error) {
out := make([]interface{}, len(nodes))
for i, n := range nodes {
l, ok := n.(*ast.BasicLit)
if !ok {
return nil, fmt.Errorf("expected literal, got %T", n)
}
val, err := literalToValue(l)
if err != nil {
return nil, err
}
out[i] = val
}
return out, nil
}
func literalToValue(a *ast.BasicLit) (interface{}, error) {
switch a.Kind {
case token.FLOAT:
value, err := strconv.ParseFloat(a.Value, 64)
if err != nil {
return nil, fmt.Errorf("failed to parse argument: %s, error: %s", a.Value, err)
}
return value, nil
case token.INT:
value, err := strconv.Atoi(a.Value)
if err != nil {
return nil, fmt.Errorf("failed to parse argument: %s, error: %s", a.Value, err)
}
return value, nil
case token.STRING:
value, err := strconv.Unquote(a.Value)
if err != nil {
return nil, fmt.Errorf("failed to parse argument: %s, error: %s", a.Value, err)
}
return value, nil
}
return nil, fmt.Errorf("unsupported function argument type: '%v'", a.Kind)
}
func callFunction(f interface{}, args []interface{}) (v interface{}, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("%s", r)
}
}()
arguments := make([]reflect.Value, len(args))
for i, a := range args {
arguments[i] = reflect.ValueOf(a)
}
fn := reflect.ValueOf(f)
ret := fn.Call(arguments)
switch len(ret) {
case 1:
return ret[0].Interface(), nil
case 2:
v, e := ret[0].Interface(), ret[1].Interface()
if e == nil {
return v, nil
}
err, ok := e.(error)
if !ok {
return nil, fmt.Errorf("expected error as a second return value, got %T", e)
}
return v, err
}
return nil, fmt.Errorf("expected at least one return argument for '%v'", fn)
}