Merge pull request #11606 from akallabeth/GetAccessToken-split

[core,aad] Split GetAccessToken callback
This commit is contained in:
akallabeth
2025-05-21 11:13:30 +02:00
committed by GitHub
8 changed files with 142 additions and 19 deletions

View File

@@ -1014,7 +1014,8 @@ static BOOL hotplug_delete_foreach(ULONG_PTR key, void* element, void* data)
return TRUE;
}
static UINT handle_hotplug(RdpdrClientContext* context, RdpdrHotplugEventType type)
static UINT handle_hotplug(RdpdrClientContext* context,
WINPR_ATTR_UNUSED RdpdrHotplugEventType type)
{
WINPR_ASSERT(context);
rdpdrPlugin* rdpdr = context->handle;
@@ -2341,7 +2342,8 @@ static UINT rdpdr_unregister_device(RdpdrClientContext* context, size_t count, c
const uintptr_t id = ids[x];
devman_unregister_device(rdpdr->devman, (void*)id);
}
return rdpdr_send_device_list_remove_request(rdpdr, count, ids);
return rdpdr_send_device_list_remove_request(rdpdr, WINPR_ASSERTING_INT_CAST(uint32_t, count),
ids);
}
static UINT rdpdr_virtual_channel_event_initialized(rdpdrPlugin* rdpdr,

View File

@@ -134,9 +134,35 @@ extern "C"
ACCESS_TOKEN_TYPE_AVD /**!< oauth2 access token for Azure Virtual Desktop */
} AccessTokenType;
/** @brief A function to be implemented by a client. It is called whenever the connection
* requires an access token.
* @param instance The instance the function is called for
* @param tokenType The type of token requested
* @param token A pointer that will hold the (allocated) token string
* @param count The number of arguments following
*
* @return \b TRUE for success, \b FALSE otherwise
* @since version 3.0.0
*/
typedef BOOL (*pGetAccessToken)(freerdp* instance, AccessTokenType tokenType, char** token,
size_t count, ...);
/** @brief The function is called whenever the connection requires an access token.
* It differs from \ref pGetAccessToken and is not meant to be implemented by a client
* directly. The client-common library will use this to provide common means to retrieve a token
* and only if that fails the instanc->GetAccessToken callback will be called.
*
* @param context The context the function is called for
* @param tokenType The type of token requested
* @param token A pointer that will hold the (allocated) token string
* @param count The number of arguments following
*
* @return \b TRUE for success, \b FALSE otherwise
* @since version 3.16.0
*/
typedef BOOL (*pGetCommonAccessToken)(rdpContext* context, AccessTokenType tokenType,
char** token, size_t count, ...);
/** @brief Callback used to inform about a reconnection attempt
*
* @param instance The instance the information is for
@@ -769,6 +795,25 @@ owned by rdpRdp */
*/
FREERDP_API BOOL freerdp_persist_credentials(rdpContext* context);
/** @brief set a new function to be called when an access token is requested.
*
* @param context The rdp context to set the function for. Must not be \b NULL
* @param GetCommonAccessToken The function pointer to set, \b NULL to disable
*
* @return \b TRUE for success, \b FALSE otherwise
* @since version 3.16.0
*/
FREERDP_API BOOL freerdp_set_common_access_token(rdpContext* context,
pGetCommonAccessToken GetCommonAccessToken);
/** @brief get the current function pointer set as GetCommonAccessToken
*
* @param context The rdp context to set the function for. Must not be \b NULL
* @return The current function pointer set or \b NULL
* @since version 3.16.0
*/
FREERDP_API pGetCommonAccessToken freerdp_get_common_access_token(rdpContext* context);
#ifdef __cplusplus
}
#endif

View File

