diff --git a/packet/https.go b/packet/https.go index 49abeb8..7174417 100644 --- a/packet/https.go +++ b/packet/https.go @@ -2,12 +2,14 @@ package packet import ( "encoding/binary" + "fmt" "io" ) type TLSMessageType byte const ( + TLSMaxPayloadLen uint16 = 16384 // 16 KB TLSHeaderLen = 5 TLSInvalid TLSMessageType = 0x0 TLSChangeCipherSpec TLSMessageType = 0x14 @@ -42,7 +44,10 @@ func ReadTLSMessage(r io.Reader) (*TLSMessage, error) { ProtoVersion: binary.BigEndian.Uint16(rawHeader[1:3]), PayloadLen: binary.BigEndian.Uint16(rawHeader[3:5]), } - + if header.PayloadLen > TLSMaxPayloadLen { + // Corrupted header? Check integer overflow + return nil, fmt.Errorf("invalid TLS header. Type: %x, ProtoVersion: %x, PayloadLen: %x", header.Type, header.ProtoVersion, header.PayloadLen) + } raw := make([]byte, header.PayloadLen+TLSHeaderLen) copy(raw[0:TLSHeaderLen], rawHeader[:]) _, err = io.ReadFull(r, raw[TLSHeaderLen:]) @@ -62,5 +67,7 @@ func ReadTLSMessage(r io.Reader) (*TLSMessage, error) { func (m *TLSMessage) IsClientHello() bool { // According to RFC 8446 section 4. // first byte (Raw[5]) of handshake message should be 0x1 - means client_hello - return m.Header.Type == TLSHandshake && m.Raw[5] == 0x01 + return len(m.Raw) > TLSHeaderLen && + m.Header.Type == TLSHandshake && + m.Raw[5] == 0x01 }