diff --git a/channels/ainput/client/ainput_main.c b/channels/ainput/client/ainput_main.c index c291bd727..554575360 100644 --- a/channels/ainput/client/ainput_main.c +++ b/channels/ainput/client/ainput_main.c @@ -45,6 +45,7 @@ struct AINPUT_PLUGIN_ AInputClientContext* context; UINT32 MajorVersion; UINT32 MinorVersion; + CRITICAL_SECTION lock; }; /** @@ -85,18 +86,15 @@ static UINT ainput_on_data_received(IWTSVirtualChannelCallback* pChannelCallback static UINT ainput_send_input_event(AInputClientContext* context, UINT64 flags, INT32 x, INT32 y) { - AINPUT_PLUGIN* ainput = NULL; - GENERIC_CHANNEL_CALLBACK* callback = NULL; BYTE buffer[32] = { 0 }; - UINT64 time = 0; wStream sbuffer = { 0 }; wStream* s = Stream_StaticInit(&sbuffer, buffer, sizeof(buffer)); WINPR_ASSERT(s); WINPR_ASSERT(context); - time = GetTickCount64(); - ainput = (AINPUT_PLUGIN*)context->handle; + const UINT64 time = GetTickCount64(); + AINPUT_PLUGIN* ainput = (AINPUT_PLUGIN*)context->handle; WINPR_ASSERT(ainput); if (ainput->MajorVersion != AINPUT_VERSION_MAJOR) @@ -105,8 +103,6 @@ static UINT ainput_send_input_event(AInputClientContext* context, UINT64 flags, ainput->MajorVersion, ainput->MinorVersion); return CHANNEL_RC_UNSUPPORTED_VERSION; } - callback = ainput->base.listener_callback->channel_callback; - WINPR_ASSERT(callback); { char ebuffer[128] = { 0 }; @@ -125,10 +121,15 @@ static UINT ainput_send_input_event(AInputClientContext* context, UINT64 flags, Stream_SealLength(s); /* ainput back what we have received. AINPUT does not have any message IDs. */ + EnterCriticalSection(&ainput->lock); + GENERIC_CHANNEL_CALLBACK* callback = ainput->base.listener_callback->channel_callback; + WINPR_ASSERT(callback); WINPR_ASSERT(callback->channel); WINPR_ASSERT(callback->channel->Write); - return callback->channel->Write(callback->channel, (ULONG)Stream_Length(s), Stream_Buffer(s), - NULL); + const UINT rc = callback->channel->Write(callback->channel, (ULONG)Stream_Length(s), + Stream_Buffer(s), NULL); + LeaveCriticalSection(&ainput->lock); + return rc; } /** @@ -140,8 +141,16 @@ static UINT ainput_on_close(IWTSVirtualChannelCallback* pChannelCallback) { GENERIC_CHANNEL_CALLBACK* callback = (GENERIC_CHANNEL_CALLBACK*)pChannelCallback; - free(callback); + if (callback) + { + AINPUT_PLUGIN* ainput = (AINPUT_PLUGIN*)callback->plugin; + WINPR_ASSERT(ainput); + /* Lock here to ensure that no ainput_send_input_event is in progress. */ + EnterCriticalSection(&ainput->lock); + free(callback); + LeaveCriticalSection(&ainput->lock); + } return CHANNEL_RC_OK; } @@ -156,14 +165,21 @@ static UINT init_plugin_cb(GENERIC_DYNVC_PLUGIN* base, WINPR_ATTR_UNUSED rdpCont context->handle = (void*)base; context->AInputSendInputEvent = ainput_send_input_event; + InitializeCriticalSection(&ainput->lock); + + EnterCriticalSection(&ainput->lock); ainput->context = context; ainput->base.iface.pInterface = context; + LeaveCriticalSection(&ainput->lock); return CHANNEL_RC_OK; } static void terminate_plugin_cb(GENERIC_DYNVC_PLUGIN* base) { AINPUT_PLUGIN* ainput = (AINPUT_PLUGIN*)base; + WINPR_ASSERT(ainput); + + DeleteCriticalSection(&ainput->lock); free(ainput->context); }