Merge pull request #11262 from akallabeth/license-fix

Redirection && StreamPool usage fixes
This commit is contained in:
akallabeth
2025-03-04 09:13:56 +01:00
committed by GitHub
9 changed files with 67 additions and 61 deletions

View File

@@ -1139,7 +1139,7 @@ BOOL rdp_client_connect_mcs_channel_join_confirm(rdpRdp* rdp, wStream* s)
return TRUE;
}
BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT16 length)
state_run_t rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT16 length)
{
WINPR_ASSERT(rdp);
WINPR_ASSERT(rdp->mcs);
@@ -1147,36 +1147,40 @@ BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT1
if (!rdp->mcs->messageChannelJoined)
{
WLog_Print(rdp->log, WLOG_WARN, "MCS message channel not joined!");
return FALSE;
return STATE_RUN_FAILED;
}
const UINT16 messageChannelId = rdp->mcs->messageChannelId;
if (messageChannelId == 0)
{
WLog_Print(rdp->log, WLOG_WARN, "MCS message channel id == 0");
return FALSE;
return STATE_RUN_FAILED;
}
if (channelId != messageChannelId)
if ((channelId != messageChannelId) && (channelId != MCS_GLOBAL_CHANNEL_ID))
{
WLog_Print(rdp->log, WLOG_WARN, "MCS message channel expected id=%" PRIu16 ", got %" PRIu16,
messageChannelId, channelId);
return FALSE;
WLog_Print(rdp->log, WLOG_WARN,
"MCS message channel expected id=[%" PRIu16 "|%d], got %" PRIu16,
messageChannelId, MCS_GLOBAL_CHANNEL_ID, channelId);
return STATE_RUN_FAILED;
}
UINT16 securityFlags = 0;
if (!rdp_read_security_header(rdp, s, &securityFlags, &length))
return FALSE;
return STATE_RUN_FAILED;
if (securityFlags & SEC_ENCRYPT)
{
if (!rdp_decrypt(rdp, s, &length, securityFlags))
return FALSE;
return STATE_RUN_FAILED;
}
if (rdp_recv_message_channel_pdu(rdp, s, securityFlags) != STATE_RUN_SUCCESS)
return FALSE;
return tpkt_ensure_stream_consumed(s, length);
const state_run_t rc = rdp_recv_message_channel_pdu(rdp, s, securityFlags);
if (state_run_success(rc))
{
if (!tpkt_ensure_stream_consumed(s, length))
return STATE_RUN_FAILED;
}
return rc;
}
BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s)
@@ -1196,8 +1200,8 @@ BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s)
/* Process any MCS message channel PDUs. */
if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId))
{
if (rdp_handle_message_channel(rdp, s, channelId, length))
return TRUE;
const state_run_t rc = rdp_handle_message_channel(rdp, s, channelId, length);
return state_run_success(rc);
}
else
{
@@ -1228,9 +1232,7 @@ state_run_t rdp_client_connect_license(rdpRdp* rdp, wStream* s)
const UINT16 messageChannelId = rdp->mcs->messageChannelId;
if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId))
{
if (!rdp_handle_message_channel(rdp, s, channelId, length))
return STATE_RUN_FAILED;
return STATE_RUN_SUCCESS;
return rdp_handle_message_channel(rdp, s, channelId, length);
}
if (!rdp_read_security_header(rdp, s, &securityFlags, &length))
@@ -1306,9 +1308,7 @@ state_run_t rdp_client_connect_demand_active(rdpRdp* rdp, wStream* s)
if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId))
{
rdp->inPackets++;
if (!rdp_handle_message_channel(rdp, s, channelId, length))
return STATE_RUN_FAILED;
return STATE_RUN_SUCCESS;
return rdp_handle_message_channel(rdp, s, channelId, length);
}
if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, NULL))

View File

@@ -73,8 +73,8 @@ FREERDP_LOCAL const char* rdp_client_connection_state_string(UINT state);
FREERDP_LOCAL BOOL rdp_channels_from_mcs(rdpSettings* settings, const rdpRdp* rdp);
FREERDP_LOCAL BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId,
UINT16 length);
FREERDP_LOCAL state_run_t rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId,
UINT16 length);
FREERDP_LOCAL BOOL rdp_handle_optional_rdp_decryption(rdpRdp* rdp, wStream* s, UINT16* length,
UINT16* pSecurityFlags);

View File

@@ -317,8 +317,13 @@ BOOL freerdp_abort_connect_context(rdpContext* context)
/* Try to send a [MS-RDPBCGR] 1.3.1.4.1 User-Initiated on Client PDU, we don't care about
* success */
if (context->rdp && context->rdp->mcs)
(void)mcs_send_disconnect_provider_ultimatum(context->rdp->mcs,
Disconnect_Ultimatum_user_requested);
{
if (!context->ServerMode)
{
(void)mcs_send_disconnect_provider_ultimatum(context->rdp->mcs,
Disconnect_Ultimatum_user_requested);
}
}
return utils_abort_connect(context->rdp);
}

