Skip to content

Commit

Permalink
Merge pull request #302 from coinbase/sgupta/do-not-override-httpclient
Browse files Browse the repository at this point in the history
Do not override the provided HTTP client in the fetcher
  • Loading branch information
swapna gupta authored Mar 3, 2021
2 parents 2522314 + cd1f20a commit 7c2169d
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 20 deletions.
46 changes: 26 additions & 20 deletions fetcher/fetcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,7 @@ func New(
serverAddress string,
options ...Option,
) *Fetcher {
// Create default fetcher
clientCfg := client.NewConfiguration(
serverAddress,
DefaultUserAgent,
&http.Client{
Timeout: DefaultHTTPTimeout,
})
client := client.NewAPIClient(clientCfg)

f := &Fetcher{
rosettaClient: client,
maxConnections: DefaultMaxConnections,
maxRetries: DefaultRetries,
retryElapsedTime: DefaultElapsedTime,
Expand All @@ -106,18 +96,34 @@ func New(
opt(f)
}

// Override transport idle connection settings
//
// See this conversation around why `.Clone()` is used here:
// https://github.com/golang/go/issues/26013
customTransport := http.DefaultTransport.(*http.Transport).Clone()
customTransport.IdleConnTimeout = DefaultIdleConnTimeout
customTransport.MaxIdleConns = f.maxConnections
customTransport.MaxIdleConnsPerHost = f.maxConnections
if f.rosettaClient == nil {
// Override transport idle connection settings
//
// See this conversation around why `.Clone()` is used here:
// https://github.com/golang/go/issues/26013
defaultTransport := http.DefaultTransport.(*http.Transport).Clone()
defaultTransport.IdleConnTimeout = DefaultIdleConnTimeout
defaultTransport.MaxIdleConns = f.maxConnections
defaultTransport.MaxIdleConnsPerHost = DefaultMaxConnections
defaultHTTPClient := &http.Client{
Timeout: DefaultHTTPTimeout,
Transport: defaultTransport,
}

// Create default fetcher
clientCfg := client.NewConfiguration(
serverAddress,
DefaultUserAgent,
defaultHTTPClient,
)
f.rosettaClient = client.NewAPIClient(clientCfg)
}

if f.insecureTLS {
customTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402
if transport, ok := f.rosettaClient.GetConfig().HTTPClient.Transport.(*http.Transport); ok {
transport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} // #nosec G402
}
}
f.rosettaClient.GetConfig().HTTPClient.Transport = customTransport

// Initialize the connection semaphore
f.connectionSemaphore = semaphore.NewWeighted(int64(f.maxConnections))
Expand Down
19 changes: 19 additions & 0 deletions fetcher/fetcher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/stretchr/testify/assert"

"github.com/coinbase/rosetta-sdk-go/asserter"
"github.com/coinbase/rosetta-sdk-go/client"
"github.com/coinbase/rosetta-sdk-go/types"
)

Expand Down Expand Up @@ -195,3 +196,21 @@ func TestInitializeAsserter(t *testing.T) {
})
}
}

func TestNewWithHTTPCLient(t *testing.T) {
// Callers can pass an http.Client to
// the fetcher via WithClient.
// Ensure that the fetcher does not
// override it.
httpClient := &http.Client{}
apiClient := client.NewAPIClient(
client.NewConfiguration(
"https://serveraddress",
DefaultUserAgent,
httpClient,
),
)
fetcher := New("https://serveraddress", WithClient(apiClient))
var assert = assert.New(t)
assert.Same(httpClient, fetcher.rosettaClient.GetConfig().HTTPClient)
}

0 comments on commit 7c2169d

Please sign in to comment.