[channels,ainput] lock context when updating listener

This commit is contained in:
akallabeth
2026-01-26 12:08:48 +01:00
parent d676518809
commit d9ca272dce

View File

@@ -45,6 +45,7 @@ struct AINPUT_PLUGIN_
AInputClientContext* context;
UINT32 MajorVersion;
UINT32 MinorVersion;
CRITICAL_SECTION lock;
};
/**
@@ -85,18 +86,15 @@ static UINT ainput_on_data_received(IWTSVirtualChannelCallback* pChannelCallback
static UINT ainput_send_input_event(AInputClientContext* context, UINT64 flags, INT32 x, INT32 y)
{
AINPUT_PLUGIN* ainput = NULL;
GENERIC_CHANNEL_CALLBACK* callback = NULL;
BYTE buffer[32] = { 0 };
UINT64 time = 0;
wStream sbuffer = { 0 };
wStream* s = Stream_StaticInit(&sbuffer, buffer, sizeof(buffer));
WINPR_ASSERT(s);
WINPR_ASSERT(context);
time = GetTickCount64();
ainput = (AINPUT_PLUGIN*)context->handle;
const UINT64 time = GetTickCount64();
AINPUT_PLUGIN* ainput = (AINPUT_PLUGIN*)context->handle;
WINPR_ASSERT(ainput);
if (ainput->MajorVersion != AINPUT_VERSION_MAJOR)
@@ -105,8 +103,6 @@ static UINT ainput_send_input_event(AInputClientContext* context, UINT64 flags,
ainput->MajorVersion, ainput->MinorVersion);
return CHANNEL_RC_UNSUPPORTED_VERSION;
}
callback = ainput->base.listener_callback->channel_callback;
WINPR_ASSERT(callback);
{
char ebuffer[128] = { 0 };
@@ -125,10 +121,15 @@ static UINT ainput_send_input_event(AInputClientContext* context, UINT64 flags,
Stream_SealLength(s);
/* ainput back what we have received. AINPUT does not have any message IDs. */
EnterCriticalSection(&ainput->lock);
GENERIC_CHANNEL_CALLBACK* callback = ainput->base.listener_callback->channel_callback;
WINPR_ASSERT(callback);
WINPR_ASSERT(callback->channel);
WINPR_ASSERT(callback->channel->Write);
return callback->channel->Write(callback->channel, (ULONG)Stream_Length(s), Stream_Buffer(s),
NULL);
const UINT rc = callback->channel->Write(callback->channel, (ULONG)Stream_Length(s),
Stream_Buffer(s), NULL);
LeaveCriticalSection(&ainput->lock);
return rc;
}
/**
@@ -140,8 +141,16 @@ static UINT ainput_on_close(IWTSVirtualChannelCallback* pChannelCallback)
{
GENERIC_CHANNEL_CALLBACK* callback = (GENERIC_CHANNEL_CALLBACK*)pChannelCallback;
free(callback);
if (callback)
{
AINPUT_PLUGIN* ainput = (AINPUT_PLUGIN*)callback->plugin;
WINPR_ASSERT(ainput);
/* Lock here to ensure that no ainput_send_input_event is in progress. */
EnterCriticalSection(&ainput->lock);
free(callback);
LeaveCriticalSection(&ainput->lock);
}
return CHANNEL_RC_OK;
}
@@ -156,14 +165,21 @@ static UINT init_plugin_cb(GENERIC_DYNVC_PLUGIN* base, WINPR_ATTR_UNUSED rdpCont
context->handle = (void*)base;
context->AInputSendInputEvent = ainput_send_input_event;
InitializeCriticalSection(&ainput->lock);
EnterCriticalSection(&ainput->lock);
ainput->context = context;
ainput->base.iface.pInterface = context;
LeaveCriticalSection(&ainput->lock);
return CHANNEL_RC_OK;
}
static void terminate_plugin_cb(GENERIC_DYNVC_PLUGIN* base)
{
AINPUT_PLUGIN* ainput = (AINPUT_PLUGIN*)base;
WINPR_ASSERT(ainput);
DeleteCriticalSection(&ainput->lock);
free(ainput->context);
}