mirror of https://github.com/deuill/grawkit.git synced 2024-09-28 08:22:46 +00:00

469 lines
13 KiB
Raw Normal View History

// Resolve function calls and variable types
package parser
import (
. "github.com/benhoyt/goawk/internal/ast"
. "github.com/benhoyt/goawk/lexer"
type varType int
const (
typeUnknown varType = iota
func (t varType) String() string {
switch t {
case typeScalar:
return "Scalar"
case typeArray:
return "Array"
return "Unknown"
// typeInfo records type information for a single variable
type typeInfo struct {
typ varType
ref *VarExpr
scope VarScope
index int
callName string
argIndex int
// Used by printVarTypes when debugTypes is turned on
func (t typeInfo) String() string {
var scope string
switch t.scope {
case ScopeGlobal:
scope = "Global"
case ScopeLocal:
scope = "Local"
scope = "Special"
return fmt.Sprintf("typ=%s ref=%p scope=%s index=%d callName=%q argIndex=%d",
t.typ, t.ref, scope, t.index, t.callName, t.argIndex)
// A single variable reference (normally scalar)
type varRef struct {
funcName string
ref *VarExpr
isArg bool
pos Position
// A single array reference
type arrayRef struct {
funcName string
ref *ArrayExpr
pos Position
// Initialize the resolver
func (p *parser) initResolve() {
p.varTypes = make(map[string]map[string]typeInfo)
p.varTypes[""] = make(map[string]typeInfo) // globals
p.functions = make(map[string]int)
p.arrayRef("ARGV", Position{1, 1}) // interpreter relies on ARGV being present
p.arrayRef("ENVIRON", Position{1, 1}) // and ENVIRON
p.multiExprs = make(map[*MultiExpr]Position, 3)
// Signal the start of a function
func (p *parser) startFunction(name string, params []string) {
p.funcName = name
p.varTypes[name] = make(map[string]typeInfo)
// Signal the end of a function
func (p *parser) stopFunction() {
p.funcName = ""
// Add function by name with given index
func (p *parser) addFunction(name string, index int) {
p.functions[name] = index
// Records a call to a user function (for resolving indexes later)
type userCall struct {
call *UserCallExpr
pos Position
inFunc string
// Record a user call site
func (p *parser) recordUserCall(call *UserCallExpr, pos Position) {
p.userCalls = append(p.userCalls, userCall{call, pos, p.funcName})
// After parsing, resolve all user calls to their indexes. Also
// ensures functions called have actually been defined, and that
// they're not being called with too many arguments.
func (p *parser) resolveUserCalls(prog *Program) {
// Number the native funcs (order by name to get consistent order)
nativeNames := make([]string, 0, len(p.nativeFuncs))
for name := range p.nativeFuncs {
nativeNames = append(nativeNames, name)
nativeIndexes := make(map[string]int, len(nativeNames))
for i, name := range nativeNames {
nativeIndexes[name] = i
for _, c := range p.userCalls {
// AWK-defined functions take precedence over native Go funcs
index, ok := p.functions[c.call.Name]
if !ok {
f, haveNative := p.nativeFuncs[c.call.Name]
if !haveNative {
panic(p.posErrorf(c.pos, "undefined function %q", c.call.Name))
typ := reflect.TypeOf(f)
if !typ.IsVariadic() && len(c.call.Args) > typ.NumIn() {
panic(p.posErrorf(c.pos, "%q called with more arguments than declared", c.call.Name))
c.call.Native = true
c.call.Index = nativeIndexes[c.call.Name]
function := prog.Functions[index]
if len(c.call.Args) > len(function.Params) {
panic(p.posErrorf(c.pos, "%q called with more arguments than declared", c.call.Name))
c.call.Index = index
// For arguments that are variable references, we don't know the
// type based on context, so mark the types for these as unknown.
func (p *parser) processUserCallArg(funcName string, arg Expr, index int) {
if varExpr, ok := arg.(*VarExpr); ok {
scope, varFuncName := p.getScope(varExpr.Name)
ref := p.varTypes[varFuncName][varExpr.Name].ref
if ref == varExpr {
// Only applies if this is the first reference to this
// variable (otherwise we know the type already)
p.varTypes[varFuncName][varExpr.Name] = typeInfo{typeUnknown, ref, scope, 0, funcName, index}
// Mark the last related varRef (the most recent one) as a
// call argument for later error handling
p.varRefs[len(p.varRefs)-1].isArg = true
// Determine scope of given variable reference (and funcName if it's
// a local, otherwise empty string)
func (p *parser) getScope(name string) (VarScope, string) {
switch {
case p.locals[name]:
return ScopeLocal, p.funcName
case SpecialVarIndex(name) > 0:
return ScopeSpecial, ""
return ScopeGlobal, ""
// Record a variable (scalar) reference and return the *VarExpr (but
// VarExpr.Index won't be set till later)
func (p *parser) varRef(name string, pos Position) *VarExpr {
scope, funcName := p.getScope(name)
expr := &VarExpr{scope, 0, name}
p.varRefs = append(p.varRefs, varRef{funcName, expr, false, pos})
info := p.varTypes[funcName][name]
if info.typ == typeUnknown {
p.varTypes[funcName][name] = typeInfo{typeScalar, expr, scope, 0, info.callName, 0}
return expr
// Record an array reference and return the *ArrayExpr (but
// ArrayExpr.Index won't be set till later)
func (p *parser) arrayRef(name string, pos Position) *ArrayExpr {
scope, funcName := p.getScope(name)
if scope == ScopeSpecial {
panic(p.errorf("can't use scalar %q as array", name))
expr := &ArrayExpr{scope, 0, name}
p.arrayRefs = append(p.arrayRefs, arrayRef{funcName, expr, pos})
info := p.varTypes[funcName][name]
if info.typ == typeUnknown {
p.varTypes[funcName][name] = typeInfo{typeArray, nil, scope, 0, info.callName, 0}
return expr
// Print variable type information (for debugging) on p.debugWriter
func (p *parser) printVarTypes(prog *Program) {
fmt.Fprintf(p.debugWriter, "scalars: %v\n", prog.Scalars)
fmt.Fprintf(p.debugWriter, "arrays: %v\n", prog.Arrays)
funcNames := []string{}
for funcName := range p.varTypes {
funcNames = append(funcNames, funcName)
for _, funcName := range funcNames {
if funcName != "" {
fmt.Fprintf(p.debugWriter, "function %s\n", funcName)
} else {
fmt.Fprintf(p.debugWriter, "globals\n")
varNames := []string{}
for name := range p.varTypes[funcName] {
varNames = append(varNames, name)
for _, name := range varNames {
info := p.varTypes[funcName][name]
fmt.Fprintf(p.debugWriter, " %s: %s\n", name, info)
// If we can't finish resolving after this many iterations, give up
const maxResolveIterations = 10000
// Resolve unknown variables types and generate variable indexes and
// name-to-index mappings for interpreter
func (p *parser) resolveVars(prog *Program) {
// First go through all unknown types and try to determine the
// type from the parameter type in that function definition. May
// need multiple passes depending on the order of functions. This
// is not particularly efficient, but on realistic programs it's
// not an issue.
for i := 0; ; i++ {
progressed := false
for funcName, infos := range p.varTypes {
for name, info := range infos {
if info.scope == ScopeSpecial || info.typ != typeUnknown {
// It's a special var or type is already known
funcIndex, ok := p.functions[info.callName]
if !ok {
// Function being called is a native function
// Determine var type based on type of this parameter
// in the called function (if we know that)
paramName := prog.Functions[funcIndex].Params[info.argIndex]
typ := p.varTypes[info.callName][paramName].typ
if typ != typeUnknown {
if p.debugTypes {
fmt.Fprintf(p.debugWriter, "resolving %s:%s to %s\n",
funcName, name, typ)
info.typ = typ
p.varTypes[funcName][name] = info
progressed = true
if !progressed {
// If we didn't progress we're done (or trying again is
// not going to help)
if i >= maxResolveIterations {
panic(p.errorf("too many iterations trying to resolve variable types"))
// Resolve global variables (iteration order is undefined, so
// assign indexes basically randomly)
prog.Scalars = make(map[string]int)
prog.Arrays = make(map[string]int)
for name, info := range p.varTypes[""] {
_, isFunc := p.functions[name]
if isFunc {
// Global var can't also be the name of a function
panic(p.errorf("global var %q can't also be a function", name))
var index int
if info.scope == ScopeSpecial {
index = SpecialVarIndex(name)
} else if info.typ == typeArray {
index = len(prog.Arrays)
prog.Arrays[name] = index
} else {
index = len(prog.Scalars)
prog.Scalars[name] = index
info.index = index
p.varTypes[""][name] = info
// Fill in unknown parameter types that are being called with arrays,
// for example, as in the following code:
// BEGIN { arr[0]; f(arr) }
// function f(a) { }
for _, c := range p.userCalls {
if c.call.Native {
function := prog.Functions[c.call.Index]
for i, arg := range c.call.Args {
varExpr, ok := arg.(*VarExpr)
if !ok {
funcName := p.getVarFuncName(prog, varExpr.Name, c.inFunc)
argType := p.varTypes[funcName][varExpr.Name]
paramType := p.varTypes[function.Name][function.Params[i]]
if argType.typ == typeArray && paramType.typ == typeUnknown {
paramType.typ = argType.typ
p.varTypes[function.Name][function.Params[i]] = paramType
// Resolve local variables (assign indexes in order of params).
// Also patch up Function.Arrays (tells interpreter which args
// are arrays).
for funcName, infos := range p.varTypes {
if funcName == "" {
scalarIndex := 0
arrayIndex := 0
functionIndex := p.functions[funcName]
function := prog.Functions[functionIndex]
arrays := make([]bool, len(function.Params))
for i, name := range function.Params {
info := infos[name]
var index int
if info.typ == typeArray {
index = arrayIndex
arrays[i] = true
} else {
// typeScalar or typeUnknown: variables may still be
// of unknown type if they've never been referenced --
// default to scalar in that case
index = scalarIndex
info.index = index
p.varTypes[funcName][name] = info
prog.Functions[functionIndex].Arrays = arrays
// Check that variables passed to functions are the correct type
for _, c := range p.userCalls {
// Check native function calls
if c.call.Native {
for _, arg := range c.call.Args {
varExpr, ok := arg.(*VarExpr)
if !ok {
// Non-variable expression, must be scalar
funcName := p.getVarFuncName(prog, varExpr.Name, c.inFunc)
info := p.varTypes[funcName][varExpr.Name]
if info.typ == typeArray {
panic(p.posErrorf(c.pos, "can't pass array %q to native function", varExpr.Name))
// Check AWK function calls
function := prog.Functions[c.call.Index]
for i, arg := range c.call.Args {
varExpr, ok := arg.(*VarExpr)
if !ok {
if function.Arrays[i] {
panic(p.posErrorf(c.pos, "can't pass scalar %s as array param", arg))
funcName := p.getVarFuncName(prog, varExpr.Name, c.inFunc)
info := p.varTypes[funcName][varExpr.Name]
if info.typ == typeArray && !function.Arrays[i] {
panic(p.posErrorf(c.pos, "can't pass array %q as scalar param", varExpr.Name))
if info.typ != typeArray && function.Arrays[i] {
panic(p.posErrorf(c.pos, "can't pass scalar %q as array param", varExpr.Name))
if p.debugTypes {
// Patch up variable indexes (interpreter uses an index instead
// of name for more efficient lookups)
for _, varRef := range p.varRefs {
info := p.varTypes[varRef.funcName][varRef.ref.Name]
if info.typ == typeArray && !varRef.isArg {
panic(p.posErrorf(varRef.pos, "can't use array %q as scalar", varRef.ref.Name))
varRef.ref.Index = info.index
for _, arrayRef := range p.arrayRefs {
info := p.varTypes[arrayRef.funcName][arrayRef.ref.Name]
if info.typ == typeScalar {
panic(p.posErrorf(arrayRef.pos, "can't use scalar %q as array", arrayRef.ref.Name))
arrayRef.ref.Index = info.index
// If name refers to a local (in function inFunc), return that
// function's name, otherwise return "" (meaning global).
func (p *parser) getVarFuncName(prog *Program, name, inFunc string) string {
if inFunc == "" {
return ""
for _, param := range prog.Functions[p.functions[inFunc]].Params {
if name == param {
return inFunc
return ""
// Record a "multi expression" (comma-separated pseudo-expression
// used to allow commas around print/printf arguments).
func (p *parser) multiExpr(exprs []Expr, pos Position) Expr {
expr := &MultiExpr{exprs}
p.multiExprs[expr] = pos
return expr
// Mark the multi expression as used (by a print/printf statement).
func (p *parser) useMultiExpr(expr *MultiExpr) {
delete(p.multiExprs, expr)
// Check that there are no unused multi expressions (syntax error).
func (p *parser) checkMultiExprs() {
if len(p.multiExprs) == 0 {
// Show error on first comma-separated expression
min := Position{1000000000, 1000000000}
for _, pos := range p.multiExprs {
if pos.Line < min.Line || (pos.Line == min.Line && pos.Column < min.Column) {
min = pos
panic(p.posErrorf(min, "unexpected comma-separated expression"))