Adding a Rust-inspired try `?` operator to Golang
A few years ago I saw someone post on Reddit
that they added a try keyword to Go, and while I thought it was cool, I think the Rust try operator is cleaner. For those unfamiliar, it’s syntactic
sugar that turns this :
fn some_fn() -> Result<(), Box<dyn Error>> {
let data = match some_expression() {
Ok(v) => v,
Err(e) => return Err(From::from(e)),
}
// ... do something with data
Ok(())
}
Into this :
fn some_fn() -> Result<(), Box<dyn Error>> {
let data = some_expression()?;
// ... do something with data
Ok(())
}
It’s eerily similar to Go’s data, err error handling but instead of returning a Tuple it’s an ADT.
With that context out of the way, I thought it would be fun to try adding this to the compiler myself, though AFAIK the Go team is against it.
This blog is a walkthrough on how I made this change in a fork of the Go compiler.
Scanner/Lexer
The first change was in the scanner/lexer, whose job is to turn raw characters from your .go file into a stream of tokens.
Here, all that was was needed was declaring the new token type in tokens.go:
const (
_ token = iota
_EOF
// ...
_Question
// ...
)
Then telling the scanner how to process this new token in scanner.go’s next() function.
func (s *scanner) next() {
// ...
switch s.ch {
// ...
case '?':
s.nextch()
s.nlsemi = true
s.tok = _Question
// ...
}
This is basically saying if you see a question mark, turn it into a _Question token and move onto the next character in the stream.
The s.nlsemi makes sure that ? \n is treated as ?; \n, so that we get to keep our implicit semicolons.
The AST
The next step in compilation is taking that token stream and building a tree representation of your program.
To accomodate the new ? operator, we need to create a new AST node for it in nodes.go :
QuestionExpr struct {
X Expr
expr
}
Since the question mark operator evaluates to a value, its an expression, but its result depends on the expression it’s evaluating on.
The X represents whatever expression is immediately to the left of the ?.
Now we have a QuestionExpr node but we need to tell the compiler how to traverse into and format it by changing the the walker and printer in walk.go and printer.go.
func (w walker) node(n Node) {
switch n := n.(type) {
// ...
case *QuestionExpr:
w.node(n.X)
// ...
}
}
func (p *printer) printRawNode(n Node) {
switch n := n.(type) {
// ...
case *QuestionExpr:
p.print(n.X, _Question)
}
// ...
}
With these changes, we can now create the token for the ? operator, and represent it in the AST, but we can’t yet parse to go from token stream -> AST.
Parser
To bridge the gap between these last 2 steps, we need to edit the expression parser (since ? is an expression).
func (p *parser) pexpr(x Expr, keep_parens bool) Expr {
loop:
for {
pos := p.pos()
switch p.tok {
// ...
case _Question:
t := new(QuestionExpr)
t.pos = pos
t.X = x
p.next()
x = t
}
return x
}
}
The logic here is that we grab the expression we just parsed (x), and wrap it in our new QuestionExpr{X: x}, then continue with the parsing.
Type Checking
We can’t allow the use of ? just anywhere, and we need to encode this in the type system. The set of rules to use ? are the following :
X(the operator’s target) is a function call- The function call returns at least one value, and the last value implements Go’s
Errorinterface - The outer function must returns at least one value, with the last also implementing Go’s
Errorinterface
Once we have proved these, we must also make some type alterations.
- Strip off the
Errorsuch thatQuestionExprevaluates to the remaining tuple - Change solo
Errorreturns to statements rather than expressions
As an example for the first alteration, if a function returns (string, int, error), we strip the error so the rest of the compiler sees the QuestionExpr as evaluating to (string, int).
This way a, b := something()? type-checks correctly with two variables on the left and two values on the right.
For the second alteration, consider a function that returns only Error (like os.Remove(path)). After stripping the error, there aren’t any remaining values so
if we set the expression return type to nothing but still return it as an expression, the compiler would complain that the evaluation result is not used when you write os.Remove(path)? as a standalone statement.
In this case, we tell the type checker this is a statement, not an expression.
// exprInternal contains the core of type checking of expressions.
// Must only be called by rawExpr.
// (See rawExpr for an explanation of the parameters.)
func (check *Checker) exprInternal(T *target, x *operand, e syntax.Expr, hint Type) exprKind {
switch e := e.(type) {
// ...
case *syntax.QuestionExpr:
// Rule 1 : The operator's target is a function call
call, isCall := e.X.(*syntax.CallExpr)
if !isCall {
check.errorf(e, InvalidSyntaxTree, "? operator requires a function call")
goto Error
}
// Type-check the call itself
check.callExpr(x, call)
if !x.isValid() {
goto Error
}
// Get the result types of the call
var resultVars []*Var
switch t := x.typ().(type) {
case *Tuple:
resultVars = t.vars
default:
if x.typ() != nil {
resultVars = []*Var{NewVar(x.Pos(), nil, "", x.typ())}
}
}
// Rule 2.1 : The function call must return at least one value
if len(resultVars) == 0 {
check.errorf(e, InvalidSyntaxTree, "? operator requires call returning at least an error")
goto Error
}
// Rule 2.2 : The last value implements Go's error interface
{
lastType := resultVars[len(resultVars)-1].typ
if !Identical(lastType, universeError) { // check for universal error type
if !check.implements(lastType, universeError, false, nil) { // if that fails, check for custom types that satisfy the error interface
check.errorf(e, InvalidSyntaxTree, "? operator requires last return value to implement error, got %s", lastType)
goto Error
}
}
}
if check.sig == nil {
check.errorf(e, InvalidSyntaxTree, "? operator used outside of function")
goto Error
}
{
// Rule 3.1 : The enclosing function must return at least one value
encResults := check.sig.results
if encResults == nil || encResults.Len() == 0 {
check.errorf(e, InvalidSyntaxTree, "cannot use ? in function that does not return error")
goto Error
}
// Rule 3.2 : The enclosing function's last value implements Go's error interface
encLastType := encResults.vars[encResults.Len()-1].typ
if !Identical(encLastType, universeError) {
if !check.implements(encLastType, universeError, false, nil) {
check.errorf(e, InvalidSyntaxTree, "cannot use ? in function whose last return type is not error (got %s)", encLastType)
goto Error
}
}
}
// Type Alteration 1 : Strip last error value
{
nonErrVars := resultVars[:len(resultVars)-1]
if len(nonErrVars) == 0 {
// Type Alteration 2 : Change from expression to statement
x.mode_ = novalue
x.typ_ = nil
x.expr = e
return statement
} else if len(nonErrVars) == 1 {
// (T, error) -> T
x.typ_ = nonErrVars[0].typ
x.mode_ = value
} else {
// (T, U, ..., error) -> (T, U, ...)
x.typ_ = NewTuple(nonErrVars...)
x.mode_ = value
}
x.expr = e
}
// ...
x.expr = e
return expression
}
}
Desugaring
At this point, we have a fully valid and typechecked AST, but since the ? operator is just syntactic sugar, we want to desugar it before the compiler goes into its SSA and machine code generation.
So, the goal is to turn our high-level QuestionExpr AST into an IR that does the if err != nil { return err } check.
The Go compiler’s uses a serialization format called pkgbits to pass the syntax tree between the compiler frontend and backend phases. The AST nodes are serialized into a file, then deserialized into IR nodes.
In codes.go we register our new node for this serialization process.
const (
exprConst codeExpr = iota
// ...
exprQuestion
// ...
)
Then, in writer.go (the serialization part) we plug in our node type
func (w *writer) expr(expr syntax.Expr) {
// ...
switch expr := expr.(type) {
// ...
case *syntax.QuestionExpr:
w.Code(exprQuestion)
w.pos(expr)
w.expr(expr.X)
// ...
}
}
In reader.go (the deserialization part), we do the desugaring work. When it encounters an exprQuestion AST node, it constructs the equivalent if err != nil { return zero..., err } as IR nodes.
The logic has three steps :
- Allocate temporary variables for every value the call returns (
tmpA, tmpB, tmpErr := f()) - Build the error check
if tmpErr != nil { return ..., tmpErr }, and useir.NewZero()to generate a zero value for each type in the enclosing function - Wrap everything in the right IR node depending on how many values we return. There are three cases :
- Zero non-error values,
f() -> Error - One non-error value,
f() -> (T, Error) - Multiple non-error values
f() -> (T, U, ..., Error)
- Zero non-error values,
For zero non-error values, we wrap the statement in an InlinedCallExpr with no results which produces no value, matching the statement kind we set in the type checker.
For one non-error value, we use InitExpr which attaches init statements to a single value expression.
Finally, for multiple non-error values, we construct an InlinedCallExpr whose result is a tuple of the non-error temps.
// expr reads and returns a typechecked expression.
func (r *reader) expr() (res ir.Node) {
defer func() {
if res != nil && res.Typecheck() == 0 {
base.FatalfAt(res.Pos(), "%v missed typecheck", res)
}
}()
switch tag := codeExpr(r.Code(pkgbits.SyncExpr)); tag {
// ...
case exprQuestion:
pos := r.pos()
x := r.expr()
var init ir.Nodes
typ := x.Type()
var errTmp *ir.Name
var nonErrTmps []ir.Node
// Step 1: Allocate temps and build the assignment
//
// also note that in the go compiler the FuncArgStruct is used when there are multiple return values
// if its not of that type then we know it has exactly 1 return value (Because of our no zero return rule it has to be one return)
if typ.IsFuncArgStruct() {
results := typ.Fields()
var lhs []ir.Node
for i, f := range results {
tmp := r.temp(pos, f.Type)
init.Append(typecheck.Stmt(ir.NewDecl(pos, ir.ODCL, tmp)))
lhs = append(lhs, tmp)
// the last element is the Error
if i == len(results)-1 {
errTmp = tmp
} else {
nonErrTmps = append(nonErrTmps, tmp)
}
}
as := ir.NewAssignListStmt(pos, ir.OAS2, lhs, []ir.Node{x})
as.Def = true
init.Append(typecheck.Stmt(as))
} else {
// Single-return case: f() -> error
errTmp = r.temp(pos, typ)
init.Append(typecheck.Stmt(ir.NewDecl(pos, ir.ODCL, errTmp)))
as := ir.NewAssignStmt(pos, errTmp, x)
as.Def = true
init.Append(typecheck.Stmt(as))
}
// Step 2: Build the error check
nilNode := ir.NewNilExpr(pos, errTmp.Type())
cond := typecheck.Expr(ir.NewBinaryExpr(pos, ir.ONE, errTmp, nilNode))
ret := ir.NewReturnStmt(pos, nil)
funcResults := r.curfn.Type().Results()
for i, f := range funcResults {
if i == len(funcResults)-1 {
// Propagate the error at the last position
ret.Results.Append(errTmp)
} else {
// Everything else is filled with zero values
ret.Results.Append(ir.NewZero(pos, f.Type))
}
}
ifStmt := ir.NewIfStmt(pos, cond, []ir.Node{ret}, nil)
init.Append(typecheck.Stmt(ifStmt))
// Step 3: Wrap in the right IR node, remember the 3 cases
if len(nonErrTmps) == 0 {
// Step 3.1: Zero non-error values
inlcd := ir.NewInlinedCallExpr(pos, init, nil)
inlcd.SetTypecheck(1)
inlcd.SetType(nil)
return inlcd
} else if len(nonErrTmps) == 1 {
// Step 3.2: One non-error value
return ir.InitExpr(init, nonErrTmps[0])
} else {
// Step 3.3: Multiple non-error values
// Construct an InlinedCallExpr with the new tuple of return values
// more on this below
inlcd := ir.NewInlinedCallExpr(pos, init, nonErrTmps)
inlcd.SetTypecheck(1)
fields := make([]*types.Field, len(nonErrTmps))
for i, tmp := range nonErrTmps {
fields[i] = types.NewField(pos, nil, tmp.Type())
}
t := types.NewStruct(fields)
t.StructType().ParamTuple = true
inlcd.SetType(t)
return inlcd
}
// ...
}
}
For the multiple non-error values, we need an InlinedCallExpr. There’s no IR expression node that evaluates to multiple values. InlinedCallExpr
is the compiler’s existing way to represent this (but not really meant for this I dont think)
We borrow it here nonetheless, with a little workaround
A little workaround
In the previous step, we saw that for multi-value assignments we create an InlinedCallExpr, which breaks an assumption in the compiler’s IR-level type checker.
Multi-value assignments like a, b := f() are handled by the assign function in stmt.go, but this function makes the assumption that the right-hand side of a
multi-value assignment is always an CallExpr, because in normal Go, the only way to produce multiple values is a function call (not exactly, but other ways are handled by different assignment ops in the compiler). So, for this workaround, we are treating the new ? as though it was an inlined function
func assign(stmt ir.Node, lhs, rhs []ir.Node) {
// ...
// x,y,z = f()
if cr > len(rhs) {
stmt := stmt.(*ir.AssignListStmt)
stmt.SetOp(ir.OAS2FUNC)
// Originally this did a direct type assertion to CallExpr.
// We also need to accept InlinedCallExpr, since that's
// what our ? desugaring produces
r := rhs[0]
var rtyp *types.Type
if call, ok := r.(*ir.CallExpr); ok {
rtyp = call.Type()
} else if icall, ok := r.(*ir.InlinedCallExpr); ok {
rtyp = icall.Type()
} else {
base.Fatalf("assign list unexpected rhs: %v", r.Op())
}
mismatched := false
failed := false
for i := range lhs {
result := rtyp.Field(i).Type
assignType(i, result)
if lhs[i].Type() == nil || result == nil {
failed = true
} else if lhs[i] != ir.BlankNode && !types.Identical(lhs[i].Type(), result) {
mismatched = true
}
}
// Fail on type mismatches without doing implicit conversion like a CallExpr would normally do
// good enough for the scope of this project :)
if mismatched && !failed {
if call, ok := r.(*ir.CallExpr); ok {
RewriteMultiValueCall(stmt, call)
} else {
base.Fatalf("assign list mismatched with InlinedCallExpr not supported")
}
}
return
}
// ...
}
And that’s it! We have a zero-cost abstraction for simple error propagation!
Some Examples
Now for some examples.
First, some working examples :
func run() error {
os.WriteFile("test.txt", []byte("hello"), 0644)?
host, strPort := net.SplitHostPort("localhost:3000")?
fmt.Println("Host:", host)
port := strconv.Atoi(strPort)?
fmt.Println("Port:", port)
doubled := strconv.Atoi("8080")? * 2
fmt.Println("Doubled:", doubled)
return nil
}
Which results in :
Host: localhost
Port: 3000
Doubled: 16160
And now for some compilation errors :
func bad1() error {
x := 42?
_ = x
return nil
}
func isValid(s string) bool {
return len(s) > 0
}
func bad2() error {
v := isValid("hello")?
_ = v
return nil
}
func bad3() {
_ = strconv.Atoi("8080")?
}
Which results in the following comp time errors :
./errors.go:8:9: invalid syntax tree: ? operator requires a function call
./errors.go:18:23: invalid syntax tree: ? operator requires last return value to implement error, got bool
./errors.go:24:26: invalid syntax tree: cannot use ? in function that does not return error
Code
All the code that was used for this is available at https://github.com/sam-harri/go-question
If there are any inaccuracies, no-no’s, or else, please reach out at samharrison@cs.toronto.edu
This project was made entirely for fun, and I think it was :)