View File

@@ -458,9 +458,7 @@ static state_run_t peer_recv_tpkt_pdu(freerdp_peer* client, wStream* s)
if (rdp_get_state(rdp) <= CONNECTION_STATE_LICENSING)
{
if (!rdp_handle_message_channel(rdp, s, channelId, length))
return STATE_RUN_FAILED;
return STATE_RUN_SUCCESS;
return rdp_handle_message_channel(rdp, s, channelId, length);
}
if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, &securityFlags))

View File

@@ -1634,9 +1634,7 @@ static state_run_t rdp_recv_tpkt_pdu(rdpRdp* rdp, wStream* s)
if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId))
{
rdp->inPackets++;
if (!rdp_handle_message_channel(rdp, s, channelId, length))
return STATE_RUN_FAILED;
return STATE_RUN_SUCCESS;
return rdp_handle_message_channel(rdp, s, channelId, length);
}
if (rdp->settings->UseRdpSecurityLayer)

View File

@@ -147,22 +147,11 @@ static void transport_ssl_cb(const SSL* ssl, int where, int ret)
}
}
wStream* transport_send_stream_init(rdpTransport* transport, size_t size)
wStream* transport_send_stream_init(WINPR_ATTR_UNUSED rdpTransport* transport, size_t size)
{
WINPR_ASSERT(transport);
wStream* s = StreamPool_Take(transport->ReceivePool, size);
if (!s)
return NULL;
if (!Stream_EnsureCapacity(s, size))
{
Stream_Release(s);
return NULL;
}
Stream_SetPosition(s, 0);
return s;
return Stream_New(NULL, size);
}
BOOL transport_attach(rdpTransport* transport, int sockfd)
@@ -1447,9 +1436,12 @@ int transport_check_fds(rdpTransport* transport)
}
received = transport->ReceiveBuffer;
if (!(transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0)))
transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0);
if (!transport->ReceiveBuffer)
{
Stream_Release(received);
return -1;
}
/**
* status:

View File

@@ -1368,7 +1368,20 @@ extern "C"
WINPR_API void StreamPool_Return(wStreamPool* pool, wStream* s);
/** @brief increment reference count of stream
*
* @param s The stream to reference
* @bug versions < 3.13.0 did only handle streams returned by StreamPool_Take
*/
WINPR_API void Stream_AddRef(wStream* s);
/** @brief Release a reference to a stream.
* If the reference count reaches \b 0 it is returned to the StreamPool it was taken from or \b
* Stream_Free is called.
*
* @param s The stream to release
* @bug versions < 3.13.0 did only handle streams returned by StreamPool_Take
*/
WINPR_API void Stream_Release(wStream* s);
WINPR_ATTR_MALLOC(Stream_Release, 1)

View File

@@ -285,10 +285,7 @@ static void StreamPool_Remove(wStreamPool* pool, wStream* s)
static void StreamPool_ReleaseOrReturn(wStreamPool* pool, wStream* s)
{
StreamPool_Lock(pool);
if (s->count > 0)
s->count--;
if (s->count == 0)
StreamPool_Remove(pool, s);
StreamPool_Remove(pool, s);
StreamPool_Unlock(pool);
}
@@ -310,12 +307,7 @@ void StreamPool_Return(wStreamPool* pool, wStream* s)
void Stream_AddRef(wStream* s)
{
WINPR_ASSERT(s);
if (s->pool)
{
StreamPool_Lock(s->pool);
s->count++;
StreamPool_Unlock(s->pool);
}
s->count++;
}
/**
@@ -325,8 +317,16 @@ void Stream_AddRef(wStream* s)
void Stream_Release(wStream* s)
{
WINPR_ASSERT(s);
if (s->pool)
StreamPool_ReleaseOrReturn(s->pool, s);
if (s->count > 0)
s->count--;
if (s->count == 0)
{
if (s->pool)
StreamPool_ReleaseOrReturn(s->pool, s);
else
Stream_Free(s, TRUE);
}
}
/**

View File

@@ -98,7 +98,7 @@ wStream* Stream_New(BYTE* buffer, size_t size)
if (!buffer && !size)
return NULL;
s = malloc(sizeof(wStream));
s = calloc(1, sizeof(wStream));
if (!s)
return NULL;
@@ -118,7 +118,7 @@ wStream* Stream_New(BYTE* buffer, size_t size)
s->length = size;
s->pool = NULL;
s->count = 0;
s->count = 1;
s->isAllocatedStream = TRUE;
s->isOwner = TRUE;
return s;
@@ -147,7 +147,7 @@ wStream* Stream_StaticInit(wStream* s, BYTE* buffer, size_t size)
s->buffer = s->pointer = buffer;
s->capacity = s->length = size;
s->pool = NULL;
s->count = 0;
s->count = 1;
s->isAllocatedStream = FALSE;
s->isOwner = FALSE;
return s;