From 103f13575c30ed953baf6d1a3d7d1954f7b49e73 Mon Sep 17 00:00:00 2001 From: David Fort Date: Thu, 25 Aug 2022 08:47:00 +0200 Subject: [PATCH] drdynvc: code cleanup of the client dynamic channel (#8148) This patch does various cleanups in the client dynamic channel. The main goal of the cleanup was to add the sending of Close messages to the server when a channel is locally closed. The refcounter is there to ensures that the DVC_CHANNEL is not freed while some pieces of code are still holding a reference on it. I did some tests by using a custom server-side echo channel at https://github.com/hardening/echoChannel, it allows to send a given amount of packets and then close (to test server-side initiated closes). It compiles with mingw (so under linux) and so it can be easily deployed (no deps). --- channels/drdynvc/client/drdynvc_main.c | 450 ++++++++++++++----------- channels/drdynvc/client/drdynvc_main.h | 10 +- 2 files changed, 257 insertions(+), 203 deletions(-) diff --git a/channels/drdynvc/client/drdynvc_main.c b/channels/drdynvc/client/drdynvc_main.c index 2ee6255cb..2bfe5a14b 100644 --- a/channels/drdynvc/client/drdynvc_main.c +++ b/channels/drdynvc/client/drdynvc_main.c @@ -23,6 +23,7 @@ #include #include +#include #include @@ -30,10 +31,7 @@ #define TAG CHANNELS_TAG("drdynvc.client") -static UINT dvcman_close_channel(IWTSVirtualChannelManager* pChannelMgr, UINT32 ChannelId, - BOOL bSendClosePDU); static void dvcman_free(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChannelMgr); -static void dvcman_channel_free(void* channel); static UINT drdynvc_write_data(drdynvcPlugin* drdynvc, UINT32 ChannelId, const BYTE* data, UINT32 dataSize, BOOL* close); static UINT drdynvc_send(drdynvcPlugin* drdynvc, wStream* s); @@ -199,20 +197,32 @@ static const char* dvcman_get_channel_name(IWTSVirtualChannel* channel) return dvc->channel_name; } -static IWTSVirtualChannel* dvcman_find_channel_by_id(IWTSVirtualChannelManager* pChannelMgr, - UINT32 ChannelId) +static DVCMAN_CHANNEL* dvcman_get_channel_by_id(IWTSVirtualChannelManager* pChannelMgr, + UINT32 ChannelId, BOOL doRef) { - IWTSVirtualChannel* channel = NULL; DVCMAN* dvcman = (DVCMAN*)pChannelMgr; DVCMAN_CHANNEL* dvcChannel; HashTable_Lock(dvcman->channelsById); dvcChannel = HashTable_GetItemValue(dvcman->channelsById, &ChannelId); if (dvcChannel) - channel = &dvcChannel->iface; + { + if (doRef) + InterlockedIncrement(&dvcChannel->refCounter); + } HashTable_Unlock(dvcman->channelsById); - return channel; + return dvcChannel; +} + +static IWTSVirtualChannel* dvcman_find_channel_by_id(IWTSVirtualChannelManager* pChannelMgr, + UINT32 ChannelId) +{ + DVCMAN_CHANNEL* channel = dvcman_get_channel_by_id(pChannelMgr, ChannelId, FALSE); + if (!channel) + return NULL; + + return &channel->iface; } static void dvcman_plugin_terminate(void* plugin) @@ -264,9 +274,6 @@ static IWTSVirtualChannelManager* dvcman_new(drdynvcPlugin* plugin) obj = HashTable_KeyObject(dvcman->channelsById); obj->fnObjectEquals = channelIdMatch; - obj = HashTable_ValueObject(dvcman->channelsById); - obj->fnObjectFree = dvcman_channel_free; - dvcman->pool = StreamPool_New(TRUE, 10); if (!dvcman->pool) goto fail; @@ -336,28 +343,116 @@ static UINT dvcman_load_addin(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* return ERROR_INVALID_FUNCTION; } +static void dvcman_channel_free(DVCMAN_CHANNEL* channel) +{ + WINPR_ASSERT(channel); + + if (channel->dvc_data) + Stream_Release(channel->dvc_data); + + DeleteCriticalSection(&(channel->lock)); + free(channel->channel_name); + free(channel); +} + +static void dvcman_channel_unref(DVCMAN_CHANNEL* channel) +{ + DVCMAN* dvcman; + + if (InterlockedDecrement(&channel->refCounter)) + return; + + dvcman = channel->dvcman; + HashTable_Remove(dvcman->channelsById, &channel->channel_id); + + dvcman_channel_free(channel); +} + +static UINT dvcchannel_send_close(DVCMAN_CHANNEL* channel) +{ + DVCMAN* dvcman = channel->dvcman; + drdynvcPlugin* drdynvc = dvcman->drdynvc; + wStream* s = StreamPool_Take(dvcman->pool, 5); + + if (!s) + { + WLog_Print(drdynvc->log, WLOG_ERROR, "StreamPool_Take failed!"); + return CHANNEL_RC_NO_MEMORY; + } + + Stream_Write_UINT8(s, (CLOSE_REQUEST_PDU << 4) | 0x02); + Stream_Write_UINT32(s, channel->channel_id); + return drdynvc_send(drdynvc, s); +} + +static UINT dvcman_channel_close(DVCMAN_CHANNEL* channel, bool perRequest) +{ + UINT error = CHANNEL_RC_OK; + drdynvcPlugin* drdynvc; + DrdynvcClientContext* context; + + switch (channel->state) + { + case DVC_CHANNEL_INIT: + break; + case DVC_CHANNEL_RUNNING: + drdynvc = channel->dvcman->drdynvc; + context = drdynvc->context; + if (perRequest) + WLog_Print(drdynvc->log, WLOG_DEBUG, "sending close confirm for '%s'", + channel->channel_name); + + error = dvcchannel_send_close(channel); + if (error != CHANNEL_RC_OK) + { + const char* msg = "error when sending close confirm for '%s'"; + if (perRequest) + msg = "error when sending closeRequest for '%s'"; + + WLog_Print(drdynvc->log, WLOG_DEBUG, msg, channel->channel_name); + } + + channel->state = DVC_CHANNEL_CLOSED; + + if (channel->channel_callback) + { + IFCALL(channel->channel_callback->OnClose, channel->channel_callback); + channel->channel_callback = NULL; + } + + if (channel->dvcman && channel->dvcman->drdynvc) + { + if (context) + { + IFCALLRET(context->OnChannelDisconnected, error, context, channel->channel_name, + channel->pInterface); + } + } + + dvcman_channel_unref(channel); + break; + case DVC_CHANNEL_CLOSED: + break; + } + + return error; +} + static DVCMAN_CHANNEL* dvcman_channel_new(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChannelMgr, UINT32 ChannelId, const char* ChannelName) { DVCMAN_CHANNEL* channel; - if (dvcman_find_channel_by_id(pChannelMgr, ChannelId)) - { - WLog_Print(drdynvc->log, WLOG_ERROR, - "Protocol error: Duplicated ChannelId %" PRIu32 " (%s)!", ChannelId, - ChannelName); - return NULL; - } - channel = (DVCMAN_CHANNEL*)calloc(1, sizeof(DVCMAN_CHANNEL)); if (!channel) - goto fail; + return NULL; channel->dvcman = (DVCMAN*)pChannelMgr; channel->channel_id = ChannelId; - channel->channel_name = _strdup(ChannelName); + channel->refCounter = 1; + channel->state = DVC_CHANNEL_INIT, channel->channel_name = _strdup(ChannelName); if (!channel->channel_name) goto fail; @@ -371,50 +466,6 @@ fail: return NULL; } -static void dvcman_channel_free(void* arg) -{ - DVCMAN_CHANNEL* channel = (DVCMAN_CHANNEL*)arg; - UINT error = CHANNEL_RC_OK; - - if (channel) - { - if (channel->channel_callback) - { - IFCALL(channel->channel_callback->OnClose, channel->channel_callback); - channel->channel_callback = NULL; - } - - if (channel->status == CHANNEL_RC_OK) - { - IWTSVirtualChannel* ichannel = (IWTSVirtualChannel*)channel; - - if (channel->dvcman && channel->dvcman->drdynvc) - { - DrdynvcClientContext* context = channel->dvcman->drdynvc->context; - - if (context) - { - IFCALLRET(context->OnChannelDisconnected, error, context, channel->channel_name, - channel->pInterface); - } - } - - error = IFCALLRESULT(CHANNEL_RC_OK, ichannel->Close, ichannel); - - if (error != CHANNEL_RC_OK) - WLog_ERR(TAG, "Close failed with error %" PRIu32 "!", error); - } - - if (channel->dvc_data) - Stream_Release(channel->dvc_data); - - DeleteCriticalSection(&(channel->lock)); - free(channel->channel_name); - } - - free(channel); -} - static void dvcman_clear(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChannelMgr) { DVCMAN* dvcman = (DVCMAN*)pChannelMgr; @@ -493,7 +544,8 @@ static UINT dvcman_write_channel(IWTSVirtualChannel* pChannel, ULONG cbSize, con LeaveCriticalSection(&(channel->lock)); /* Close delayed, it removes the channel struct */ if (close) - dvcman_close_channel(channel->dvcman->drdynvc->channel_mgr, channel->channel_id, TRUE); + dvcman_channel_close(channel, FALSE); + return status; } @@ -510,7 +562,7 @@ static UINT dvcman_close_channel_iface(IWTSVirtualChannel* pChannel) return CHANNEL_RC_BAD_CHANNEL; WLog_DBG(TAG, "close_channel_iface: id=%" PRIu32 "", channel->channel_id); - return CHANNEL_RC_OK; + return dvcman_channel_close(channel, FALSE); } /** @@ -518,37 +570,64 @@ static UINT dvcman_close_channel_iface(IWTSVirtualChannel* pChannel) * * @return 0 on success, otherwise a Win32 error code */ -static UINT dvcman_create_channel(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChannelMgr, - UINT32 ChannelId, const char* ChannelName) +static DVCMAN_CHANNEL* dvcman_create_channel(drdynvcPlugin* drdynvc, + IWTSVirtualChannelManager* pChannelMgr, + UINT32 ChannelId, const char* ChannelName, UINT* res) { BOOL bAccept; - DVCMAN_CHANNEL* channel; + DVCMAN_CHANNEL* channel = NULL; DrdynvcClientContext* context; DVCMAN* dvcman = (DVCMAN*)pChannelMgr; DVCMAN_LISTENER* listener; IWTSVirtualChannelCallback* pCallback = NULL; - UINT error; + + WINPR_ASSERT(res); HashTable_Lock(dvcman->listeners); listener = (DVCMAN_LISTENER*)HashTable_GetItemValue(dvcman->listeners, ChannelName); if (!listener) { - error = ERROR_NOT_FOUND; + *res = ERROR_NOT_FOUND; goto out; } - if (!(channel = dvcman_channel_new(drdynvc, pChannelMgr, ChannelId, ChannelName))) + channel = dvcman_get_channel_by_id(pChannelMgr, ChannelId, FALSE); + if (channel) { - WLog_Print(drdynvc->log, WLOG_ERROR, "dvcman_channel_new failed!"); - error = CHANNEL_RC_NO_MEMORY; - goto out; + switch (channel->state) + { + case DVC_CHANNEL_RUNNING: + WLog_Print(drdynvc->log, WLOG_ERROR, + "Protocol error: Duplicated ChannelId %" PRIu32 " (%s)!", ChannelId, + ChannelName); + *res = CHANNEL_RC_ALREADY_OPEN; + goto out; + + case DVC_CHANNEL_CLOSED: + case DVC_CHANNEL_INIT: + default: + WLog_Print(drdynvc->log, WLOG_ERROR, "not expecting a createChannel from state %d", + channel->state); + *res = CHANNEL_RC_INITIALIZATION_ERROR; + goto out; + } + } + else + { + if (!(channel = dvcman_channel_new(drdynvc, pChannelMgr, ChannelId, ChannelName))) + { + WLog_Print(drdynvc->log, WLOG_ERROR, "dvcman_channel_new failed!"); + *res = CHANNEL_RC_NO_MEMORY; + goto out; + } } - channel->status = ERROR_NOT_CONNECTED; if (!HashTable_Insert(dvcman->channelsById, &channel->channel_id, channel)) { WLog_Print(drdynvc->log, WLOG_ERROR, "unable to register channel in our channel list"); - error = ERROR_INTERNAL_ERROR; + *res = ERROR_INTERNAL_ERROR; + dvcman_channel_free(channel); + channel = NULL; goto out; } @@ -556,42 +635,45 @@ static UINT dvcman_create_channel(drdynvcPlugin* drdynvc, IWTSVirtualChannelMana channel->iface.Close = dvcman_close_channel_iface; bAccept = TRUE; - error = listener->listener_callback->OnNewChannelConnection( + *res = listener->listener_callback->OnNewChannelConnection( listener->listener_callback, &channel->iface, NULL, &bAccept, &pCallback); - if (error != CHANNEL_RC_OK) + if (*res != CHANNEL_RC_OK) { WLog_Print(drdynvc->log, WLOG_ERROR, - "OnNewChannelConnection failed with error %" PRIu32 "!", error); - error = ERROR_INTERNAL_ERROR; + "OnNewChannelConnection failed with error %" PRIu32 "!", *res); + *res = ERROR_INTERNAL_ERROR; + dvcman_channel_unref(channel); goto out; } if (!bAccept) { WLog_Print(drdynvc->log, WLOG_ERROR, "OnNewChannelConnection returned with bAccept FALSE!"); - error = ERROR_INTERNAL_ERROR; + *res = ERROR_INTERNAL_ERROR; + dvcman_channel_unref(channel); + channel = NULL; goto out; } WLog_Print(drdynvc->log, WLOG_DEBUG, "listener %s created new channel %" PRIu32 "", listener->channel_name, channel->channel_id); - channel->status = CHANNEL_RC_OK; + channel->state = DVC_CHANNEL_RUNNING; channel->channel_callback = pCallback; channel->pInterface = listener->iface.pInterface; context = dvcman->drdynvc->context; - IFCALLRET(context->OnChannelConnected, error, context, ChannelName, listener->iface.pInterface); - if (error != CHANNEL_RC_OK) + IFCALLRET(context->OnChannelConnected, *res, context, ChannelName, listener->iface.pInterface); + if (*res != CHANNEL_RC_OK) { WLog_Print(drdynvc->log, WLOG_ERROR, - "context.OnChannelConnected failed with error %" PRIu32 "", error); + "context.OnChannelConnected failed with error %" PRIu32 "", *res); } out: HashTable_Unlock(dvcman->listeners); - return error; + return channel; } /** @@ -599,21 +681,12 @@ out: * * @return 0 on success, otherwise a Win32 error code */ -static UINT dvcman_open_channel(drdynvcPlugin* drdynvc, IWTSVirtualChannelManager* pChannelMgr, - UINT32 ChannelId) +static UINT dvcman_open_channel(drdynvcPlugin* drdynvc, DVCMAN_CHANNEL* channel) { - DVCMAN_CHANNEL* channel; IWTSVirtualChannelCallback* pCallback; - UINT error; - channel = (DVCMAN_CHANNEL*)dvcman_find_channel_by_id(pChannelMgr, ChannelId); + UINT error = CHANNEL_RC_OK; - if (!channel) - { - WLog_Print(drdynvc->log, WLOG_ERROR, "ChannelId %" PRIu32 " not found!", ChannelId); - return ERROR_INTERNAL_ERROR; - } - - if (channel->status == CHANNEL_RC_OK) + if (channel->state == DVC_CHANNEL_RUNNING) { pCallback = channel->channel_callback; @@ -624,57 +697,15 @@ static UINT dvcman_open_channel(drdynvcPlugin* drdynvc, IWTSVirtualChannelManage { WLog_Print(drdynvc->log, WLOG_ERROR, "OnOpen failed with error %" PRIu32 "!", error); - return error; + goto out; } } - WLog_Print(drdynvc->log, WLOG_DEBUG, "open_channel: ChannelId %" PRIu32 "", ChannelId); + WLog_Print(drdynvc->log, WLOG_DEBUG, "open_channel: ChannelId %" PRIu32 "", + channel->channel_id); } - return CHANNEL_RC_OK; -} - -/** - * Function description - * - * @return 0 on success, otherwise a Win32 error code - */ -UINT dvcman_close_channel(IWTSVirtualChannelManager* pChannelMgr, UINT32 ChannelId, - BOOL bSendClosePDU) -{ - DVCMAN_CHANNEL* channel; - UINT error = CHANNEL_RC_OK; - DVCMAN* dvcman = (DVCMAN*)pChannelMgr; - drdynvcPlugin* drdynvc = dvcman->drdynvc; - - channel = (DVCMAN_CHANNEL*)dvcman_find_channel_by_id(pChannelMgr, ChannelId); - if (!channel) - { - // WLog_Print(drdynvc->log, WLOG_ERROR, "ChannelId %"PRIu32" not found!", ChannelId); - /** - * Windows 8 / Windows Server 2012 send close requests for channels that failed to be - * created. Do not warn, simply return success here. - */ - return CHANNEL_RC_OK; - } - - if (drdynvc && bSendClosePDU) - { - wStream* s = StreamPool_Take(dvcman->pool, 5); - if (!s) - { - WLog_Print(drdynvc->log, WLOG_ERROR, "StreamPool_Take failed!"); - error = CHANNEL_RC_NO_MEMORY; - } - else - { - Stream_Write_UINT8(s, (CLOSE_REQUEST_PDU << 4) | 0x02); - Stream_Write_UINT32(s, ChannelId); - error = drdynvc_send(drdynvc, s); - } - } - - HashTable_Remove(dvcman->channelsById, &ChannelId); +out: return error; } @@ -683,24 +714,8 @@ UINT dvcman_close_channel(IWTSVirtualChannelManager* pChannelMgr, UINT32 Channel * * @return 0 on success, otherwise a Win32 error code */ -static UINT dvcman_receive_channel_data_first(drdynvcPlugin* drdynvc, - IWTSVirtualChannelManager* pChannelMgr, - UINT32 ChannelId, UINT32 length) +static UINT dvcman_receive_channel_data_first(DVCMAN_CHANNEL* channel, UINT32 length) { - DVCMAN_CHANNEL* channel; - channel = (DVCMAN_CHANNEL*)dvcman_find_channel_by_id(pChannelMgr, ChannelId); - - if (!channel) - { - /** - * Windows Server 2012 R2 can send some messages over - * Microsoft::Windows::RDS::Geometry::v08.01 even if the dynamic virtual channel wasn't - * registered on our side. Ignoring it works. - */ - WLog_Print(drdynvc->log, WLOG_ERROR, "ChannelId %" PRIu32 " not found!", ChannelId); - return CHANNEL_RC_OK; - } - if (channel->dvc_data) Stream_Release(channel->dvc_data); @@ -708,6 +723,7 @@ static UINT dvcman_receive_channel_data_first(drdynvcPlugin* drdynvc, if (!channel->dvc_data) { + drdynvcPlugin* drdynvc = channel->dvcman->drdynvc; WLog_Print(drdynvc->log, WLOG_ERROR, "StreamPool_Take failed!"); return CHANNEL_RC_NO_MEMORY; } @@ -721,32 +737,24 @@ static UINT dvcman_receive_channel_data_first(drdynvcPlugin* drdynvc, * * @return 0 on success, otherwise a Win32 error code */ -static UINT dvcman_receive_channel_data(drdynvcPlugin* drdynvc, - IWTSVirtualChannelManager* pChannelMgr, UINT32 ChannelId, - wStream* data, UINT32 ThreadingFlags) +static UINT dvcman_receive_channel_data(DVCMAN_CHANNEL* channel, wStream* data, + UINT32 ThreadingFlags) { UINT status = CHANNEL_RC_OK; - DVCMAN_CHANNEL* channel; size_t dataSize = Stream_GetRemainingLength(data); - channel = (DVCMAN_CHANNEL*)dvcman_find_channel_by_id(pChannelMgr, ChannelId); - - if (!channel) - { - /* Windows 8.1 tries to open channels not created. - * Ignore cases like this. */ - WLog_Print(drdynvc->log, WLOG_ERROR, "ChannelId %" PRIu32 " not found!", ChannelId); - return CHANNEL_RC_OK; - } if (channel->dvc_data) { + drdynvcPlugin* drdynvc = channel->dvcman->drdynvc; + /* Fragmented data */ if (Stream_GetPosition(channel->dvc_data) + dataSize > channel->dvc_data_length) { WLog_Print(drdynvc->log, WLOG_ERROR, "data exceeding declared length!"); Stream_Release(channel->dvc_data); channel->dvc_data = NULL; - return ERROR_INVALID_DATA; + status = ERROR_INVALID_DATA; + goto out; } Stream_Copy(data, channel->dvc_data, dataSize); @@ -766,6 +774,7 @@ static UINT dvcman_receive_channel_data(drdynvcPlugin* drdynvc, status = channel->channel_callback->OnDataReceived(channel->channel_callback, data); } +out: return status; } @@ -871,6 +880,7 @@ static UINT drdynvc_write_data(drdynvcPlugin* drdynvc, UINT32 ChannelId, const B if (dataSize == 0) { + /* TODO: shall treat that case with write(0) that do a close */ *close = TRUE; Stream_Release(data_out); } @@ -1062,6 +1072,7 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c char* name; size_t length; DVCMAN* dvcman; + DVCMAN_CHANNEL* channel; UINT32 retStatus; WINPR_UNUSED(Sp); @@ -1100,7 +1111,6 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c WLog_Print(drdynvc->log, WLOG_DEBUG, "process_create_request: ChannelId=%" PRIu32 " ChannelName=%s", ChannelId, name); - channel_status = dvcman_create_channel(drdynvc, drdynvc->channel_mgr, ChannelId, name); data_out = StreamPool_Take(dvcman->pool, pos + 4); if (!data_out) @@ -1113,6 +1123,8 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c Stream_SetPosition(s, 1); Stream_Copy(s, data_out, pos - 1); + channel = + dvcman_create_channel(drdynvc, drdynvc->channel_mgr, ChannelId, name, &channel_status); switch (channel_status) { case CHANNEL_RC_OK: @@ -1135,29 +1147,23 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c Stream_Write_UINT32(data_out, retStatus); status = drdynvc_send(drdynvc, data_out); - if (status != CHANNEL_RC_OK) { WLog_Print(drdynvc->log, WLOG_ERROR, "VirtualChannelWriteEx failed with %s [%08" PRIX32 "]", WTSErrorToString(status), status); + dvcman_channel_unref(channel); return status; } if (channel_status == CHANNEL_RC_OK) { - if ((status = dvcman_open_channel(drdynvc, drdynvc->channel_mgr, ChannelId))) + if ((status = dvcman_open_channel(drdynvc, channel))) { WLog_Print(drdynvc->log, WLOG_ERROR, "dvcman_open_channel failed with error %" PRIu32 "!", status); return status; } } - else - { - if ((status = dvcman_close_channel(drdynvc->channel_mgr, ChannelId, FALSE))) - WLog_Print(drdynvc->log, WLOG_ERROR, - "dvcman_close_channel failed with error %" PRIu32 "!", status); - } return status; } @@ -1170,9 +1176,10 @@ static UINT drdynvc_process_create_request(drdynvcPlugin* drdynvc, int Sp, int c static UINT drdynvc_process_data_first(drdynvcPlugin* drdynvc, int Sp, int cbChId, wStream* s, UINT32 ThreadingFlags) { - UINT status; + UINT status = CHANNEL_RC_OK; UINT32 Length; UINT32 ChannelId; + DVCMAN_CHANNEL* channel; if (!Stream_CheckAndLogRequiredLength( TAG, s, drdynvc_cblen_to_bytes(cbChId) + drdynvc_cblen_to_bytes(Sp))) @@ -1183,15 +1190,32 @@ static UINT drdynvc_process_data_first(drdynvcPlugin* drdynvc, int Sp, int cbChI WLog_Print(drdynvc->log, WLOG_TRACE, "process_data_first: Sp=%d cbChId=%d, ChannelId=%" PRIu32 " Length=%" PRIu32 "", Sp, cbChId, ChannelId, Length); - status = dvcman_receive_channel_data_first(drdynvc, drdynvc->channel_mgr, ChannelId, Length); + + channel = dvcman_get_channel_by_id(drdynvc->channel_mgr, ChannelId, TRUE); + if (!channel) + { + /** + * Windows Server 2012 R2 can send some messages over + * Microsoft::Windows::RDS::Geometry::v08.01 even if the dynamic virtual channel wasn't + * registered on our side. Ignoring it works. + */ + WLog_Print(drdynvc->log, WLOG_ERROR, "ChannelId %" PRIu32 " not found!", ChannelId); + return CHANNEL_RC_OK; + } + + if (channel->state != DVC_CHANNEL_RUNNING) + goto out; + + status = dvcman_receive_channel_data_first(channel, Length); if (status == CHANNEL_RC_OK) - status = dvcman_receive_channel_data(drdynvc, drdynvc->channel_mgr, ChannelId, s, - ThreadingFlags); + status = dvcman_receive_channel_data(channel, s, ThreadingFlags); if (status != CHANNEL_RC_OK) - status = dvcman_close_channel(drdynvc->channel_mgr, ChannelId, TRUE); + status = dvcman_channel_close(channel, FALSE); +out: + dvcman_channel_unref(channel); return status; } @@ -1204,7 +1228,8 @@ static UINT drdynvc_process_data(drdynvcPlugin* drdynvc, int Sp, int cbChId, wSt UINT32 ThreadingFlags) { UINT32 ChannelId; - UINT status; + DVCMAN_CHANNEL* channel; + UINT status = CHANNEL_RC_OK; if (!Stream_CheckAndLogRequiredLength(TAG, s, drdynvc_cblen_to_bytes(cbChId))) return ERROR_INVALID_DATA; @@ -1212,12 +1237,28 @@ static UINT drdynvc_process_data(drdynvcPlugin* drdynvc, int Sp, int cbChId, wSt ChannelId = drdynvc_read_variable_uint(s, cbChId); WLog_Print(drdynvc->log, WLOG_TRACE, "process_data: Sp=%d cbChId=%d, ChannelId=%" PRIu32 "", Sp, cbChId, ChannelId); - status = - dvcman_receive_channel_data(drdynvc, drdynvc->channel_mgr, ChannelId, s, ThreadingFlags); + channel = dvcman_get_channel_by_id(drdynvc->channel_mgr, ChannelId, TRUE); + if (!channel) + { + /** + * Windows Server 2012 R2 can send some messages over + * Microsoft::Windows::RDS::Geometry::v08.01 even if the dynamic virtual channel wasn't + * registered on our side. Ignoring it works. + */ + WLog_Print(drdynvc->log, WLOG_ERROR, "ChannelId %" PRIu32 " not found!", ChannelId); + return CHANNEL_RC_OK; + } + + if (channel->state != DVC_CHANNEL_RUNNING) + goto out; + + status = dvcman_receive_channel_data(channel, s, ThreadingFlags); if (status != CHANNEL_RC_OK) - status = dvcman_close_channel(drdynvc->channel_mgr, ChannelId, TRUE); + status = dvcman_channel_close(channel, FALSE); +out: + dvcman_channel_unref(channel); return status; } @@ -1228,8 +1269,8 @@ static UINT drdynvc_process_data(drdynvcPlugin* drdynvc, int Sp, int cbChId, wSt */ static UINT drdynvc_process_close_request(drdynvcPlugin* drdynvc, int Sp, int cbChId, wStream* s) { - UINT error; UINT32 ChannelId; + DVCMAN_CHANNEL* channel; if (!Stream_CheckAndLogRequiredLength(TAG, s, drdynvc_cblen_to_bytes(cbChId))) return ERROR_INVALID_DATA; @@ -1239,11 +1280,17 @@ static UINT drdynvc_process_close_request(drdynvcPlugin* drdynvc, int Sp, int cb "process_close_request: Sp=%d cbChId=%d, ChannelId=%" PRIu32 "", Sp, cbChId, ChannelId); - if ((error = dvcman_close_channel(drdynvc->channel_mgr, ChannelId, TRUE))) - WLog_Print(drdynvc->log, WLOG_ERROR, "dvcman_close_channel failed with error %" PRIu32 "!", - error); + channel = (DVCMAN_CHANNEL*)dvcman_get_channel_by_id(drdynvc->channel_mgr, ChannelId, TRUE); + if (!channel) + { + WLog_Print(drdynvc->log, WLOG_ERROR, "dvcman_close_request channel %" PRIu32 " not present", + ChannelId); + return CHANNEL_RC_OK; + } - return error; + dvcman_channel_close(channel, TRUE); + dvcman_channel_unref(channel); + return CHANNEL_RC_OK; } /** @@ -1413,10 +1460,9 @@ static void VCAPITYPE drdynvc_virtual_channel_open_event_ex(LPVOID lpUserParam, static BOOL channelByIdCleanerFn(const void* key, void* value, void* arg) { - drdynvcPlugin* drdynvc = (drdynvcPlugin*)arg; DVCMAN_CHANNEL* channel = (DVCMAN_CHANNEL*)value; - dvcman_close_channel(drdynvc->channel_mgr, channel->channel_id, FALSE); + dvcman_channel_close(channel, FALSE); return TRUE; } diff --git a/channels/drdynvc/client/drdynvc_main.h b/channels/drdynvc/client/drdynvc_main.h index b0785b969..c282cdd61 100644 --- a/channels/drdynvc/client/drdynvc_main.h +++ b/channels/drdynvc/client/drdynvc_main.h @@ -70,11 +70,19 @@ typedef struct rdpContext* context; } DVCMAN_ENTRY_POINTS; +typedef enum +{ + DVC_CHANNEL_INIT, + DVC_CHANNEL_RUNNING, + DVC_CHANNEL_CLOSED +} DVC_CHANNEL_STATE; + typedef struct { IWTSVirtualChannel iface; - int status; + volatile LONG refCounter; + DVC_CHANNEL_STATE state; DVCMAN* dvcman; void* pInterface; UINT32 channel_id;