[core,server] Improve WTS API locking

* Use InitOnceExecuteOnce to initialize global hash table for server
  handles
* Lock full function calls when manipulating global data (Hash table and
  session ID count)
This commit is contained in:
Armin Novak
2026-03-03 15:21:40 +01:00
parent 32b2bd22aa
commit d079f59e55

View File

@@ -65,8 +65,9 @@ typedef struct
static const DWORD g_err_oom = WINPR_CXX_COMPAT_CAST(DWORD, E_OUTOFMEMORY);
static DWORD g_SessionId = 1;
static DWORD g_SessionId = 0;
static wHashTable* g_ServerHandles = nullptr;
static INIT_ONCE g_HandleInitializer = INIT_ONCE_STATIC_INIT;
static rdpPeerChannel* wts_get_dvc_channel_by_id(WTSVirtualChannelManager* vcm, UINT32 ChannelId)
{
@@ -1001,20 +1002,57 @@ static UINT32 channelId_Hash(const void* key)
return *v;
}
static void clearHandles(void)
{
HashTable_Free(g_ServerHandles);
g_ServerHandles = nullptr;
}
static BOOL CALLBACK initializeHandles(WINPR_ATTR_UNUSED PINIT_ONCE once,
WINPR_ATTR_UNUSED PVOID param,
WINPR_ATTR_UNUSED PVOID* context)
{
WINPR_ASSERT(g_ServerHandles == nullptr);
g_ServerHandles = HashTable_New(TRUE);
g_SessionId = 1;
(void)atexit(clearHandles);
return g_ServerHandles != nullptr;
}
static void wtsCloseVCM(WTSVirtualChannelManager* vcm, bool closeDrdynvc)
{
WINPR_ASSERT(vcm);
HashTable_Lock(g_ServerHandles);
if (vcm && (vcm != INVALID_HANDLE_VALUE))
{
HashTable_Remove(g_ServerHandles, (void*)(UINT_PTR)vcm->SessionId);
HashTable_Free(vcm->dynamicVirtualChannels);
if (vcm->drdynvc_channel)
{
if (closeDrdynvc)
(void)WTSVirtualChannelClose(vcm->drdynvc_channel);
vcm->drdynvc_channel = nullptr;
}
MessageQueue_Free(vcm->queue);
free(vcm);
}
HashTable_Unlock(g_ServerHandles);
}
HANDLE WINAPI FreeRDP_WTSOpenServerA(LPSTR pServerName)
{
rdpContext* context = nullptr;
freerdp_peer* client = nullptr;
WTSVirtualChannelManager* vcm = nullptr;
HANDLE hServer = INVALID_HANDLE_VALUE;
wObject queueCallbacks = WINPR_C_ARRAY_INIT;
context = (rdpContext*)pServerName;
rdpContext* context = (rdpContext*)pServerName;
if (!context)
return INVALID_HANDLE_VALUE;
client = context->peer;
freerdp_peer* client = context->peer;
if (!client)
{
@@ -1022,40 +1060,32 @@ HANDLE WINAPI FreeRDP_WTSOpenServerA(LPSTR pServerName)
return INVALID_HANDLE_VALUE;
}
vcm = (WTSVirtualChannelManager*)calloc(1, sizeof(WTSVirtualChannelManager));
if (!InitOnceExecuteOnce(&g_HandleInitializer, initializeHandles, nullptr, nullptr))
return INVALID_HANDLE_VALUE;
WTSVirtualChannelManager* vcm =
(WTSVirtualChannelManager*)calloc(1, sizeof(WTSVirtualChannelManager));
if (!vcm)
goto error_vcm_alloc;
goto fail;
vcm->client = client;
vcm->rdp = context->rdp;
vcm->SessionId = g_SessionId++;
if (!g_ServerHandles)
{
g_ServerHandles = HashTable_New(TRUE);
if (!g_ServerHandles)
goto error_free;
}
if (!HashTable_Insert(g_ServerHandles, (void*)(UINT_PTR)vcm->SessionId, (void*)vcm))
goto error_free;
queueCallbacks.fnObjectFree = wts_virtual_channel_manager_free_message;
vcm->queue = MessageQueue_New(&queueCallbacks);
if (!vcm->queue)
goto error_queue;
goto fail;
vcm->dvc_channel_id_seq = 0;
vcm->dynamicVirtualChannels = HashTable_New(TRUE);
if (!vcm->dynamicVirtualChannels)
goto error_dynamicVirtualChannels;
goto fail;
if (!HashTable_SetHashFunction(vcm->dynamicVirtualChannels, channelId_Hash))
goto error_hashFunction;
goto fail;
{
wObject* obj = HashTable_ValueObject(vcm->dynamicVirtualChannels);
@@ -1066,18 +1096,21 @@ HANDLE WINAPI FreeRDP_WTSOpenServerA(LPSTR pServerName)
obj->fnObjectEquals = dynChannelMatch;
}
client->ReceiveChannelData = WTSReceiveChannelData;
hServer = (HANDLE)vcm;
{
HashTable_Lock(g_ServerHandles);
vcm->SessionId = g_SessionId++;
const BOOL rc =
HashTable_Insert(g_ServerHandles, (void*)(UINT_PTR)vcm->SessionId, (void*)vcm);
HashTable_Unlock(g_ServerHandles);
if (!rc)
goto fail;
}
HANDLE hServer = (HANDLE)vcm;
return hServer;
error_hashFunction:
HashTable_Free(vcm->dynamicVirtualChannels);
error_dynamicVirtualChannels:
MessageQueue_Free(vcm->queue);
error_queue:
HashTable_Remove(g_ServerHandles, (void*)(UINT_PTR)vcm->SessionId);
error_free:
free(vcm);
error_vcm_alloc:
fail:
wtsCloseVCM(vcm, false);
SetLastError(ERROR_NOT_ENOUGH_MEMORY);
return INVALID_HANDLE_VALUE;
}
@@ -1095,24 +1128,8 @@ HANDLE WINAPI FreeRDP_WTSOpenServerExA(LPSTR pServerName)
VOID WINAPI FreeRDP_WTSCloseServer(HANDLE hServer)
{
WTSVirtualChannelManager* vcm = nullptr;
vcm = (WTSVirtualChannelManager*)hServer;
if (vcm && (vcm != INVALID_HANDLE_VALUE))
{
HashTable_Remove(g_ServerHandles, (void*)(UINT_PTR)vcm->SessionId);
HashTable_Free(vcm->dynamicVirtualChannels);
if (vcm->drdynvc_channel)
{
(void)WTSVirtualChannelClose(vcm->drdynvc_channel);
vcm->drdynvc_channel = nullptr;
}
MessageQueue_Free(vcm->queue);
free(vcm);
}
WTSVirtualChannelManager* vcm = (WTSVirtualChannelManager*)hServer;
wtsCloseVCM(vcm, true);
}
BOOL WINAPI FreeRDP_WTSEnumerateSessionsW(WINPR_ATTR_UNUSED HANDLE hServer,
@@ -1438,29 +1455,33 @@ fail:
HANDLE WINAPI FreeRDP_WTSVirtualChannelOpenEx(DWORD SessionId, LPSTR pVirtualName, DWORD flags)
{
wStream* s = nullptr;
rdpMcs* mcs = nullptr;
BOOL joined = FALSE;
freerdp_peer* client = nullptr;
rdpPeerChannel* channel = nullptr;
BOOL joined = FALSE;
ULONG written = 0;
WTSVirtualChannelManager* vcm = nullptr;
if (SessionId == WTS_CURRENT_SESSION)
return nullptr;
vcm = (WTSVirtualChannelManager*)HashTable_GetItemValue(g_ServerHandles,
(void*)(UINT_PTR)SessionId);
HashTable_Lock(g_ServerHandles);
WTSVirtualChannelManager* vcm = (WTSVirtualChannelManager*)HashTable_GetItemValue(
g_ServerHandles, (void*)(UINT_PTR)SessionId);
if (!vcm)
return nullptr;
goto end;
if (!(flags & WTS_CHANNEL_OPTION_DYNAMIC))
{
HashTable_Unlock(g_ServerHandles);
return FreeRDP_WTSVirtualChannelOpen((HANDLE)vcm, SessionId, pVirtualName);
}
client = vcm->client;
mcs = client->context->rdp->mcs;
freerdp_peer* client = vcm->client;
WINPR_ASSERT(client);
WINPR_ASSERT(client->context);
WINPR_ASSERT(client->context->rdp);
rdpMcs* mcs = client->context->rdp->mcs;
WINPR_ASSERT(mcs);
for (UINT32 index = 0; index < mcs->channelCount; index++)
{
@@ -1476,13 +1497,13 @@ HANDLE WINAPI FreeRDP_WTSVirtualChannelOpenEx(DWORD SessionId, LPSTR pVirtualNam
if (!joined)
{
SetLastError(ERROR_NOT_FOUND);
return nullptr;
goto end;
}
if (!vcm->drdynvc_channel || (vcm->drdynvc_state != DRDYNVC_STATE_READY))
{
SetLastError(ERROR_NOT_READY);
return nullptr;
goto end;
}
WINPR_ASSERT(client);
@@ -1496,7 +1517,7 @@ HANDLE WINAPI FreeRDP_WTSVirtualChannelOpenEx(DWORD SessionId, LPSTR pVirtualNam
if (!channel)
{
SetLastError(ERROR_NOT_ENOUGH_MEMORY);
return nullptr;
goto end;
}
const LONG hdl = InterlockedIncrement(&vcm->dvc_channel_id_seq);
@@ -1524,12 +1545,16 @@ HANDLE WINAPI FreeRDP_WTSVirtualChannelOpenEx(DWORD SessionId, LPSTR pVirtualNam
goto fail;
}
end:
Stream_Free(s, TRUE);
HashTable_Unlock(g_ServerHandles);
return channel;
fail:
Stream_Free(s, TRUE);
if (channel)
HashTable_Remove(vcm->dynamicVirtualChannels, &channel->channelId);
HashTable_Unlock(g_ServerHandles);
SetLastError(ERROR_NOT_ENOUGH_MEMORY);
return nullptr;