Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix segfault in State::serialize method #664

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -411,10 +411,14 @@ func (c *Conn) Close() error {

// ConnectionState returns basic DTLS details about the connection.
// Note that this replaced the `Export` function of v1.
func (c *Conn) ConnectionState() State {
func (c *Conn) ConnectionState() (State, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
return *c.state.clone()
stateClone, err := c.state.clone()
if err != nil {
return State{}, false
}
return *stateClone, true
}

// SelectedSRTPProtectionProfile returns the selected SRTPProtectionProfile
Expand Down
115 changes: 101 additions & 14 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,28 +497,40 @@ func TestExportKeyingMaterial(t *testing.T) {
c.setLocalEpoch(0)
c.setRemoteEpoch(0)

state := c.ConnectionState()
state, ok := c.ConnectionState()
if !ok {
t.Fatal("ConnectionState failed")
}
_, err := state.ExportKeyingMaterial(exportLabel, nil, 0)
if !errors.Is(err, errHandshakeInProgress) {
t.Errorf("ExportKeyingMaterial when epoch == 0: expected '%s' actual '%s'", errHandshakeInProgress, err)
}

c.setLocalEpoch(1)
state = c.ConnectionState()
state, ok = c.ConnectionState()
if !ok {
t.Fatal("ConnectionState failed")
}
_, err = state.ExportKeyingMaterial(exportLabel, []byte{0x00}, 0)
if !errors.Is(err, errContextUnsupported) {
t.Errorf("ExportKeyingMaterial with context: expected '%s' actual '%s'", errContextUnsupported, err)
}

for k := range invalidKeyingLabels() {
state = c.ConnectionState()
state, ok = c.ConnectionState()
if !ok {
t.Fatal("ConnectionState failed")
}
_, err = state.ExportKeyingMaterial(k, nil, 0)
if !errors.Is(err, errReservedExportKeyingMaterial) {
t.Errorf("ExportKeyingMaterial reserved label: expected '%s' actual '%s'", errReservedExportKeyingMaterial, err)
}
}

state = c.ConnectionState()
state, ok = c.ConnectionState()
if !ok {
t.Fatal("ConnectionState failed")
}
keyingMaterial, err := state.ExportKeyingMaterial(exportLabel, nil, 10)
if err != nil {
t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
Expand All @@ -527,7 +539,10 @@ func TestExportKeyingMaterial(t *testing.T) {
}

c.state.isClient = true
state = c.ConnectionState()
state, ok = c.ConnectionState()
if !ok {
t.Fatal("ConnectionState failed")
}
keyingMaterial, err = state.ExportKeyingMaterial(exportLabel, nil, 10)
if err != nil {
t.Errorf("ExportKeyingMaterial as server: unexpected error '%s'", err)
Expand Down Expand Up @@ -669,7 +684,11 @@ func TestPSK(t *testing.T) {
t.Fatalf("TestPSK: Server failed(%v)", err)
}

actualPSKIdentityHint := server.ConnectionState().IdentityHint
state, ok := server.ConnectionState()
if !ok {
t.Fatalf("TestPSK: Server ConnectionState failed")
}
actualPSKIdentityHint := state.IdentityHint
if !bytes.Equal(actualPSKIdentityHint, test.ClientIdentity) {
t.Errorf("TestPSK: Server ClientPSKIdentity Mismatch '%s': expected(%v) actual(%v)", test.Name, test.ClientIdentity, actualPSKIdentityHint)
}
Expand Down Expand Up @@ -1194,7 +1213,11 @@ func TestClientCertificate(t *testing.T) {
t.Errorf("Client failed(%v)", res.err)
}

actualClientCert := server.ConnectionState().PeerCertificates
state, ok := server.ConnectionState()
if !ok {
t.Error("Server connection state not available")
}
actualClientCert := state.PeerCertificates
if tt.serverCfg.ClientAuth == RequireAnyClientCert || tt.serverCfg.ClientAuth == RequireAndVerifyClientCert {
if actualClientCert == nil {
t.Errorf("Client did not provide a certificate")
Expand All @@ -1221,7 +1244,11 @@ func TestClientCertificate(t *testing.T) {
}
}

actualServerCert := res.c.ConnectionState().PeerCertificates
clientState, ok := res.c.ConnectionState()
if !ok {
t.Error("Client connection state not available")
}
actualServerCert := clientState.PeerCertificates
if actualServerCert == nil {
t.Errorf("Server did not provide a certificate")
}
Expand Down Expand Up @@ -2889,8 +2916,12 @@ func TestSessionResume(t *testing.T) {
t.Fatalf("TestSessionResume: Server failed(%v)", err)
}

actualSessionID := server.ConnectionState().SessionID
actualMasterSecret := server.ConnectionState().masterSecret
state, ok := server.ConnectionState()
if !ok {
t.Fatal("TestSessionResume: ConnectionState failed")
}
actualSessionID := state.SessionID
actualMasterSecret := state.masterSecret
if !bytes.Equal(actualSessionID, id) {
t.Errorf("TestSessionResumetion: SessionID Mismatch: expected(%v) actual(%v)", id, actualSessionID)
}
Expand Down Expand Up @@ -2940,8 +2971,12 @@ func TestSessionResume(t *testing.T) {
t.Fatalf("TestSessionResumetion: Server failed(%v)", err)
}

actualSessionID := server.ConnectionState().SessionID
actualMasterSecret := server.ConnectionState().masterSecret
state, ok := server.ConnectionState()
if !ok {
t.Fatal("TestSessionResumetion: ConnectionState failed")
}
actualSessionID := state.SessionID
actualMasterSecret := state.masterSecret
ss, _ := s2.Get(actualSessionID)
if !bytes.Equal(actualMasterSecret, ss.Secret) {
t.Errorf("TestSessionResumetion: masterSecret Mismatch: expected(%v) actual(%v)", ss.Secret, actualMasterSecret)
Expand Down Expand Up @@ -3071,8 +3106,8 @@ func TestCipherSuiteMatchesCertificateType(t *testing.T) {
t.Fatal(err)
} else if err := c.Close(); err != nil {
t.Fatal(err)
} else if c.ConnectionState().cipherSuite.ID() != test.expectedCipher {
t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, c.ConnectionState().cipherSuite.ID())
} else if state, ok := c.ConnectionState(); !ok || state.cipherSuite.ID() != test.expectedCipher {
t.Fatalf("Expected(%s) and Actual(%s) CipherSuite do not match", test.expectedCipher, state.cipherSuite.ID())
}
})
}
Expand Down Expand Up @@ -3527,3 +3562,55 @@ func TestFragmentBuffer_Retransmission(t *testing.T) {
t.Fatal("fragment should be retransmission")
}
}

func TestConnectionState(t *testing.T) {
ca, cb := dpipe.Pipe()

// Setup client
clientCfg := &Config{}
clientCert, err := selfsign.GenerateSelfSigned()
if err != nil {
t.Fatal(err)
}
clientCfg.Certificates = []tls.Certificate{clientCert}
clientCfg.InsecureSkipVerify = true
client, err := Client(dtlsnet.PacketConnFromConn(ca), ca.RemoteAddr(), clientCfg)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = client.Close()
}()

_, ok := client.ConnectionState()
if ok {
t.Fatal("ConnectionState should be nil")
}

ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
c := make(chan error)
go func() {
errC := client.HandshakeContext(ctx)
c <- errC
}()

// Setup server
server, err := testServer(ctx, dtlsnet.PacketConnFromConn(cb), cb.RemoteAddr(), &Config{}, true)
if err != nil {
t.Fatal(err)
}
defer func() {
_ = server.Close()
}()

err = <-c
if err != nil {
t.Fatal(err)
}

_, ok = client.ConnectionState()
if !ok {
t.Fatal("ConnectionState should not be nil")
}
}
12 changes: 10 additions & 2 deletions flight4handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,11 @@

if state.cipherSuite.AuthenticationType() == CipherSuiteAuthenticationTypeAnonymous {
if cfg.verifyConnection != nil {
if err := cfg.verifyConnection(state.clone()); err != nil {
stateClone, err := state.clone()
if err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err

Check warning on line 188 in flight4handler.go

View check run for this annotation

Codecov / codecov/patch

flight4handler.go#L186-L188

Added lines #L186 - L188 were not covered by tests
}
if err := cfg.verifyConnection(stateClone); err != nil {

Check warning on line 190 in flight4handler.go

View check run for this annotation

Codecov / codecov/patch

flight4handler.go#L190

Added line #L190 was not covered by tests
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
}
}
Expand All @@ -210,7 +214,11 @@
// go to flight6
}
if cfg.verifyConnection != nil {
if err := cfg.verifyConnection(state.clone()); err != nil {
stateClone, err := state.clone()
if err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, err

Check warning on line 219 in flight4handler.go

View check run for this annotation

Codecov / codecov/patch

flight4handler.go#L219

Added line #L219 was not covered by tests
}
if err := cfg.verifyConnection(stateClone); err != nil {
return 0, &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
}
}
Expand Down
8 changes: 6 additions & 2 deletions flight5handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,12 @@
}
}
if cfg.verifyConnection != nil {
if err = cfg.verifyConnection(state.clone()); err != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, err
stateClone, errC := state.clone()
if errC != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.InternalError}, errC