@@ -48,6 +48,7 @@ struct rdp_aad
char* hostname;
char* scope;
wLog* log;
pGetCommonAccessToken GetCommonAccessToken;
};
#ifdef WITH_AAD
@@ -269,9 +270,6 @@ int aad_client_begin(rdpAad* aad)
rdpSettings* settings = aad->rdpcontext->settings;
WINPR_ASSERT(settings);
freerdp* instance = aad->rdpcontext->instance;
WINPR_ASSERT(instance);
/* Get the host part of the hostname */
const char* hostname = freerdp_settings_get_string(settings, FreeRDP_AadServerHostname);
if (!hostname)
@@ -303,17 +301,17 @@ int aad_client_begin(rdpAad* aad)
return -1;
/* Obtain an oauth authorization code */
if (!instance->GetAccessToken)
if (!aad->GetCommonAccessToken)
{
WLog_Print(aad->log, WLOG_ERROR, "instance->GetAccessToken == NULL");
WLog_Print(aad->log, WLOG_ERROR, "aad->rdpcontext->GetCommonAccessToken == NULL");
return -1;
}
if (!aad_fetch_wellknown(aad->log, aad->rdpcontext))
return -1;
const BOOL arc = instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AAD, &aad->access_token,
2, aad->scope, aad->kid);
const BOOL arc = aad->GetCommonAccessToken(aad->rdpcontext, ACCESS_TOKEN_TYPE_AAD,
&aad->access_token, 2, aad->scope, aad->kid);
if (!arc)
{
WLog_Print(aad->log, WLOG_ERROR, "Unable to obtain access token");
@@ -788,7 +786,8 @@ static BOOL ensure_wellknown(WINPR_ATTR_UNUSED rdpContext* context)
#endif
rdpAad* aad_new(rdpContext* context, rdpTransport* transport)
rdpAad* aad_new(rdpContext* context, rdpTransport* transport,
pGetCommonAccessToken GetCommonAccessToken)
{
WINPR_ASSERT(transport);
WINPR_ASSERT(context);
@@ -799,6 +798,7 @@ rdpAad* aad_new(rdpContext* context, rdpTransport* transport)
return NULL;
aad->log = WLog_Get(FREERDP_TAG("aad"));
aad->GetCommonAccessToken = GetCommonAccessToken;
aad->key = freerdp_key_new();
if (!aad->key)
goto fail;

View File

@@ -42,6 +42,7 @@ FREERDP_LOCAL AAD_STATE aad_get_state(rdpAad* aad);
FREERDP_LOCAL void aad_free(rdpAad* aad);
WINPR_ATTR_MALLOC(aad_free, 1)
FREERDP_LOCAL rdpAad* aad_new(rdpContext* context, rdpTransport* transport);
FREERDP_LOCAL rdpAad* aad_new(rdpContext* context, rdpTransport* transport,
pGetCommonAccessToken GetCommonAccessToken);
#endif /* FREERDP_LIB_CORE_AAD_H */

View File

@@ -789,6 +789,56 @@ BOOL freerdp_context_new(freerdp* instance)
return freerdp_context_new_ex(instance, NULL);
}
static BOOL freerdp_common_context(rdpContext* context, AccessTokenType tokenType, char** token,
size_t count, ...)
{
BOOL rc = FALSE;
WINPR_ASSERT(context);
if (!context->instance || !context->instance->GetAccessToken)
return TRUE;
va_list ap;
va_start(ap, count);
switch (tokenType)
{
case ACCESS_TOKEN_TYPE_AAD:
if (count != 2)
{
WLog_ERR(TAG,
"ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz
", aborting",
count);
}
else
{
const char* scope = va_arg(ap, const char*);
const char* req_cnf = va_arg(ap, const char*);
rc = context->instance->GetAccessToken(context->instance, tokenType, token, count,
scope, req_cnf);
}
break;
case ACCESS_TOKEN_TYPE_AVD:
if (count != 0)
{
WLog_WARN(TAG,
"ACCESS_TOKEN_TYPE_AVD expected 0 additional arguments, but got %" PRIuz
", ignoring",
count);
}
else
{
rc = context->instance->GetAccessToken(context->instance, tokenType, token, count);
}
break;
default:
break;
}
va_end(ap);
return rc;
}
BOOL freerdp_context_new_ex(freerdp* instance, rdpSettings* settings)
{
rdpRdp* rdp = NULL;
@@ -869,10 +919,19 @@ BOOL freerdp_context_new_ex(freerdp* instance, rdpSettings* settings)
if (!context->dump)
goto fail;
/* Fallback:
* Client common library might set a function pointer to handle this, but here we provide a
* default implementation that simply calls instance->GetAccessToken.
*/
if (!freerdp_set_common_access_token(context, freerdp_common_context))
goto fail;
IFCALLRET(instance->ContextNew, ret, instance, context);
if (ret)
return TRUE;
if (!ret)
goto fail;
return TRUE;
fail:
freerdp_context_free(instance);
@@ -1507,3 +1566,19 @@ const char* freerdp_disconnect_reason_string(int reason)
return "rn-unknown";
}
}
BOOL freerdp_set_common_access_token(rdpContext* context,
pGetCommonAccessToken GetCommonAccessToken)
{
WINPR_ASSERT(context);
WINPR_ASSERT(context->rdp);
context->rdp->GetCommonAccessToken = GetCommonAccessToken;
return TRUE;
}
pGetCommonAccessToken freerdp_get_common_access_token(rdpContext* context)
{
WINPR_ASSERT(context);
WINPR_ASSERT(context->rdp);
return context->rdp->GetCommonAccessToken;
}

View File

@@ -194,9 +194,7 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method,
WINPR_ASSERT(content_type);
WINPR_ASSERT(arm->context);
freerdp* instance = arm->context->instance;
WINPR_ASSERT(instance);
WINPR_ASSERT(arm->context->rdp);
uri = http_context_get_uri(arm->http);
request = http_request_new();
@@ -211,7 +209,7 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method,
{
char* token = NULL;
if (!instance->GetAccessToken)
if (!arm->context->rdp->GetCommonAccessToken)
{
WLog_Print(arm->log, WLOG_ERROR, "No authorization token provided");
goto out;
@@ -220,7 +218,8 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method,
if (!arm_fetch_wellknown(arm))
goto out;
if (!instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AVD, &token, 0))
if (!arm->context->rdp->GetCommonAccessToken(arm->context, ACCESS_TOKEN_TYPE_AVD, &token,
0))
{
WLog_Print(arm->log, WLOG_ERROR, "Unable to obtain access token");
goto out;

View File

@@ -2318,7 +2318,7 @@ static bool rdp_new_common(rdpRdp* rdp)
goto fail;
}
rdp->aad = aad_new(rdp->context, rdp->transport);
rdp->aad = aad_new(rdp->context, rdp->transport, rdp->GetCommonAccessToken);
if (!rdp->aad)
goto fail;

View File

@@ -209,6 +209,7 @@ struct rdp_rdp
char log_context[64];
WINPR_JSON* wellknown;
FreeRDPTimer* timer;
pGetCommonAccessToken GetCommonAccessToken;
};
FREERDP_LOCAL BOOL rdp_read_security_header(rdpRdp* rdp, wStream* s, UINT16* flags, UINT16* length);