diff --git a/libfreerdp/core/connection.c b/libfreerdp/core/connection.c index 0cab2dc02..9674226ae 100644 --- a/libfreerdp/core/connection.c +++ b/libfreerdp/core/connection.c @@ -1139,7 +1139,7 @@ BOOL rdp_client_connect_mcs_channel_join_confirm(rdpRdp* rdp, wStream* s) return TRUE; } -BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT16 length) +state_run_t rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT16 length) { WINPR_ASSERT(rdp); WINPR_ASSERT(rdp->mcs); @@ -1147,36 +1147,40 @@ BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT1 if (!rdp->mcs->messageChannelJoined) { WLog_Print(rdp->log, WLOG_WARN, "MCS message channel not joined!"); - return FALSE; + return STATE_RUN_FAILED; } const UINT16 messageChannelId = rdp->mcs->messageChannelId; if (messageChannelId == 0) { WLog_Print(rdp->log, WLOG_WARN, "MCS message channel id == 0"); - return FALSE; + return STATE_RUN_FAILED; } - if (channelId != messageChannelId) + if ((channelId != messageChannelId) && (channelId != MCS_GLOBAL_CHANNEL_ID)) { - WLog_Print(rdp->log, WLOG_WARN, "MCS message channel expected id=%" PRIu16 ", got %" PRIu16, - messageChannelId, channelId); - return FALSE; + WLog_Print(rdp->log, WLOG_WARN, + "MCS message channel expected id=[%" PRIu16 "|%d], got %" PRIu16, + messageChannelId, MCS_GLOBAL_CHANNEL_ID, channelId); + return STATE_RUN_FAILED; } UINT16 securityFlags = 0; if (!rdp_read_security_header(rdp, s, &securityFlags, &length)) - return FALSE; + return STATE_RUN_FAILED; if (securityFlags & SEC_ENCRYPT) { if (!rdp_decrypt(rdp, s, &length, securityFlags)) - return FALSE; + return STATE_RUN_FAILED; } - if (rdp_recv_message_channel_pdu(rdp, s, securityFlags) != STATE_RUN_SUCCESS) - return FALSE; - - return tpkt_ensure_stream_consumed(s, length); + const state_run_t rc = rdp_recv_message_channel_pdu(rdp, s, securityFlags); + if (state_run_success(rc)) + { + if (!tpkt_ensure_stream_consumed(s, length)) + return STATE_RUN_FAILED; + } + return rc; } BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s) @@ -1196,8 +1200,8 @@ BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s) /* Process any MCS message channel PDUs. */ if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId)) { - if (rdp_handle_message_channel(rdp, s, channelId, length)) - return TRUE; + const state_run_t rc = rdp_handle_message_channel(rdp, s, channelId, length); + return state_run_success(rc); } else { @@ -1228,9 +1232,7 @@ state_run_t rdp_client_connect_license(rdpRdp* rdp, wStream* s) const UINT16 messageChannelId = rdp->mcs->messageChannelId; if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId)) { - if (!rdp_handle_message_channel(rdp, s, channelId, length)) - return STATE_RUN_FAILED; - return STATE_RUN_SUCCESS; + return rdp_handle_message_channel(rdp, s, channelId, length); } if (!rdp_read_security_header(rdp, s, &securityFlags, &length)) @@ -1306,9 +1308,7 @@ state_run_t rdp_client_connect_demand_active(rdpRdp* rdp, wStream* s) if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId)) { rdp->inPackets++; - if (!rdp_handle_message_channel(rdp, s, channelId, length)) - return STATE_RUN_FAILED; - return STATE_RUN_SUCCESS; + return rdp_handle_message_channel(rdp, s, channelId, length); } if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, NULL)) diff --git a/libfreerdp/core/connection.h b/libfreerdp/core/connection.h index a8276260b..7e6e6a57e 100644 --- a/libfreerdp/core/connection.h +++ b/libfreerdp/core/connection.h @@ -73,8 +73,8 @@ FREERDP_LOCAL const char* rdp_client_connection_state_string(UINT state); FREERDP_LOCAL BOOL rdp_channels_from_mcs(rdpSettings* settings, const rdpRdp* rdp); -FREERDP_LOCAL BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, - UINT16 length); +FREERDP_LOCAL state_run_t rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, + UINT16 length); FREERDP_LOCAL BOOL rdp_handle_optional_rdp_decryption(rdpRdp* rdp, wStream* s, UINT16* length, UINT16* pSecurityFlags); diff --git a/libfreerdp/core/freerdp.c b/libfreerdp/core/freerdp.c index 30536a757..58852cf4f 100644 --- a/libfreerdp/core/freerdp.c +++ b/libfreerdp/core/freerdp.c @@ -317,8 +317,13 @@ BOOL freerdp_abort_connect_context(rdpContext* context) /* Try to send a [MS-RDPBCGR] 1.3.1.4.1 User-Initiated on Client PDU, we don't care about * success */ if (context->rdp && context->rdp->mcs) - (void)mcs_send_disconnect_provider_ultimatum(context->rdp->mcs, - Disconnect_Ultimatum_user_requested); + { + if (!context->ServerMode) + { + (void)mcs_send_disconnect_provider_ultimatum(context->rdp->mcs, + Disconnect_Ultimatum_user_requested); + } + } return utils_abort_connect(context->rdp); } diff --git a/libfreerdp/core/peer.c b/libfreerdp/core/peer.c index e8c681daf..36a84fe0c 100644 --- a/libfreerdp/core/peer.c +++ b/libfreerdp/core/peer.c @@ -458,9 +458,7 @@ static state_run_t peer_recv_tpkt_pdu(freerdp_peer* client, wStream* s) if (rdp_get_state(rdp) <= CONNECTION_STATE_LICENSING) { - if (!rdp_handle_message_channel(rdp, s, channelId, length)) - return STATE_RUN_FAILED; - return STATE_RUN_SUCCESS; + return rdp_handle_message_channel(rdp, s, channelId, length); } if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, &securityFlags)) diff --git a/libfreerdp/core/rdp.c b/libfreerdp/core/rdp.c index 0706bca1a..c51d53043 100644 --- a/libfreerdp/core/rdp.c +++ b/libfreerdp/core/rdp.c @@ -1634,9 +1634,7 @@ static state_run_t rdp_recv_tpkt_pdu(rdpRdp* rdp, wStream* s) if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId)) { rdp->inPackets++; - if (!rdp_handle_message_channel(rdp, s, channelId, length)) - return STATE_RUN_FAILED; - return STATE_RUN_SUCCESS; + return rdp_handle_message_channel(rdp, s, channelId, length); } if (rdp->settings->UseRdpSecurityLayer) diff --git a/libfreerdp/core/transport.c b/libfreerdp/core/transport.c index 07b669d23..e0407b8f8 100644 --- a/libfreerdp/core/transport.c +++ b/libfreerdp/core/transport.c @@ -147,22 +147,11 @@ static void transport_ssl_cb(const SSL* ssl, int where, int ret) } } -wStream* transport_send_stream_init(rdpTransport* transport, size_t size) +wStream* transport_send_stream_init(WINPR_ATTR_UNUSED rdpTransport* transport, size_t size) { WINPR_ASSERT(transport); - wStream* s = StreamPool_Take(transport->ReceivePool, size); - if (!s) - return NULL; - - if (!Stream_EnsureCapacity(s, size)) - { - Stream_Release(s); - return NULL; - } - - Stream_SetPosition(s, 0); - return s; + return Stream_New(NULL, size); } BOOL transport_attach(rdpTransport* transport, int sockfd) @@ -1447,9 +1436,12 @@ int transport_check_fds(rdpTransport* transport) } received = transport->ReceiveBuffer; - - if (!(transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0))) + transport->ReceiveBuffer = StreamPool_Take(transport->ReceivePool, 0); + if (!transport->ReceiveBuffer) + { + Stream_Release(received); return -1; + } /** * status: diff --git a/winpr/include/winpr/stream.h b/winpr/include/winpr/stream.h index 4ae7c5b78..8ef97b2ee 100644 --- a/winpr/include/winpr/stream.h +++ b/winpr/include/winpr/stream.h @@ -1368,7 +1368,20 @@ extern "C" WINPR_API void StreamPool_Return(wStreamPool* pool, wStream* s); + /** @brief increment reference count of stream + * + * @param s The stream to reference + * @bug versions < 3.13.0 did only handle streams returned by StreamPool_Take + */ WINPR_API void Stream_AddRef(wStream* s); + + /** @brief Release a reference to a stream. + * If the reference count reaches \b 0 it is returned to the StreamPool it was taken from or \b + * Stream_Free is called. + * + * @param s The stream to release + * @bug versions < 3.13.0 did only handle streams returned by StreamPool_Take + */ WINPR_API void Stream_Release(wStream* s); WINPR_ATTR_MALLOC(Stream_Release, 1) diff --git a/winpr/libwinpr/utils/collections/StreamPool.c b/winpr/libwinpr/utils/collections/StreamPool.c index b550c9bad..8870f4d2f 100644 --- a/winpr/libwinpr/utils/collections/StreamPool.c +++ b/winpr/libwinpr/utils/collections/StreamPool.c @@ -285,10 +285,7 @@ static void StreamPool_Remove(wStreamPool* pool, wStream* s) static void StreamPool_ReleaseOrReturn(wStreamPool* pool, wStream* s) { StreamPool_Lock(pool); - if (s->count > 0) - s->count--; - if (s->count == 0) - StreamPool_Remove(pool, s); + StreamPool_Remove(pool, s); StreamPool_Unlock(pool); } @@ -310,12 +307,7 @@ void StreamPool_Return(wStreamPool* pool, wStream* s) void Stream_AddRef(wStream* s) { WINPR_ASSERT(s); - if (s->pool) - { - StreamPool_Lock(s->pool); - s->count++; - StreamPool_Unlock(s->pool); - } + s->count++; } /** @@ -325,8 +317,16 @@ void Stream_AddRef(wStream* s) void Stream_Release(wStream* s) { WINPR_ASSERT(s); - if (s->pool) - StreamPool_ReleaseOrReturn(s->pool, s); + + if (s->count > 0) + s->count--; + if (s->count == 0) + { + if (s->pool) + StreamPool_ReleaseOrReturn(s->pool, s); + else + Stream_Free(s, TRUE); + } } /** diff --git a/winpr/libwinpr/utils/stream.c b/winpr/libwinpr/utils/stream.c index 145b54f5e..d1ed7b718 100644 --- a/winpr/libwinpr/utils/stream.c +++ b/winpr/libwinpr/utils/stream.c @@ -98,7 +98,7 @@ wStream* Stream_New(BYTE* buffer, size_t size) if (!buffer && !size) return NULL; - s = malloc(sizeof(wStream)); + s = calloc(1, sizeof(wStream)); if (!s) return NULL; @@ -118,7 +118,7 @@ wStream* Stream_New(BYTE* buffer, size_t size) s->length = size; s->pool = NULL; - s->count = 0; + s->count = 1; s->isAllocatedStream = TRUE; s->isOwner = TRUE; return s; @@ -147,7 +147,7 @@ wStream* Stream_StaticInit(wStream* s, BYTE* buffer, size_t size) s->buffer = s->pointer = buffer; s->capacity = s->length = size; s->pool = NULL; - s->count = 0; + s->count = 1; s->isAllocatedStream = FALSE; s->isOwner = FALSE; return s;