Go语言中常见100问题-#50 comparing an error type

2022-08-15 15:19:38 浏览数 (1)

如何正确地通过类型对error进行比较是值得讨论的一个问题,在前面的问题Go语言中常见100问题-#49 wrap error中,介绍了可以通过%w对error进行wrap。但是一旦对error通过%w进行包装,后续在对包装后的error通过type进行判断的时候,必须采用合适的方法,否则将会出错。

下面通过一个具体的HTTP handler例子进行说明,该例功能是查询某个给定账号的交易金额。handler处理逻辑是从请求中获取账号id,然后从数据库中查询该账号的交易金额。下面的两种情况都会导致查询失败:

  1. 请求的ID无效(id字符串的长度不是5)
  2. 查询数据库失败

情况1,我们想返回一个StatusBadRequest (400)错误。情况2,我们想返回一个ServiceUnavailable (503)错误。为了实现这个目标,我们创建一个transientError类型的错误来暂时标记这个错误。调用方检查返回的error类型,如果是transientError类型,返回503错误码,否则返回400错误码。实例代码如下:

代码语言:javascript复制
type transientError struct {
     err error
}

func (t transientError) Error() string {
     return fmt.Sprintf("transient error: %v", t.err)
}

func getTransactionAmount(transactionID string) (float32, error) {
     if len(transactionID) != 5 {
             return 0, fmt.Errorf("id is invalid: %s", transactionID)
     }

     amount, err := getTransactionAmountFromDB(transactionID)
     if err != nil {
             return 0, transientError{err: err}
     }
     return amount, nil
}

如果是id无效,getTransactionAmount返回一个fmt.Error错误,如果是DB获取数据失败,返回经过包装的transientError错误。调用者handler方法中判断返回的error类型来决定返回合适的状态码。通过switch type,如果是DB查询失败,返回503,否则返回400。

代码语言:javascript复制
func handler(w http.ResponseWriter, r *http.Request) {
        transactionID := r.URL.Query().Get("transaction")

        amount, err := getTransactionAmount(transactionID)
        if err != nil {
                switch err := err.(type) {
                case transientError:
                        http.Error(w, err.Error(), http.StatusServiceUnavailable)
                default:
                        http.Error(w, err.Error(), http.StatusBadRequest)
                }
                return
        }

        // Write response
}

前面的代码没有任何问题。现在我们对getTransactionAmount进行一点重构,transientErrorgetTransactionAmountFromDB返回,不再是getTransactionAmount来返回。重构的代码如下

代码语言:javascript复制
func getTransactionAmount(transactionID string) (float32, error) {
        // Check transaction ID validity

        amount, err := getTransactionAmountFromDB(transactionID)
        if err != nil {
                return 0, fmt.Errorf("failed to get transaction %s: %w",
                        transactionID, err)
        }
        return amount, nil
}

func getTransactionAmountFromDB(transactionID string) (float32, error) {
        // ...
        if err != nil {
                return 0, transientError{err: err}
        }
        // ...
}

重构后的代码运行始终会返回400错误,即程序只会走到default分支,无法走到transientError分支,这是为什么呢?

重构之前,getTransactionAmount函数直接返回transientError

重构之后,getTransactionAmountFromDB函数直接返回transientErrorgetTransactionAmount返回的错误是一个被包装过的transientError类型,它的直接类型不再是transientError

对于上面的问题,Go1.13通过errors.As库函数可以判断一个wrap后的error是否是某种类型的错误。errors.As函数会递归对error进行unwrap,检查每一层的error是否是要比较的错误类型。实现代码如下

代码语言:javascript复制
func handler(w http.ResponseWriter, r *http.Request) {
        // Get transaction ID

        amount, err := getTransactionAmount(transactionID)
        if err != nil {
                if errors.As(err, &transientError{}) {
                        http.Error(w, err.Error(), http.StatusServiceUnavailable)
                } else {
                        http.Error(w, err.Error(), http.StatusBadRequest)
                }
                return
        }

        // Write response
}

errors.As函数的第二个入参必须是一个指针类型对象,否则在运行的时候会产生panic。使用errors.As函数,无论返回的error是一个transientError类型的error,还是一个将transientErrorwrap后的error,都能匹配成功。

总结起来,如果采用Go1.13的方法wrap error, 必须采用errors.As函数检查错误类型,因为errors.As会递归的unwrap error,判断每一层的error类型是否是需要匹配的目标类型,无论返回的error是否是经过wrap的,都可以匹配检查到。

0 人点赞