Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Synchronize parallel calls to Conn.Close and Conn.handshake #671

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 30 additions & 10 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,13 @@ type Conn struct {

handshakeCompletedSuccessfully atomic.Value
handshakeMutex sync.Mutex
handshakeDone chan struct{}

encryptedPackets []addrPkt

connectionClosedByUser bool
closeLock sync.Mutex
closed *closer.Closer
handshakeLoopsFinished sync.WaitGroup

readDeadline *deadline.Deadline
writeDeadline *deadline.Deadline
Expand Down Expand Up @@ -256,6 +256,12 @@ func (c *Conn) HandshakeContext(ctx context.Context) error {
return nil
}

handshakeDone := make(chan struct{})
defer close(handshakeDone)
c.closeLock.Lock()
c.handshakeDone = handshakeDone
c.closeLock.Unlock()

// rfc5246#section-7.4.3
// In addition, the hash and signature algorithms MUST be compatible
// with the key in the server's end-entity certificate.
Expand Down Expand Up @@ -405,7 +411,12 @@ func (c *Conn) Write(p []byte) (int, error) {
// Close closes the connection.
func (c *Conn) Close() error {
err := c.close(true) //nolint:contextcheck
c.handshakeLoopsFinished.Wait()
c.closeLock.Lock()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

poor c.closeLock gets locked and unlocked thrice in a call to c.Close (twice in close + the call here), I guess I could get the channel in c.close and return it? Doesn't fit semantically for close to return the channel, but it is an internal function, so it doesn't matter much.

handshakeDone := c.handshakeDone
c.closeLock.Unlock()
if handshakeDone != nil {
<-handshakeDone
}
return err
}

Expand Down Expand Up @@ -1026,7 +1037,6 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh

done := make(chan struct{})
ctxRead, cancelRead := context.WithCancel(context.Background())
c.cancelHandshakeReader = cancelRead
cfg.onFlightState = func(_ flightVal, s handshakeState) {
if s == handshakeFinished && !c.isHandshakeCompletedSuccessfully() {
c.setHandshakeCompletedSuccessfully()
Expand All @@ -1035,16 +1045,21 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
}

ctxHs, cancel := context.WithCancel(context.Background())

c.closeLock.Lock()
c.cancelHandshaker = cancel
c.cancelHandshakeReader = cancelRead
c.closeLock.Unlock()

firstErr := make(chan error, 1)

c.handshakeLoopsFinished.Add(2)
var handshakeLoopsFinished sync.WaitGroup
handshakeLoopsFinished.Add(2)

// Handshake routine should be live until close.
// The other party may request retransmission of the last flight to cope with packet drop.
go func() {
defer c.handshakeLoopsFinished.Done()
defer handshakeLoopsFinished.Done()
err := c.fsm.Run(ctxHs, c, initialState)
if !errors.Is(err, context.Canceled) {
select {
Expand All @@ -1064,7 +1079,7 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
// Force stop handshaker when the underlying connection is closed.
cancel()
}()
defer c.handshakeLoopsFinished.Done()
defer handshakeLoopsFinished.Done()
for {
if err := c.readAndBuffer(ctxRead); err != nil {
var e *alertError
Expand Down Expand Up @@ -1123,12 +1138,12 @@ func (c *Conn) handshake(ctx context.Context, cfg *handshakeConfig, initialFligh
case err := <-firstErr:
cancelRead()
cancel()
c.handshakeLoopsFinished.Wait()
handshakeLoopsFinished.Wait()
return c.translateHandshakeCtxError(err)
case <-ctx.Done():
cancelRead()
cancel()
c.handshakeLoopsFinished.Wait()
handshakeLoopsFinished.Wait()
return c.translateHandshakeCtxError(ctx.Err())
case <-done:
return nil
Expand All @@ -1146,8 +1161,13 @@ func (c *Conn) translateHandshakeCtxError(err error) error {
}

func (c *Conn) close(byUser bool) error {
c.cancelHandshaker()
c.cancelHandshakeReader()
c.closeLock.Lock()
cancelHandshaker := c.cancelHandshaker
cancelHandshakeReader := c.cancelHandshakeReader
c.closeLock.Unlock()

cancelHandshaker()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Sean-Der is it important to invoke these cancels at the beginning? if not I could move thes copying to line 1179 (after closeLock gets locked) and the invocations to 1187 (after it gets unlocked)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests succeed locally when I do it like that, so maybe the order of calls here isn't imporant, in which case I can do it like that and avoid multiple locking/unlocking of the lock.

cancelHandshakeReader()

if c.isHandshakeCompletedSuccessfully() && byUser {
// Discard error from notify() to return non-error on the first user call of Close()
Expand Down
51 changes: 51 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3662,3 +3662,54 @@ func TestMultiHandshake(t *testing.T) {
t.Fatal(err)
}
}

func TestCloseDuringHandshake(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 10).Stop()

serverCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}

for i := 0; i < 100; i++ {
_, cb := dpipe.Pipe()
server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{
Certificates: []tls.Certificate{serverCert},
})
if err != nil {
t.Fatal(err)
}

waitChan := make(chan struct{})
go func() {
close(waitChan)
_ = server.Handshake()
}()

<-waitChan
if err = server.Close(); err != nil {
t.Fatal(err)
}
}
}

func TestCloseWithoutHandshake(t *testing.T) {
defer test.CheckRoutines(t)()
defer test.TimeOut(time.Second * 10).Stop()

serverCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}
_, cb := dpipe.Pipe()
server, err := Server(dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{
Certificates: []tls.Certificate{serverCert},
})
if err != nil {
t.Fatal(err)
}
if err = server.Close(); err != nil {
t.Fatal(err)
}
}
Loading