diff --git a/libfreerdp/core/connection.c b/libfreerdp/core/connection.c index 0cab2dc02..9674226ae 100644 --- a/libfreerdp/core/connection.c +++ b/libfreerdp/core/connection.c @@ -1139,7 +1139,7 @@ BOOL rdp_client_connect_mcs_channel_join_confirm(rdpRdp* rdp, wStream* s) return TRUE; } -BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT16 length) +state_run_t rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT16 length) { WINPR_ASSERT(rdp); WINPR_ASSERT(rdp->mcs); @@ -1147,36 +1147,40 @@ BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, UINT1 if (!rdp->mcs->messageChannelJoined) { WLog_Print(rdp->log, WLOG_WARN, "MCS message channel not joined!"); - return FALSE; + return STATE_RUN_FAILED; } const UINT16 messageChannelId = rdp->mcs->messageChannelId; if (messageChannelId == 0) { WLog_Print(rdp->log, WLOG_WARN, "MCS message channel id == 0"); - return FALSE; + return STATE_RUN_FAILED; } - if (channelId != messageChannelId) + if ((channelId != messageChannelId) && (channelId != MCS_GLOBAL_CHANNEL_ID)) { - WLog_Print(rdp->log, WLOG_WARN, "MCS message channel expected id=%" PRIu16 ", got %" PRIu16, - messageChannelId, channelId); - return FALSE; + WLog_Print(rdp->log, WLOG_WARN, + "MCS message channel expected id=[%" PRIu16 "|%d], got %" PRIu16, + messageChannelId, MCS_GLOBAL_CHANNEL_ID, channelId); + return STATE_RUN_FAILED; } UINT16 securityFlags = 0; if (!rdp_read_security_header(rdp, s, &securityFlags, &length)) - return FALSE; + return STATE_RUN_FAILED; if (securityFlags & SEC_ENCRYPT) { if (!rdp_decrypt(rdp, s, &length, securityFlags)) - return FALSE; + return STATE_RUN_FAILED; } - if (rdp_recv_message_channel_pdu(rdp, s, securityFlags) != STATE_RUN_SUCCESS) - return FALSE; - - return tpkt_ensure_stream_consumed(s, length); + const state_run_t rc = rdp_recv_message_channel_pdu(rdp, s, securityFlags); + if (state_run_success(rc)) + { + if (!tpkt_ensure_stream_consumed(s, length)) + return STATE_RUN_FAILED; + } + return rc; } BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s) @@ -1196,8 +1200,8 @@ BOOL rdp_client_connect_auto_detect(rdpRdp* rdp, wStream* s) /* Process any MCS message channel PDUs. */ if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId)) { - if (rdp_handle_message_channel(rdp, s, channelId, length)) - return TRUE; + const state_run_t rc = rdp_handle_message_channel(rdp, s, channelId, length); + return state_run_success(rc); } else { @@ -1228,9 +1232,7 @@ state_run_t rdp_client_connect_license(rdpRdp* rdp, wStream* s) const UINT16 messageChannelId = rdp->mcs->messageChannelId; if (rdp->mcs->messageChannelJoined && (channelId == messageChannelId)) { - if (!rdp_handle_message_channel(rdp, s, channelId, length)) - return STATE_RUN_FAILED; - return STATE_RUN_SUCCESS; + return rdp_handle_message_channel(rdp, s, channelId, length); } if (!rdp_read_security_header(rdp, s, &securityFlags, &length)) @@ -1306,9 +1308,7 @@ state_run_t rdp_client_connect_demand_active(rdpRdp* rdp, wStream* s) if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId)) { rdp->inPackets++; - if (!rdp_handle_message_channel(rdp, s, channelId, length)) - return STATE_RUN_FAILED; - return STATE_RUN_SUCCESS; + return rdp_handle_message_channel(rdp, s, channelId, length); } if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, NULL)) diff --git a/libfreerdp/core/connection.h b/libfreerdp/core/connection.h index a8276260b..7e6e6a57e 100644 --- a/libfreerdp/core/connection.h +++ b/libfreerdp/core/connection.h @@ -73,8 +73,8 @@ FREERDP_LOCAL const char* rdp_client_connection_state_string(UINT state); FREERDP_LOCAL BOOL rdp_channels_from_mcs(rdpSettings* settings, const rdpRdp* rdp); -FREERDP_LOCAL BOOL rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, - UINT16 length); +FREERDP_LOCAL state_run_t rdp_handle_message_channel(rdpRdp* rdp, wStream* s, UINT16 channelId, + UINT16 length); FREERDP_LOCAL BOOL rdp_handle_optional_rdp_decryption(rdpRdp* rdp, wStream* s, UINT16* length, UINT16* pSecurityFlags); diff --git a/libfreerdp/core/peer.c b/libfreerdp/core/peer.c index e8c681daf..36a84fe0c 100644 --- a/libfreerdp/core/peer.c +++ b/libfreerdp/core/peer.c @@ -458,9 +458,7 @@ static state_run_t peer_recv_tpkt_pdu(freerdp_peer* client, wStream* s) if (rdp_get_state(rdp) <= CONNECTION_STATE_LICENSING) { - if (!rdp_handle_message_channel(rdp, s, channelId, length)) - return STATE_RUN_FAILED; - return STATE_RUN_SUCCESS; + return rdp_handle_message_channel(rdp, s, channelId, length); } if (!rdp_handle_optional_rdp_decryption(rdp, s, &length, &securityFlags)) diff --git a/libfreerdp/core/rdp.c b/libfreerdp/core/rdp.c index cd792cc64..b27a4cbfd 100644 --- a/libfreerdp/core/rdp.c +++ b/libfreerdp/core/rdp.c @@ -1635,9 +1635,7 @@ static state_run_t rdp_recv_tpkt_pdu(rdpRdp* rdp, wStream* s) if (rdp->mcs->messageChannelId && (channelId == rdp->mcs->messageChannelId)) { rdp->inPackets++; - if (!rdp_handle_message_channel(rdp, s, channelId, length)) - return STATE_RUN_FAILED; - return STATE_RUN_SUCCESS; + return rdp_handle_message_channel(rdp, s, channelId, length); } if (rdp->settings->UseRdpSecurityLayer)