Adding a Rust-inspired try `?` operator to Golang

A few years ago I saw someone post on Reddit that they added a try keyword to Go, and while I thought it was cool, I think the Rust try operator is cleaner. For those unfamiliar, it’s syntactic sugar that turns this :

fn some_fn() -> Result<(), Box<dyn Error>> {
    let data = match some_expression() {
        Ok(v) => v,
        Err(e) => return Err(From::from(e)),
    }
    // ... do something with data
    Ok(())
}

Into this :

fn some_fn() -> Result<(), Box<dyn Error>> {
    let data = some_expression()?;
    // ... do something with data
    Ok(())
}

It’s eerily similar to Go’s data, err error handling but instead of returning a Tuple it’s an ADT. With that context out of the way, I thought it would be fun to try adding this to the compiler myself, though AFAIK the Go team is against it.

This blog is a walkthrough on how I made this change in a fork of the Go compiler.

Scanner/Lexer

The first change was in the scanner/lexer, whose job is to turn raw characters from your .go file into a stream of tokens. Here, all that was was needed was declaring the new token type in tokens.go:

const (
	_    token = iota
	_EOF
    // ...
	_Question
    // ...
)

Then telling the scanner how to process this new token in scanner.go’s next() function.

func (s *scanner) next() {

    // ...

	switch s.ch {

    // ...

	case '?':
		s.nextch()
		s.nlsemi = true
		s.tok = _Question

    // ...
}

This is basically saying if you see a question mark, turn it into a _Question token and move onto the next character in the stream. The s.nlsemi makes sure that ? \n is treated as ?; \n, so that we get to keep our implicit semicolons.

The AST

The next step in compilation is taking that token stream and building a tree representation of your program. To accomodate the new ? operator, we need to create a new AST node for it in nodes.go :

QuestionExpr struct {
    X Expr
    expr
}

Since the question mark operator evaluates to a value, its an expression, but its result depends on the expression it’s evaluating on. The X represents whatever expression is immediately to the left of the ?.

Now we have a QuestionExpr node but we need to tell the compiler how to traverse into and format it by changing the the walker and printer in walk.go and printer.go.

func (w walker) node(n Node) {
    switch n := n.(type) {

    // ...

	case *QuestionExpr:
		w.node(n.X)

    // ...

    }
}
func (p *printer) printRawNode(n Node) {
	switch n := n.(type) {

    // ...
    
	case *QuestionExpr:
		p.print(n.X, _Question)
    }

    // ...
    
}

With these changes, we can now create the token for the ? operator, and represent it in the AST, but we can’t yet parse to go from token stream -> AST.

Parser

To bridge the gap between these last 2 steps, we need to edit the expression parser (since ? is an expression).

func (p *parser) pexpr(x Expr, keep_parens bool) Expr {

loop:
	for {
		pos := p.pos()
		switch p.tok {

        // ...

		case _Question:
			t := new(QuestionExpr)
			t.pos = pos
			t.X = x
			p.next()
			x = t
	}

	return x
}
}

The logic here is that we grab the expression we just parsed (x), and wrap it in our new QuestionExpr{X: x}, then continue with the parsing.

Type Checking

We can’t allow the use of ? just anywhere, and we need to encode this in the type system. The set of rules to use ? are the following :

  • X (the operator’s target) is a function call
  • The function call returns at least one value, and the last value implements Go’s Error interface
  • The outer function must returns at least one value, with the last also implementing Go’s Error interface

Once we have proved these, we must also make some type alterations.

  • Strip off the Error such that QuestionExpr evaluates to the remaining tuple
  • Change solo Error returns to statements rather than expressions

As an example for the first alteration, if a function returns (string, int, error), we strip the error so the rest of the compiler sees the QuestionExpr as evaluating to (string, int). This way a, b := something()? type-checks correctly with two variables on the left and two values on the right.

For the second alteration, consider a function that returns only Error (like os.Remove(path)). After stripping the error, there aren’t any remaining values so if we set the expression return type to nothing but still return it as an expression, the compiler would complain that the evaluation result is not used when you write os.Remove(path)? as a standalone statement. In this case, we tell the type checker this is a statement, not an expression.

// exprInternal contains the core of type checking of expressions.
// Must only be called by rawExpr.
// (See rawExpr for an explanation of the parameters.)
func (check *Checker) exprInternal(T *target, x *operand, e syntax.Expr, hint Type) exprKind {

	switch e := e.(type) {

    // ...

	case *syntax.QuestionExpr:
		// Rule 1 : The operator's target is a function call
		call, isCall := e.X.(*syntax.CallExpr)
		if !isCall {
			check.errorf(e, InvalidSyntaxTree, "? operator requires a function call")
			goto Error
		}

		// Type-check the call itself
		check.callExpr(x, call)
		if !x.isValid() {
			goto Error
		}

		// Get the result types of the call
		var resultVars []*Var
		switch t := x.typ().(type) {
		case *Tuple:
			resultVars = t.vars
		default:
			if x.typ() != nil {
				resultVars = []*Var{NewVar(x.Pos(), nil, "", x.typ())}
			}
		}
        
        // Rule 2.1 : The function call must return at least one value
		if len(resultVars) == 0 {
			check.errorf(e, InvalidSyntaxTree, "? operator requires call returning at least an error")
			goto Error
		}

		// Rule 2.2 : The last value implements Go's error interface
		{
			lastType := resultVars[len(resultVars)-1].typ
			if !Identical(lastType, universeError) { // check for universal error type
				if !check.implements(lastType, universeError, false, nil) { // if that fails, check for custom types that satisfy the error interface
					check.errorf(e, InvalidSyntaxTree, "? operator requires last return value to implement error, got %s", lastType)
					goto Error
				}
			}
		}

		if check.sig == nil {
			check.errorf(e, InvalidSyntaxTree, "? operator used outside of function")
			goto Error
		}
		{
            // Rule 3.1 : The enclosing function must return at least one value
			encResults := check.sig.results
			if encResults == nil || encResults.Len() == 0 {
				check.errorf(e, InvalidSyntaxTree, "cannot use ? in function that does not return error")
				goto Error
			}
            // Rule 3.2 : The enclosing function's last value implements Go's error interface
			encLastType := encResults.vars[encResults.Len()-1].typ
			if !Identical(encLastType, universeError) {
				if !check.implements(encLastType, universeError, false, nil) {
					check.errorf(e, InvalidSyntaxTree, "cannot use ? in function whose last return type is not error (got %s)", encLastType)
					goto Error
				}
			}
		}

		// Type Alteration 1 : Strip last error value
		{
			nonErrVars := resultVars[:len(resultVars)-1]
			if len(nonErrVars) == 0 {
                // Type Alteration 2 : Change from expression to statement 
				x.mode_ = novalue
				x.typ_ = nil
				x.expr = e
				return statement
			} else if len(nonErrVars) == 1 {
				// (T, error) -> T
				x.typ_ = nonErrVars[0].typ
				x.mode_ = value
			} else {
				// (T, U, ..., error) -> (T, U, ...)
				x.typ_ = NewTuple(nonErrVars...)
				x.mode_ = value
			}
			x.expr = e
		}

    // ...

	x.expr = e
	return expression
    }
}

Desugaring

At this point, we have a fully valid and typechecked AST, but since the ? operator is just syntactic sugar, we want to desugar it before the compiler goes into its SSA and machine code generation. So, the goal is to turn our high-level QuestionExpr AST into an IR that does the if err != nil { return err } check.

The Go compiler’s uses a serialization format called pkgbits to pass the syntax tree between the compiler frontend and backend phases. The AST nodes are serialized into a file, then deserialized into IR nodes.

In codes.go we register our new node for this serialization process.

const (
	exprConst  codeExpr = iota
    // ...
	exprQuestion
    // ...
)

Then, in writer.go (the serialization part) we plug in our node type

func (w *writer) expr(expr syntax.Expr) {

    // ...

	switch expr := expr.(type) {
    
    // ...

	case *syntax.QuestionExpr:
		w.Code(exprQuestion)
		w.pos(expr)
		w.expr(expr.X)

	// ...
    
	}
}

In reader.go (the deserialization part), we do the desugaring work. When it encounters an exprQuestion AST node, it constructs the equivalent if err != nil { return zero..., err } as IR nodes. The logic has three steps :

  • Allocate temporary variables for every value the call returns (tmpA, tmpB, tmpErr := f())
  • Build the error check if tmpErr != nil { return ..., tmpErr }, and use ir.NewZero() to generate a zero value for each type in the enclosing function
  • Wrap everything in the right IR node depending on how many values we return. There are three cases :
    • Zero non-error values, f() -> Error
    • One non-error value, f() -> (T, Error)
    • Multiple non-error values f() -> (T, U, ..., Error)

For zero non-error values, we wrap the statement in an InlinedCallExpr with no results which produces no value, matching the statement kind we set in the type checker. For one non-error value, we use InitExpr which attaches init statements to a single value expression. Finally, for multiple non-error values, we construct an InlinedCallExpr whose result is a tuple of the non-error temps.

// expr reads and returns a typechecked expression.
func (r *reader) expr() (res ir.Node) {
	defer func() {
		if res != nil && res.Typecheck() == 0 {
			base.FatalfAt(res.Pos(), "%v missed typecheck", res)
		}
	}()

	switch tag := codeExpr(r.Code(pkgbits.SyncExpr)); tag {
    
    // ...

	case exprQuestion:
		pos := r.pos()
		x := r.expr()

		var init ir.Nodes
		typ := x.Type()
		var errTmp *ir.Name
		var nonErrTmps []ir.Node

		// Step 1: Allocate temps and build the assignment
        // 
        // also note that in the go compiler the FuncArgStruct is used when there are multiple return values
        // if its not of that type then we know it has exactly 1 return value (Because of our no zero return rule it has to be one return)
		if typ.IsFuncArgStruct() {
			results := typ.Fields()
			var lhs []ir.Node
			for i, f := range results {
				tmp := r.temp(pos, f.Type)
				init.Append(typecheck.Stmt(ir.NewDecl(pos, ir.ODCL, tmp)))
				lhs = append(lhs, tmp)
                // the last element is the Error
				if i == len(results)-1 {
					errTmp = tmp
				} else {
					nonErrTmps = append(nonErrTmps, tmp)
				}
			}

			as := ir.NewAssignListStmt(pos, ir.OAS2, lhs, []ir.Node{x})
			as.Def = true
			init.Append(typecheck.Stmt(as))
		} else {
            // Single-return case: f() -> error
			errTmp = r.temp(pos, typ)
			init.Append(typecheck.Stmt(ir.NewDecl(pos, ir.ODCL, errTmp)))

			as := ir.NewAssignStmt(pos, errTmp, x)
			as.Def = true
			init.Append(typecheck.Stmt(as))
		}

        // Step 2: Build the error check
		nilNode := ir.NewNilExpr(pos, errTmp.Type())
		cond := typecheck.Expr(ir.NewBinaryExpr(pos, ir.ONE, errTmp, nilNode))

		ret := ir.NewReturnStmt(pos, nil)
		funcResults := r.curfn.Type().Results()
		for i, f := range funcResults {
			if i == len(funcResults)-1 {
                // Propagate the error at the last position
				ret.Results.Append(errTmp)
			} else {
                // Everything else is filled with zero values
				ret.Results.Append(ir.NewZero(pos, f.Type))
			}
		}

		ifStmt := ir.NewIfStmt(pos, cond, []ir.Node{ret}, nil)
		init.Append(typecheck.Stmt(ifStmt))

		// Step 3: Wrap in the right IR node, remember the 3 cases
		if len(nonErrTmps) == 0 {
			// Step 3.1: Zero non-error values
			inlcd := ir.NewInlinedCallExpr(pos, init, nil)
			inlcd.SetTypecheck(1)
			inlcd.SetType(nil)
			return inlcd
		} else if len(nonErrTmps) == 1 {
			// Step 3.2: One non-error value
			return ir.InitExpr(init, nonErrTmps[0])
		} else {
            // Step 3.3: Multiple non-error values
			// Construct an InlinedCallExpr with the new tuple of return values
            // more on this below
			inlcd := ir.NewInlinedCallExpr(pos, init, nonErrTmps)
			inlcd.SetTypecheck(1)

			fields := make([]*types.Field, len(nonErrTmps))
			for i, tmp := range nonErrTmps {
				fields[i] = types.NewField(pos, nil, tmp.Type())
			}
			t := types.NewStruct(fields)
			t.StructType().ParamTuple = true
			inlcd.SetType(t)
			return inlcd
		}

    // ...

	}
}

For the multiple non-error values, we need an InlinedCallExpr. There’s no IR expression node that evaluates to multiple values. InlinedCallExpr is the compiler’s existing way to represent this (but not really meant for this I dont think) We borrow it here nonetheless, with a little workaround

A little workaround

In the previous step, we saw that for multi-value assignments we create an InlinedCallExpr, which breaks an assumption in the compiler’s IR-level type checker. Multi-value assignments like a, b := f() are handled by the assign function in stmt.go, but this function makes the assumption that the right-hand side of a multi-value assignment is always an CallExpr, because in normal Go, the only way to produce multiple values is a function call (not exactly, but other ways are handled by different assignment ops in the compiler). So, for this workaround, we are treating the new ? as though it was an inlined function

func assign(stmt ir.Node, lhs, rhs []ir.Node) {
	
    // ...

	// x,y,z = f()
	if cr > len(rhs) {
		stmt := stmt.(*ir.AssignListStmt)
		stmt.SetOp(ir.OAS2FUNC)
		
		// Originally this did a direct type assertion to CallExpr.
		// We also need to accept InlinedCallExpr, since that's
		// what our ? desugaring produces
		r := rhs[0]
		var rtyp *types.Type
		if call, ok := r.(*ir.CallExpr); ok {
			rtyp = call.Type()
		} else if icall, ok := r.(*ir.InlinedCallExpr); ok {
			rtyp = icall.Type()
		} else {
			base.Fatalf("assign list unexpected rhs: %v", r.Op())
		}

		mismatched := false
		failed := false
		for i := range lhs {
			result := rtyp.Field(i).Type
			assignType(i, result)

			if lhs[i].Type() == nil || result == nil {
				failed = true
			} else if lhs[i] != ir.BlankNode && !types.Identical(lhs[i].Type(), result) {
				mismatched = true
			}
		}

		// Fail on type mismatches without doing implicit conversion like a CallExpr would normally do
        // good enough for the scope of this project :)
		if mismatched && !failed {
			if call, ok := r.(*ir.CallExpr); ok {
				RewriteMultiValueCall(stmt, call)
			} else {
				base.Fatalf("assign list mismatched with InlinedCallExpr not supported")
			}
		}
		return
	}

    // ...
}

And that’s it! We have a zero-cost abstraction for simple error propagation!

Some Examples

Now for some examples.

First, some working examples :

func run() error {
	os.WriteFile("test.txt", []byte("hello"), 0644)?

	host, strPort := net.SplitHostPort("localhost:3000")?
	fmt.Println("Host:", host)

	port := strconv.Atoi(strPort)?
	fmt.Println("Port:", port)

	doubled := strconv.Atoi("8080")? * 2
	fmt.Println("Doubled:", doubled)

	return nil
}

Which results in :

Host: localhost
Port: 3000
Doubled: 16160

And now for some compilation errors :

func bad1() error {
	x := 42?
	_ = x
	return nil
}

func isValid(s string) bool {
	return len(s) > 0 
}

func bad2() error {
	v := isValid("hello")?
	_ = v
	return nil
}

func bad3() {
	_ = strconv.Atoi("8080")?
}

Which results in the following comp time errors :

./errors.go:8:9: invalid syntax tree: ? operator requires a function call
./errors.go:18:23: invalid syntax tree: ? operator requires last return value to implement error, got bool
./errors.go:24:26: invalid syntax tree: cannot use ? in function that does not return error

Code

All the code that was used for this is available at https://github.com/sam-harri/go-question

If there are any inaccuracies, no-no’s, or else, please reach out at samharrison@cs.toronto.edu

This project was made entirely for fun, and I think it was :)