diff --git a/x/httpconnect/connect_client.go b/x/httpconnect/connect_client.go new file mode 100644 index 00000000..03990bc2 --- /dev/null +++ b/x/httpconnect/connect_client.go @@ -0,0 +1,131 @@ +// Copyright 2025 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpconnect + +import ( + "context" + "errors" + "fmt" + "github.com/Jigsaw-Code/outline-sdk/transport" + "io" + "net" + "net/http" +) + +// connectClient is a [transport.StreamDialer] implementation that dials [proxyAddr] with the given [dialer] +// and sends a CONNECT request to the dialed proxy. +type connectClient struct { + dialer transport.StreamDialer + proxyAddr string + + headers http.Header +} + +var _ transport.StreamDialer = (*connectClient)(nil) + +type ClientOption func(c *connectClient) + +func NewConnectClient(dialer transport.StreamDialer, proxyAddr string, opts ...ClientOption) (transport.StreamDialer, error) { + if dialer == nil { + return nil, errors.New("dialer must not be nil") + } + _, _, err := net.SplitHostPort(proxyAddr) + if err != nil { + return nil, fmt.Errorf("failed to parse proxy address %s: %w", proxyAddr, err) + } + + cc := &connectClient{ + dialer: dialer, + proxyAddr: proxyAddr, + headers: make(http.Header), + } + + for _, opt := range opts { + opt(cc) + } + + return cc, nil +} + +// WithHeaders appends the given [headers] to the CONNECT request +func WithHeaders(headers http.Header) ClientOption { + return func(c *connectClient) { + c.headers = headers.Clone() + } +} + +// DialStream - connects to the proxy and sends a CONNECT request to it, closes the connection if the request fails +func (cc *connectClient) DialStream(ctx context.Context, remoteAddr string) (transport.StreamConn, error) { + innerConn, err := cc.dialer.DialStream(ctx, cc.proxyAddr) + if err != nil { + return nil, fmt.Errorf("failed to dial proxy %s: %w", cc.proxyAddr, err) + } + + conn, err := cc.doConnect(ctx, remoteAddr, innerConn) + if err != nil { + _ = innerConn.Close() + return nil, fmt.Errorf("doConnect %s: %w", remoteAddr, err) + } + + return conn, nil +} + +func (cc *connectClient) doConnect(ctx context.Context, remoteAddr string, conn transport.StreamConn) (transport.StreamConn, error) { + _, _, err := net.SplitHostPort(remoteAddr) + if err != nil { + return nil, fmt.Errorf("failed to parse remote address %s: %w", remoteAddr, err) + } + + pr, pw := io.Pipe() + + req, err := http.NewRequestWithContext(ctx, http.MethodConnect, "http://"+remoteAddr, pr) // TODO: HTTPS support + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.ContentLength = -1 // -1 means length unknown + mergeHeaders(req.Header, cc.headers) + + tr := &http.Transport{ + // TODO: HTTP/2 support with [http2.ConfigureTransport] + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return conn, nil + }, + } + + hc := http.Client{ + Transport: tr, + } + + resp, err := hc.Do(req) + if err != nil { + return nil, fmt.Errorf("do: %w", err) + } + if resp.StatusCode != http.StatusOK { + _ = resp.Body.Close() + return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + + return &pipeConn{ + reader: resp.Body, + writer: pw, + StreamConn: conn, + }, nil +} + +func mergeHeaders(dst http.Header, src http.Header) { + for k, v := range src { + dst[k] = append(dst[k], v...) + } +} diff --git a/x/httpconnect/connect_client_test.go b/x/httpconnect/connect_client_test.go new file mode 100644 index 00000000..0c28e03b --- /dev/null +++ b/x/httpconnect/connect_client_test.go @@ -0,0 +1,109 @@ +// Copyright 2025 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpconnect + +import ( + "bufio" + "context" + "encoding/base64" + "github.com/Jigsaw-Code/outline-sdk/transport" + "github.com/Jigsaw-Code/outline-sdk/x/httpproxy" + "github.com/stretchr/testify/require" + "net" + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestConnectClientOk(t *testing.T) { + t.Parallel() + + creds := base64.StdEncoding.EncodeToString([]byte("username:password")) + + targetSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method, "Method") + w.WriteHeader(http.StatusOK) + _, err := w.Write([]byte("HTTP/1.1 200 OK\r\n")) + require.NoError(t, err) + })) + defer targetSrv.Close() + + targetURL, err := url.Parse(targetSrv.URL) + require.NoError(t, err) + + tcpDialer := &transport.TCPDialer{Dialer: net.Dialer{}} + connectHandler := httpproxy.NewConnectHandler(tcpDialer) + proxySrv := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + require.Equal(t, "Basic "+creds, request.Header.Get("Proxy-Authorization")) + connectHandler.ServeHTTP(writer, request) + })) + defer proxySrv.Close() + + proxyURL, err := url.Parse(proxySrv.URL) + require.NoError(t, err, "Parse") + + connClient, err := NewConnectClient( + tcpDialer, + proxyURL.Host, + WithHeaders(http.Header{"Proxy-Authorization": []string{"Basic " + creds}}), + ) + require.NoError(t, err, "NewConnectClient") + + streamConn, err := connClient.DialStream(context.Background(), targetURL.Host) + require.NoError(t, err, "DialStream") + require.NotNil(t, streamConn, "StreamConn") + + req, err := http.NewRequest(http.MethodGet, targetSrv.URL, nil) + require.NoError(t, err, "NewRequest") + + err = req.Write(streamConn) + require.NoError(t, err, "Write") + + resp, err := http.ReadResponse(bufio.NewReader(streamConn), req) + require.NoError(t, err, "ReadResponse") + + require.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestConnectClientFail(t *testing.T) { + t.Parallel() + + targetURL := "somehost:1234" + + proxySrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodConnect, r.Method, "Method") + require.Equal(t, targetURL, r.Host, "Host") + + w.WriteHeader(http.StatusBadRequest) + _, err := w.Write([]byte("HTTP/1.1 400 Bad request\r\n\r\n")) + require.NoError(t, err, "Write") + })) + defer proxySrv.Close() + + proxyURL, err := url.Parse(proxySrv.URL) + require.NoError(t, err, "Parse") + + connClient, err := NewConnectClient( + &transport.TCPDialer{ + Dialer: net.Dialer{}, + }, + proxyURL.Host, + ) + require.NoError(t, err, "NewConnectClient") + + _, err = connClient.DialStream(context.Background(), targetURL) + require.Error(t, err, "unexpected status code: 400") +} diff --git a/x/httpconnect/doc.go b/x/httpconnect/doc.go new file mode 100644 index 00000000..54418797 --- /dev/null +++ b/x/httpconnect/doc.go @@ -0,0 +1,16 @@ +// Copyright 2025 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package httpconnect contains an HTTP CONNECT client implementation. +package httpconnect diff --git a/x/httpconnect/pipe_conn.go b/x/httpconnect/pipe_conn.go new file mode 100644 index 00000000..50f174a3 --- /dev/null +++ b/x/httpconnect/pipe_conn.go @@ -0,0 +1,50 @@ +// Copyright 2025 The Outline Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package httpconnect + +import ( + "errors" + "github.com/Jigsaw-Code/outline-sdk/transport" + "io" +) + +var _ transport.StreamConn = (*pipeConn)(nil) + +// pipeConn is a [transport.StreamConn] that overrides [Read], [Write] (and corresponding [Close]) functions with the given [reader] and [writer] +type pipeConn struct { + reader io.ReadCloser + writer io.WriteCloser + transport.StreamConn +} + +func (p *pipeConn) Read(b []byte) (n int, err error) { + return p.reader.Read(b) +} + +func (p *pipeConn) Write(b []byte) (n int, err error) { + return p.writer.Write(b) +} + +func (p *pipeConn) CloseRead() error { + return errors.Join(p.reader.Close(), p.StreamConn.CloseRead()) +} + +func (p *pipeConn) CloseWrite() error { + return errors.Join(p.writer.Close(), p.StreamConn.CloseWrite()) +} + +func (p *pipeConn) Close() error { + return errors.Join(p.reader.Close(), p.writer.Close(), p.StreamConn.Close()) +}