golang源码分析:httpmock

2023-09-06 19:23:48 浏览数 (2)

https://github.com/jarcoal/httpmock是一个mock http请求包,他的原理是使用MockTransport替换http包client的Transport RoundTripper,并注册请求对应的返回值。当http请求发出的时候,被mock的Transport拦截,通过路径匹配找到对应的response,实现了http请求的mock,它的使用方式如下:

代码语言:javascript复制
httpmock.Activate()
defer httpmock.DeactivateAndReset()
httpmock.RegisterResponder("GET", api, httpmock.NewStringResponder(200, string(mockResponse)))

然后就可以发送http请求了。下面我们分析下它的源码。

分析httpmock源码之前我们先分析下http.Client的源码。我们发起http请求的时候,调用方式如下:

代码语言:javascript复制
myClient := &http.Client{}
response, err := myClient.Do(request)

源码位于src/net/http/client.go

代码语言:javascript复制
func (c *Client) Do(req *Request) (*Response, error) {
  return c.do(req)
}
代码语言:javascript复制
func (c *Client) do(req *Request) (retres *Response, reterr error) {
      if resp, didTimeout, err = c.send(req, deadline); err != nil {

在发送请求前先获取RoundTripper,然后调用RoundTrip方法

代码语言:javascript复制
func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) {
      resp, didTimeout, err = send(req, c.transport(), deadline)
代码语言:javascript复制
func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) {
      resp, err = rt.RoundTrip(req)

如果用户没有自己定义transport,会使用DefaultTransport

代码语言:javascript复制
func (c *Client) transport() RoundTripper {
  if c.Transport != nil {
    return c.Transport
  }
  return DefaultTransport
}
代码语言:javascript复制
type Client struct {
  // Transport specifies the mechanism by which individual
  // HTTP requests are made.
  // If nil, DefaultTransport is used.
  Transport RoundTripper
}

RoundTripper是client的一个属性,类型是interface,它有一个方法RoundTrip

代码语言:javascript复制
type RoundTripper interface { 
 RoundTrip(*Request) (*Response, error)
}

DefaultTransport定义在src/net/http/transport.go

代码语言:javascript复制
var DefaultTransport RoundTripper = &Transport{
  Proxy: ProxyFromEnvironment,
  DialContext: defaultTransportDialContext(&net.Dialer{
    Timeout:   30 * time.Second,
    KeepAlive: 30 * time.Second,
  }),
  ForceAttemptHTTP2:     true,
  MaxIdleConns:          100,
  IdleConnTimeout:       90 * time.Second,
  TLSHandshakeTimeout:   10 * time.Second,
  ExpectContinueTimeout: 1 * time.Second,
}
代码语言:javascript复制
type Transport struct {
  idleMu       sync.Mutex
  closeIdle    bool                                // user has requested to close all idle conns
  idleConn     map[connectMethodKey][]*persistConn // most recently used at end
  idleConnWait map[connectMethodKey]wantConnQueue  // waiting getConns
  idleLRU      connLRU


  reqMu       sync.Mutex
  reqCanceler map[cancelKey]func(error)


  altMu    sync.Mutex   // guards changing altProto only
  altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme


  connsPerHostMu   sync.Mutex
  connsPerHost     map[connectMethodKey]int
  connsPerHostWait map[connectMethodKey]wantConnQueue // waiting getConns


  // Proxy specifies a function to return a proxy for a given
  // Request. If the function returns a non-nil error, the
  // request is aborted with the provided error.
  //
  // The proxy type is determined by the URL scheme. "http",
  // "https", and "socks5" are supported. If the scheme is empty,
  // "http" is assumed.
  //
  // If Proxy is nil or returns a nil *URL, no proxy is used.
  Proxy func(*Request) (*url.URL, error)


  // DialContext specifies the dial function for creating unencrypted TCP connections.
  // If DialContext is nil (and the deprecated Dial below is also nil),
  // then the transport dials using package net.
  //
  // DialContext runs concurrently with calls to RoundTrip.
  // A RoundTrip call that initiates a dial may end up using
  // a connection dialed previously when the earlier connection
  // becomes idle before the later DialContext completes.
  DialContext func(ctx context.Context, network, addr string) (net.Conn, error)


  // Dial specifies the dial function for creating unencrypted TCP connections.
  //
  // Dial runs concurrently with calls to RoundTrip.
  // A RoundTrip call that initiates a dial may end up using
  // a connection dialed previously when the earlier connection
  // becomes idle before the later Dial completes.
  //
  // Deprecated: Use DialContext instead, which allows the transport
  // to cancel dials as soon as they are no longer needed.
  // If both are set, DialContext takes priority.
  Dial func(network, addr string) (net.Conn, error)


  // DialTLSContext specifies an optional dial function for creating
  // TLS connections for non-proxied HTTPS requests.
  //
  // If DialTLSContext is nil (and the deprecated DialTLS below is also nil),
  // DialContext and TLSClientConfig are used.
  //
  // If DialTLSContext is set, the Dial and DialContext hooks are not used for HTTPS
  // requests and the TLSClientConfig and TLSHandshakeTimeout
  // are ignored. The returned net.Conn is assumed to already be
  // past the TLS handshake.
  DialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)


  // DialTLS specifies an optional dial function for creating
  // TLS connections for non-proxied HTTPS requests.
  //
  // Deprecated: Use DialTLSContext instead, which allows the transport
  // to cancel dials as soon as they are no longer needed.
  // If both are set, DialTLSContext takes priority.
  DialTLS func(network, addr string) (net.Conn, error)


  // TLSClientConfig specifies the TLS configuration to use with
  // tls.Client.
  // If nil, the default configuration is used.
  // If non-nil, HTTP/2 support may not be enabled by default.
  TLSClientConfig *tls.Config


  // TLSHandshakeTimeout specifies the maximum amount of time waiting to
  // wait for a TLS handshake. Zero means no timeout.
  TLSHandshakeTimeout time.Duration


  // DisableKeepAlives, if true, disables HTTP keep-alives and
  // will only use the connection to the server for a single
  // HTTP request.
  //
  // This is unrelated to the similarly named TCP keep-alives.
  DisableKeepAlives bool


  // DisableCompression, if true, prevents the Transport from
  // requesting compression with an "Accept-Encoding: gzip"
  // request header when the Request contains no existing
  // Accept-Encoding value. If the Transport requests gzip on
  // its own and gets a gzipped response, it's transparently
  // decoded in the Response.Body. However, if the user
  // explicitly requested gzip it is not automatically
  // uncompressed.
  DisableCompression bool


  // MaxIdleConns controls the maximum number of idle (keep-alive)
  // connections across all hosts. Zero means no limit.
  MaxIdleConns int


  // MaxIdleConnsPerHost, if non-zero, controls the maximum idle
  // (keep-alive) connections to keep per-host. If zero,
  // DefaultMaxIdleConnsPerHost is used.
  MaxIdleConnsPerHost int


  // MaxConnsPerHost optionally limits the total number of
  // connections per host, including connections in the dialing,
  // active, and idle states. On limit violation, dials will block.
  //
  // Zero means no limit.
  MaxConnsPerHost int


  // IdleConnTimeout is the maximum amount of time an idle
  // (keep-alive) connection will remain idle before closing
  // itself.
  // Zero means no limit.
  IdleConnTimeout time.Duration


  // ResponseHeaderTimeout, if non-zero, specifies the amount of
  // time to wait for a server's response headers after fully
  // writing the request (including its body, if any). This
  // time does not include the time to read the response body.
  ResponseHeaderTimeout time.Duration


  // ExpectContinueTimeout, if non-zero, specifies the amount of
  // time to wait for a server's first response headers after fully
  // writing the request headers if the request has an
  // "Expect: 100-continue" header. Zero means no timeout and
  // causes the body to be sent immediately, without
  // waiting for the server to approve.
  // This time does not include the time to send the request header.
  ExpectContinueTimeout time.Duration


  // TLSNextProto specifies how the Transport switches to an
  // alternate protocol (such as HTTP/2) after a TLS ALPN
  // protocol negotiation. If Transport dials an TLS connection
  // with a non-empty protocol name and TLSNextProto contains a
  // map entry for that key (such as "h2"), then the func is
  // called with the request's authority (such as "example.com"
  // or "example.com:1234") and the TLS connection. The function
  // must return a RoundTripper that then handles the request.
  // If TLSNextProto is not nil, HTTP/2 support is not enabled
  // automatically.
  TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper


  // ProxyConnectHeader optionally specifies headers to send to
  // proxies during CONNECT requests.
  // To set the header dynamically, see GetProxyConnectHeader.
  ProxyConnectHeader Header


  // GetProxyConnectHeader optionally specifies a func to return
  // headers to send to proxyURL during a CONNECT request to the
  // ip:port target.
  // If it returns an error, the Transport's RoundTrip fails with
  // that error. It can return (nil, nil) to not add headers.
  // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is
  // ignored.
  GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (Header, error)


  // MaxResponseHeaderBytes specifies a limit on how many
  // response bytes are allowed in the server's response
  // header.
  //
  // Zero means to use a default limit.
  MaxResponseHeaderBytes int64


  // WriteBufferSize specifies the size of the write buffer used
  // when writing to the transport.
  // If zero, a default (currently 4KB) is used.
  WriteBufferSize int


  // ReadBufferSize specifies the size of the read buffer used
  // when reading from the transport.
  // If zero, a default (currently 4KB) is used.
  ReadBufferSize int


  // nextProtoOnce guards initialization of TLSNextProto and
  // h2transport (via onceSetNextProtoDefaults)
  nextProtoOnce      sync.Once
  h2transport        h2Transport // non-nil if http2 wired up
  tlsNextProtoWasNil bool        // whether TLSNextProto was nil when the Once fired


  // ForceAttemptHTTP2 controls whether HTTP/2 is enabled when a non-zero
  // Dial, DialTLS, or DialContext func or TLSClientConfig is provided.
  // By default, use of any those fields conservatively disables HTTP/2.
  // To use a custom dialer or TLS config and still attempt HTTP/2
  // upgrades, set this to true.
  ForceAttemptHTTP2 bool
}

src/net/http/roundtrip.go内实现了方法:

代码语言:javascript复制
 func (t *Transport) RoundTrip(req *Request) (*Response, error) {
  return t.roundTrip(req)
}
代码语言:javascript复制
func (t *Transport) roundTrip(req *Request) (*Response, error) {
    t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
   if altRT := t.alternateRoundTripper(req); altRT != nil {
    if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
        if pconn.alt != nil {
      // HTTP/2 path.
      t.setReqCanceler(cancelKey, nil) // not cancelable with CancelRequest
      resp, err = pconn.alt.RoundTrip(req)
    } else {
      resp, err = pconn.roundTrip(treq)
    }
代码语言:javascript复制
func (t *Transport) onceSetNextProtoDefaults() {
      altProto, _ := t.altProto.Load().(map[string]RoundTripper)
代码语言:javascript复制
func (t *Transport) onceSetNextProtoDefaults() {
代码语言:javascript复制
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
代码语言:javascript复制
func (t *Transport) alternateRoundTripper(req *Request) RoundTripper {
  if !t.useRegisteredProtocol(req) {
    return nil
  }
  altProto, _ := t.altProto.Load().(map[string]RoundTripper)
  return altProto[req.URL.Scheme]
}
代码语言:javascript复制
func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
  t.altMu.Lock()
      oldMap, _ := t.altProto.Load().(map[string]RoundTripper)
  if _, exists := oldMap[scheme]; exists {
    panic("protocol "   scheme   " already registered")
  }
      newMap := make(map[string]RoundTripper)
  for k, v := range oldMap {
    newMap[k] = v
  }
  newMap[scheme] = rt
  t.altProto.Store(newMap)

src/net/http/h2_bundle.go

代码语言:javascript复制
func http2registerHTTPSProtocol(t *Transport, rt http2noDialH2RoundTripper) (err error) {
        t.RegisterProtocol("https", rt)
代码语言:javascript复制
func http2configureTransports(t1 *Transport) (*http2Transport, error) {
      if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); 

熟悉完http client的方法后,我们开始分析下httpmock的源码:

首先看下激活方法github.com/jarcoal/httpmock@v1.3.0/transport.go,本质上就是用mock transport替换http包的default transport,并将http.DefaultTransport保存InitialTransport便于将来恢复

代码语言:javascript复制
func Activate() {
  if Disabled() {
    return
  }
  // make sure that if Activate is called multiple times it doesn't
  // overwrite the InitialTransport with a mock transport.
  if http.DefaultTransport != DefaultTransport {
    InitialTransport = http.DefaultTransport
  }

  http.DefaultTransport = DefaultTransport
}

可以通过环境变量禁用mock功能。

代码语言:javascript复制
  var envVarName = "GONOMOCKS"
代码语言:javascript复制
// Disabled allows to test whether httpmock is enabled or not. It
// depends on GONOMOCKS environment variable.
func Disabled() bool {
  return os.Getenv(envVarName) != ""
}

在mock中重点依赖的几个全局变量如下

代码语言:javascript复制
var DefaultTransport = NewMockTransport()
var InitialTransport = http.DefaultTransport
var oldClients = map[*http.Client]http.RoundTripper{}
var oldClientsLock sync.Mutex
代码语言:javascript复制
func NewMockTransport() *MockTransport {
  return &MockTransport{
    responders:    make(map[internal.RouteKey]matchResponders),
    callCountInfo: make(map[matchRouteKey]int),
  }
}
代码语言:javascript复制
type MockTransport struct {
  // DontCheckMethod disables standard methods check. By default, if
  // a responder is registered using a lower-cased method among CONNECT,
  // DELETE, GET, HEAD, OPTIONS, POST, PUT and TRACE, a panic occurs
  // as it is probably a mistake.
  DontCheckMethod  bool
  mu               sync.RWMutex
  responders       map[internal.RouteKey]matchResponders
  regexpResponders []regexpResponder
  noResponder      Responder
  callCountInfo    map[matchRouteKey]int
  totalCallCount   int
}

mockTransport的RundTrip方法如下:

代码语言:javascript复制
func (m *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) {
      for fromFindIdx := 0; ; {
    found, findIdx = m.findResponders(method, req.URL, fromFindIdx)
      suggested = m.suggestResponder(method, req.URL)
      m.callCountInfo[matchRouteKey{RouteKey: found.key, name: mr.matcher.name}]  
      return runCancelable(responder, internal.SetSubmatches(req, found.submatches))
代码语言:javascript复制
func (m *MockTransport) findResponders(method string, url *url.URL, fromIdx int) (
  found respondersFound,
  findForKeyIndex int,

) {

mock结束后要还原以前的transport

代码语言:javascript复制
func DeactivateAndReset() {
  Deactivate()
  Reset()
}
代码语言:javascript复制
func Deactivate() {
  http.DefaultTransport = InitialTransport
    for oldClient, oldTransport := range oldClients {
    oldClient.Transport = oldTransport
    delete(oldClients, oldClient)
      var oldClients = map[*http.Client]http.RoundTripper{}
      func Reset() {
  DefaultTransport.Reset()
}  
代码语言:javascript复制
func (m *MockTransport) Reset() {
        m.responders = make(map[internal.RouteKey]matchResponders)

下面看下注册mock返回值的方法github.com/jarcoal/httpmock@v1.3.0/response.go

代码语言:javascript复制
func NewStringResponder(status int, body string) Responder {
  return ResponderFromResponse(NewStringResponse(status, body))
}
代码语言:javascript复制
func ResponderFromResponse(resp *http.Response) Responder {
  return func(req *http.Request) (*http.Response, error) {
      res := *resp
代码语言:javascript复制
func NewStringResponse(status int, body string) *http.Response {
代码语言:javascript复制
func RegisterResponder(method, url string, responder Responder) {
  DefaultTransport.RegisterResponder(method, url, responder)
}

它就是将方法,url和返回值的对应关系,提前缓存起来,方便发起请求的时候寻找:

代码语言:javascript复制
func (m *MockTransport) RegisterResponder(method, url string, responder Responder) {
  m.RegisterMatcherResponder(method, url, Matcher{}, responder)
}
代码语言:javascript复制
func (m *MockTransport) RegisterMatcherResponder(method, url string, matcher Matcher, responder Responder) {
        mr := matchResponder{
    matcher:   matcher,
    responder: responder,
  }
      if isRegexpURL(url) {
    rr := regexpResponder{
      origRx:     url,
      method:     method,
      rx:         regexp.MustCompile(url[2:]),
      responders: matchResponders{mr},
    }
    m.registerRegexpResponder(rr)
        m.responders[key] = m.responders[key].add(mr)
    m.callCountInfo[matchRouteKey{RouteKey: key, name: matcher.name}] = 0
代码语言:javascript复制
func (m *MockTransport) registerRegexpResponder(rxResp regexpResponder) {
      mr := rxResp.responders[0]
        for {
    for i, rr := range m.regexpResponders {
      if rr.method == rxResp.method && rr.origRx == rxResp.origRx {
        if mr.responder == nil {
          rr.responders = rr.responders.remove(mr.matcher.name)
          if rr.responders == nil {
            copy(m.regexpResponders[:i], m.regexpResponders[i 1:])
            m.regexpResponders[len(m.regexpResponders)-1] = regexpResponder{}
            m.regexpResponders = m.regexpResponders[:len(m.regexpResponders)-1]

github.com/jarcoal/httpmock@v1.3.0/internal/submatches.go

代码语言:javascript复制
func SetSubmatches(req *http.Request, submatches []string) *http.Request {

github.com/jarcoal/httpmock@v1.3.0/match.go

代码语言:javascript复制
func (mrs matchResponders) findMatchResponder(req *http.Request) *matchResponder {
        for _, mr := range mrs {
    copyBody.rearm()
    if mr.matcher.Check(req) {
      return &mr
代码语言:javascript复制
func (m Matcher) Check(req *http.Request) bool {
  return m.fn.Check(req)
}
代码语言:javascript复制
type MatcherFunc func(req *http.Request) bool

src/net/http/response.go

代码语言:javascript复制
type Response struct {
  Status     string // e.g. "200 OK"
  StatusCode int    // e.g. 200
  Proto      string // e.g. "HTTP/1.0"
  ProtoMajor int    // e.g. 1
  ProtoMinor int    // e.g. 0


  // Header maps header keys to values. If the response had multiple
  // headers with the same key, they may be concatenated, with comma
  // delimiters.  (RFC 7230, section 3.2.2 requires that multiple headers
  // be semantically equivalent to a comma-delimited sequence.) When
  // Header values are duplicated by other fields in this struct (e.g.,
  // ContentLength, TransferEncoding, Trailer), the field values are
  // authoritative.
  //
  // Keys in the map are canonicalized (see CanonicalHeaderKey).
  Header Header


  // Body represents the response body.
  //
  // The response body is streamed on demand as the Body field
  // is read. If the network connection fails or the server
  // terminates the response, Body.Read calls return an error.
  //
  // The http Client and Transport guarantee that Body is always
  // non-nil, even on responses without a body or responses with
  // a zero-length body. It is the caller's responsibility to
  // close Body. The default HTTP client's Transport may not
  // reuse HTTP/1.x "keep-alive" TCP connections if the Body is
  // not read to completion and closed.
  //
  // The Body is automatically dechunked if the server replied
  // with a "chunked" Transfer-Encoding.
  //
  // As of Go 1.12, the Body will also implement io.Writer
  // on a successful "101 Switching Protocols" response,
  // as used by WebSockets and HTTP/2's "h2c" mode.
  Body io.ReadCloser


  // ContentLength records the length of the associated content. The
  // value -1 indicates that the length is unknown. Unless Request.Method
  // is "HEAD", values >= 0 indicate that the given number of bytes may
  // be read from Body.
  ContentLength int64


  // Contains transfer encodings from outer-most to inner-most. Value is
  // nil, means that "identity" encoding is used.
  TransferEncoding []string


  // Close records whether the header directed that the connection be
  // closed after reading Body. The value is advice for clients: neither
  // ReadResponse nor Response.Write ever closes a connection.
  Close bool


  // Uncompressed reports whether the response was sent compressed but
  // was decompressed by the http package. When true, reading from
  // Body yields the uncompressed content instead of the compressed
  // content actually set from the server, ContentLength is set to -1,
  // and the "Content-Length" and "Content-Encoding" fields are deleted
  // from the responseHeader. To get the original response from
  // the server, set Transport.DisableCompression to true.
  Uncompressed bool


  // Trailer maps trailer keys to values in the same
  // format as Header.
  //
  // The Trailer initially contains only nil values, one for
  // each key specified in the server's "Trailer" header
  // value. Those values are not added to Header.
  //
  // Trailer must not be accessed concurrently with Read calls
  // on the Body.
  //
  // After Body.Read has returned io.EOF, Trailer will contain
  // any trailer values sent by the server.
  Trailer Header


  // Request is the request that was sent to obtain this Response.
  // Request's Body is nil (having already been consumed).
  // This is only populated for Client requests.
  Request *Request


  // TLS contains information about the TLS connection on which the
  // response was received. It is nil for unencrypted responses.
  // The pointer is shared between responses and should not be
  // modified.
  TLS *tls.ConnectionState
}

0 人点赞