diff --git a/libfreerdp/core/rdstls.c b/libfreerdp/core/rdstls.c index 1d79158fe..247f63914 100644 --- a/libfreerdp/core/rdstls.c +++ b/libfreerdp/core/rdstls.c @@ -936,6 +936,8 @@ int rdstls_authenticate(rdpRdstls* rdstls) static SSIZE_T rdstls_parse_pdu_data_type(wLog* log, UINT16 dataType, wStream* s) { + size_t pduLength = 0; + switch (dataType) { case RDSTLS_DATA_PASSWORD_CREDS: @@ -972,9 +974,7 @@ static SSIZE_T rdstls_parse_pdu_data_type(wLog* log, UINT16 dataType, wStream* s return 0; Stream_Read_UINT16(s, passwordLength); - if (Stream_GetRemainingLength(s) < passwordLength) - return 0; - Stream_Seek(s, passwordLength); + pduLength = Stream_GetPosition(s) + passwordLength; } break; case RDSTLS_DATA_AUTORECONNECT_COOKIE: @@ -987,8 +987,8 @@ static SSIZE_T rdstls_parse_pdu_data_type(wLog* log, UINT16 dataType, wStream* s if (Stream_GetRemainingLength(s) < 2) return 0; Stream_Read_UINT16(s, cookieLength); - if (!Stream_SafeSeek(s, cookieLength)) - return 0; + + pduLength = Stream_GetPosition(s) + cookieLength; } break; default: @@ -996,10 +996,9 @@ static SSIZE_T rdstls_parse_pdu_data_type(wLog* log, UINT16 dataType, wStream* s return -1; } - const size_t len = Stream_GetPosition(s); - if (len > SSIZE_MAX) + if (pduLength > SSIZE_MAX) return 0; - return (SSIZE_T)len; + return (SSIZE_T)pduLength; } SSIZE_T rdstls_parse_pdu(wLog* log, wStream* stream) diff --git a/libfreerdp/core/transport.c b/libfreerdp/core/transport.c index 144a6c55b..eae6046e5 100644 --- a/libfreerdp/core/transport.c +++ b/libfreerdp/core/transport.c @@ -1114,14 +1114,15 @@ static int transport_default_read_pdu(rdpTransport* transport, wStream* s) position = Stream_GetPosition(s); if (position > pduLength) return -1; - - status = transport_read_layer_bytes(transport, s, pduLength - Stream_GetPosition(s)); - - if (status != 1) + else if (position < pduLength) { - if ((status < INT32_MIN) || (status > INT32_MAX)) - return -1; - return (int)status; + status = transport_read_layer_bytes(transport, s, pduLength - position); + if (status != 1) + { + if ((status < INT32_MIN) || (status > INT32_MAX)) + return -1; + return (int)status; + } } if (Stream_GetPosition(s) >= pduLength)