Check warning on line 349 in flight5handler.go

View check run for this annotation

Codecov / codecov/patch

flight5handler.go#L349

Added line #L349 was not covered by tests
}
if errC = cfg.verifyConnection(stateClone); errC != nil {
return &alert.Alert{Level: alert.Fatal, Description: alert.BadCertificate}, errC
}
}

Expand Down
10 changes: 8 additions & 2 deletions resume_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ import (
"github.com/pion/transport/v3/test"
)

var errMessageMissmatch = errors.New("messages missmatch")
var (
errMessageMissmatch = errors.New("messages missmatch")
errInvalidConnectionState = errors.New("failed to get connection state")
)

func TestResumeClient(t *testing.T) {
DoTestResume(t, Client, Server)
Expand Down Expand Up @@ -120,7 +123,10 @@ func DoTestResume(t *testing.T, newLocal, newRemote func(net.PacketConn, net.Add
}

// Serialize and deserialize state
state := local.ConnectionState()
state, ok := local.ConnectionState()
if !ok {
fatal(t, errChan, errInvalidConnectionState)
}
var b []byte
b, err = state.MarshalBinary()
if err != nil {
Expand Down
28 changes: 21 additions & 7 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import (
"bytes"
"encoding/gob"
"errors"
"sync/atomic"

"github.com/pion/dtls/v3/pkg/crypto/elliptic"
Expand Down Expand Up @@ -87,15 +88,25 @@
NegotiatedProtocol string
}

func (s *State) clone() *State {
serialized := s.serialize()
var errCipherSuiteNotSet = &InternalError{Err: errors.New("cipher suite not set")} //nolint:goerr113

func (s *State) clone() (*State, error) {
serialized, err := s.serialize()
if err != nil {
return nil, err
}
state := &State{}
state.deserialize(*serialized)

return state
return state, err
}

func (s *State) serialize() *serializedState {
func (s *State) serialize() (*serializedState, error) {
if s.cipherSuite == nil {
return nil, errCipherSuiteNotSet
}
cipherSuiteID := uint16(s.cipherSuite.ID())

// Marshal random values
localRnd := s.localRandom.MarshalFixed()
remoteRnd := s.remoteRandom.MarshalFixed()
Expand All @@ -104,7 +115,7 @@
return &serializedState{
LocalEpoch: s.getLocalEpoch(),
RemoteEpoch: s.getRemoteEpoch(),
CipherSuiteID: uint16(s.cipherSuite.ID()),
CipherSuiteID: cipherSuiteID,
MasterSecret: s.masterSecret,
SequenceNumber: atomic.LoadUint64(&s.localSequenceNumber[epoch]),
LocalRandom: localRnd,
Expand All @@ -117,7 +128,7 @@
RemoteConnectionID: s.remoteConnectionID,
IsClient: s.isClient,
NegotiatedProtocol: s.NegotiatedProtocol,
}
}, nil
}

func (s *State) deserialize(serialized serializedState) {
Expand Down Expand Up @@ -187,7 +198,10 @@

// MarshalBinary is a binary.BinaryMarshaler.MarshalBinary implementation
func (s *State) MarshalBinary() ([]byte, error) {
serialized := s.serialize()
serialized, err := s.serialize()
if err != nil {
return nil, err

Check warning on line 203 in state.go

View check run for this annotation

Codecov / codecov/patch

state.go#L203

Added line #L203 was not covered by tests
}

var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
Expand Down
Loading