mirror of
https://github.com/morgan9e/FreeRDP
synced 2026-04-14 00:14:11 +09:00
Merge pull request #11262 from akallabeth/license-fix
Redirection && StreamPool usage fixes
This commit is contained in:
@@ -1139,7 +1139,7 @@ BOOL rdp_client_connect_mcs_channel_join_confirm(rdpRdp* rdp, wStream* s)
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT16 length)
|
||||
state_run_t rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT16 length)
|
||||
{
|
||||
WINPR_ASSERT(rdp);
|
||||
WINPR_ASSERT(rdp->mcs);
|
||||
@@ -1147,36 +1147,40 @@ BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT1
|
||||
if (!rdp->mcs->messageChannelJoined)
|
||||
{
|
||||
WLog_Print(rdp->log, WLOG_WARN, "MCS message channel not joined!");
|
||||
return FALSE;
|
||||
return STATE_RUN_FAILED;
|
||||
}
|
||||
const UINT16 messageChannelId = rdp->mcs->messageChannelId;
|
||||
if (messageChannelId == 0)
|
||||
{
|
||||
WLog_Print(rdp->log, WLOG_WARN, "MCS message channel id == 0");
|
||||
return FALSE;
|
||||
return STATE_RUN_FAILED;
|
||||
}
|
||||
|
||||
if (channelId != messageChannelId)
|
||||
if ((channelId != messageChannelId) && (channelId != MCS_GLOBAL_CHANNEL_ID))
|
||||
{
|
||||
WLog_Print(rdp->log, WLOG_WARN, "MCS message channel expected id=%" PRIu16 ", got %" PRIu16,
|
||||
messageChannelId, channelId);
|
||||
return FALSE;
|
||||
WLog_Print(rdp->log, WLOG_WARN,
|
||||
"MCS message channel expected id=[%" PRIu16 "|%d], got %" PRIu16,
|
||||
messageChannelId, MCS_GLOBAL_CHANNEL_ID, channelId);
|
||||
return STATE_RUN_FAILED;
|
||||
}
|
||||
|
||||
UINT16 securityFlags = 0;
|
||||
if (!rdp_read_security_header(rdp, s, &securityFlags, &length))
|
||||
return FALSE;
|
||||
return STATE_RUN_FAILED;
|
||||
|
||||
if (securityFlags & SEC_ENCRYPT)
|
||||
{
|
||||
if (!rdp_decrypt(rdp, s, &length, securityFlags))
|
||||
return FALSE;
|
||||
return STATE_RUN_FAILED;
|
||||
}
|
||||
|
||||
if (rdp_recv_message_channel_pdu(rdp, s, securityFlags) != STATE_RUN_SUCCESS)
|
||||
return FALSE;
|
||||
|
||||
return tpkt_ensure_stream_consumed(s, length);
|
||||
const state_run_t rc = rdp_recv_message_channel_pdu(rdp, s, securityFlags);
|
||||
if (state_run_success(rc))
|
||||
{
|
||||
if (!tpkt_ensure_stream_consumed(s, length))
|
||||
return STATE_RUN_FAILED;
|
||||
}
|
||||
return rc;
|
||||
}
|
||||
|
||||
BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s)
|
||||
@@ -1196,8 +1200,8 @@ BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s)
|
||||
/* Process any MCS message channel PDUs. */
|
||||
if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId))
|
||||
{
|
||||
if (rdp_handle_message_channel(rdp, s, channelId, length))
|
||||
return TRUE;
|
||||
const state_run_t rc = rdp_handle_message_channel(rdp, s, channelId, length);
|
||||
return state_run_success(rc);
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -1228,9 +1232,7 @@ state_run_t rdp_client_connect_license(rdpRdp* rdp, wStream* s)
|
||||
const UINT16 messageChannelId = rdp->mcs->messageChannelId;
|
||||
if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId))
|
||||
{
|
||||
if (!rdp_handle_message_channel(rdp, s, channelId, length))
|
||||
return STATE_RUN_FAILED;
|
||||
return STATE_RUN_SUCCESS;
|
||||
return rdp_handle_message_channel(rdp, s, channelId, length);
|
||||
}
|
||||
|
||||
if (!rdp_read_security_header(rdp, s, &securityFlags, &length))
|
||||
@@ -1306,9 +1308,7 @@ state_run_t rdp_client_connect_demand_active(rdpRdp* rdp, wStream* s)
|
||||
if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId))
|
||||
{
|
||||
rdp->inPackets++;
|
||||
if (!rdp_handle_message_channel(rdp, s, channelId, length))
|
||||
return STATE_RUN_FAILED;
|
||||
return STATE_RUN_SUCCESS;
|
||||
return rdp_handle_message_channel(rdp, s, channelId, length);
|
||||
}
|
||||
|
||||
if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, NULL))
|
||||
|
||||
@@ -73,8 +73,8 @@ FREERDP_LOCAL const char* rdp_client_connection_state_string(UINT state);
|
||||
|
||||
FREERDP_LOCAL BOOL rdp_channels_from_mcs(rdpSettings* settings, const rdpRdp* rdp);
|
||||
|
||||
FREERDP_LOCAL BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId,
|
||||
UINT16 length);
|
||||
FREERDP_LOCAL state_run_t rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId,
|
||||
UINT16 length);
|
||||
FREERDP_LOCAL BOOL rdp_handle_optional_rdp_decryption(rdpRdp* rdp, wStream* s, UINT16* length,
|
||||
UINT16* pSecurityFlags);
|
||||
|
||||
|
||||
@@ -317,8 +317,13 @@ BOOL freerdp_abort_connect_context(rdpContext* context)
|
||||
/* Try to send a [MS-RDPBCGR] 1.3.1.4.1 User-Initiated on Client PDU, we don't care about
|
||||
* success */
|
||||
if (context->rdp && context->rdp->mcs)
|
||||
(void)mcs_send_disconnect_provider_ultimatum(context->rdp->mcs,
|
||||
Disconnect_Ultimatum_user_requested);
|
||||
{
|
||||
if (!context->ServerMode)
|
||||
{
|
||||
(void)mcs_send_disconnect_provider_ultimatum(context->rdp->mcs,
|
||||
Disconnect_Ultimatum_user_requested);
|
||||
}
|
||||
}
|
||||
return utils_abort_connect(context->rdp);
|
||||
}
|
||||
|
||||
|
||||
@@ -458,9 +458,7 @@ static state_run_t peer_recv_tpkt_pdu(freerdp_peer* client, wStream* s)
|
||||
|
||||
if (rdp_get_state(rdp) <= CONNECTION_STATE_LICENSING)
|
||||
{
|
||||
if (!rdp_handle_message_channel(rdp, s, channelId, length))
|
||||
return STATE_RUN_FAILED;
|
||||
return STATE_RUN_SUCCESS;
|
||||
return rdp_handle_message_channel(rdp, s, channelId, length);
|
||||
}
|
||||
|
||||
if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, &securityFlags))
|
||||
|
||||
@@ -1634,9 +1634,7 @@ static state_run_t rdp_recv_tpkt_pdu(rdpRdp* rdp, wStream* s)
|
||||
if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId))
|
||||
{
|
||||
rdp->inPackets++;
|
||||
if (!rdp_handle_message_channel(rdp, s, channelId, length))
|
||||
return STATE_RUN_FAILED;
|
||||
return STATE_RUN_SUCCESS;
|
||||
return rdp_handle_message_channel(rdp, s, channelId, length);
|
||||
}
|
||||
|
||||
if (rdp->settings->UseRdpSecurityLayer)
|
||||
|
||||
@@ -147,22 +147,11 @@ static void transport_ssl_cb(const SSL* ssl, int where, int ret)
|
||||
}
|
||||
}
|
||||
|
||||
wStream* transport_send_stream_init(rdpTransport* transport, size_t size)
|
||||
wStream* transport_send_stream_init(WINPR_ATTR_UNUSED rdpTransport* transport, size_t size)
|
||||
{
|
||||
WINPR_ASSERT(transport);
|
||||
|
||||
wStream* s = StreamPool_Take(transport->ReceivePool, size);
|
||||
if (!s)
|
||||
return NULL;
|
||||
|
||||
if (!Stream_EnsureCapacity(s, size))
|
||||
{
|
||||
Stream_Release(s);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
Stream_SetPosition(s, 0);
|
||||
return s;
|
||||
return Stream_New(NULL, size);
|
||||
}
|
||||
|
||||
BOOL transport_attach(rdpTransport* transport, int sockfd)
|
||||
@@ -1447,9 +1436,12 @@ int transport_check_fds(rdpTransport* transport)
|
||||
}
|
||||
|
||||
received = transport->ReceiveBuffer;
|
||||
|
||||
if (!(transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0)))
|
||||
transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0);
|
||||
if (!transport->ReceiveBuffer)
|
||||
{
|
||||
Stream_Release(received);
|
||||
return -1;
|
||||
}
|
||||
|
||||
/**
|
||||
* status:
|
||||
|
||||
@@ -1368,7 +1368,20 @@ extern "C"
|
||||
|
||||
WINPR_API void StreamPool_Return(wStreamPool* pool, wStream* s);
|
||||
|
||||
/** @brief increment reference count of stream
|
||||
*
|
||||
* @param s The stream to reference
|
||||
* @bug versions < 3.13.0 did only handle streams returned by StreamPool_Take
|
||||
*/
|
||||
WINPR_API void Stream_AddRef(wStream* s);
|
||||
|
||||
/** @brief Release a reference to a stream.
|
||||
* If the reference count reaches \b 0 it is returned to the StreamPool it was taken from or \b
|
||||
* Stream_Free is called.
|
||||
*
|
||||
* @param s The stream to release
|
||||
* @bug versions < 3.13.0 did only handle streams returned by StreamPool_Take
|
||||
*/
|
||||
WINPR_API void Stream_Release(wStream* s);
|
||||
|
||||
WINPR_ATTR_MALLOC(Stream_Release, 1)
|
||||
|
||||
@@ -285,10 +285,7 @@ static void StreamPool_Remove(wStreamPool* pool, wStream* s)
|
||||
static void StreamPool_ReleaseOrReturn(wStreamPool* pool, wStream* s)
|
||||
{
|
||||
StreamPool_Lock(pool);
|
||||
if (s->count > 0)
|
||||
s->count--;
|
||||
if (s->count == 0)
|
||||
StreamPool_Remove(pool, s);
|
||||
StreamPool_Remove(pool, s);
|
||||
StreamPool_Unlock(pool);
|
||||
}
|
||||
|
||||
@@ -310,12 +307,7 @@ void StreamPool_Return(wStreamPool* pool, wStream* s)
|
||||
void Stream_AddRef(wStream* s)
|
||||
{
|
||||
WINPR_ASSERT(s);
|
||||
if (s->pool)
|
||||
{
|
||||
StreamPool_Lock(s->pool);
|
||||
s->count++;
|
||||
StreamPool_Unlock(s->pool);
|
||||
}
|
||||
s->count++;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -325,8 +317,16 @@ void Stream_AddRef(wStream* s)
|
||||
void Stream_Release(wStream* s)
|
||||
{
|
||||
WINPR_ASSERT(s);
|
||||
if (s->pool)
|
||||
StreamPool_ReleaseOrReturn(s->pool, s);
|
||||
|
||||
if (s->count > 0)
|
||||
s->count--;
|
||||
if (s->count == 0)
|
||||
{
|
||||
if (s->pool)
|
||||
StreamPool_ReleaseOrReturn(s->pool, s);
|
||||
else
|
||||
Stream_Free(s, TRUE);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -98,7 +98,7 @@ wStream* Stream_New(BYTE* buffer, size_t size)
|
||||
if (!buffer && !size)
|
||||
return NULL;
|
||||
|
||||
s = malloc(sizeof(wStream));
|
||||
s = calloc(1, sizeof(wStream));
|
||||
if (!s)
|
||||
return NULL;
|
||||
|
||||
@@ -118,7 +118,7 @@ wStream* Stream_New(BYTE* buffer, size_t size)
|
||||
s->length = size;
|
||||
|
||||
s->pool = NULL;
|
||||
s->count = 0;
|
||||
s->count = 1;
|
||||
s->isAllocatedStream = TRUE;
|
||||
s->isOwner = TRUE;
|
||||
return s;
|
||||
@@ -147,7 +147,7 @@ wStream* Stream_StaticInit(wStream* s, BYTE* buffer, size_t size)
|
||||
s->buffer = s->pointer = buffer;
|
||||
s->capacity = s->length = size;
|
||||
s->pool = NULL;
|
||||
s->count = 0;
|
||||
s->count = 1;
|
||||
s->isAllocatedStream = FALSE;
|
||||
s->isOwner = FALSE;
|
||||
return s;
|
||||
|
||||
Reference in New Issue
Block a user