mirror of
https://github.com/morgan9e/FreeRDP
synced 2026-04-14 00:14:11 +09:00
Merge pull request #11606 from akallabeth/GetAccessToken-split
[core,aad] Split GetAccessToken callback
This commit is contained in:
@@ -1014,7 +1014,8 @@ static BOOL hotplug_delete_foreach(ULONG_PTR key, void* element, void* data)
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
static UINT handle_hotplug(RdpdrClientContext* context, RdpdrHotplugEventType type)
|
||||
static UINT handle_hotplug(RdpdrClientContext* context,
|
||||
WINPR_ATTR_UNUSED RdpdrHotplugEventType type)
|
||||
{
|
||||
WINPR_ASSERT(context);
|
||||
rdpdrPlugin* rdpdr = context->handle;
|
||||
@@ -2341,7 +2342,8 @@ static UINT rdpdr_unregister_device(RdpdrClientContext* context, size_t count, c
|
||||
const uintptr_t id = ids[x];
|
||||
devman_unregister_device(rdpdr->devman, (void*)id);
|
||||
}
|
||||
return rdpdr_send_device_list_remove_request(rdpdr, count, ids);
|
||||
return rdpdr_send_device_list_remove_request(rdpdr, WINPR_ASSERTING_INT_CAST(uint32_t, count),
|
||||
ids);
|
||||
}
|
||||
|
||||
static UINT rdpdr_virtual_channel_event_initialized(rdpdrPlugin* rdpdr,
|
||||
|
||||
@@ -134,9 +134,35 @@ extern "C"
|
||||
ACCESS_TOKEN_TYPE_AVD /**!< oauth2 access token for Azure Virtual Desktop */
|
||||
} AccessTokenType;
|
||||
|
||||
/** @brief A function to be implemented by a client. It is called whenever the connection
|
||||
* requires an access token.
|
||||
* @param instance The instance the function is called for
|
||||
* @param tokenType The type of token requested
|
||||
* @param token A pointer that will hold the (allocated) token string
|
||||
* @param count The number of arguments following
|
||||
*
|
||||
* @return \b TRUE for success, \b FALSE otherwise
|
||||
* @since version 3.0.0
|
||||
*/
|
||||
typedef BOOL (*pGetAccessToken)(freerdp* instance, AccessTokenType tokenType, char** token,
|
||||
size_t count, ...);
|
||||
|
||||
/** @brief The function is called whenever the connection requires an access token.
|
||||
* It differs from \ref pGetAccessToken and is not meant to be implemented by a client
|
||||
* directly. The client-common library will use this to provide common means to retrieve a token
|
||||
* and only if that fails the instanc->GetAccessToken callback will be called.
|
||||
*
|
||||
* @param context The context the function is called for
|
||||
* @param tokenType The type of token requested
|
||||
* @param token A pointer that will hold the (allocated) token string
|
||||
* @param count The number of arguments following
|
||||
*
|
||||
* @return \b TRUE for success, \b FALSE otherwise
|
||||
* @since version 3.16.0
|
||||
*/
|
||||
typedef BOOL (*pGetCommonAccessToken)(rdpContext* context, AccessTokenType tokenType,
|
||||
char** token, size_t count, ...);
|
||||
|
||||
/** @brief Callback used to inform about a reconnection attempt
|
||||
*
|
||||
* @param instance The instance the information is for
|
||||
@@ -769,6 +795,25 @@ owned by rdpRdp */
|
||||
*/
|
||||
FREERDP_API BOOL freerdp_persist_credentials(rdpContext* context);
|
||||
|
||||
/** @brief set a new function to be called when an access token is requested.
|
||||
*
|
||||
* @param context The rdp context to set the function for. Must not be \b NULL
|
||||
* @param GetCommonAccessToken The function pointer to set, \b NULL to disable
|
||||
*
|
||||
* @return \b TRUE for success, \b FALSE otherwise
|
||||
* @since version 3.16.0
|
||||
*/
|
||||
FREERDP_API BOOL freerdp_set_common_access_token(rdpContext* context,
|
||||
pGetCommonAccessToken GetCommonAccessToken);
|
||||
|
||||
/** @brief get the current function pointer set as GetCommonAccessToken
|
||||
*
|
||||
* @param context The rdp context to set the function for. Must not be \b NULL
|
||||
* @return The current function pointer set or \b NULL
|
||||
* @since version 3.16.0
|
||||
*/
|
||||
FREERDP_API pGetCommonAccessToken freerdp_get_common_access_token(rdpContext* context);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -48,6 +48,7 @@ struct rdp_aad
|
||||
char* hostname;
|
||||
char* scope;
|
||||
wLog* log;
|
||||
pGetCommonAccessToken GetCommonAccessToken;
|
||||
};
|
||||
|
||||
#ifdef WITH_AAD
|
||||
@@ -269,9 +270,6 @@ int aad_client_begin(rdpAad* aad)
|
||||
rdpSettings* settings = aad->rdpcontext->settings;
|
||||
WINPR_ASSERT(settings);
|
||||
|
||||
freerdp* instance = aad->rdpcontext->instance;
|
||||
WINPR_ASSERT(instance);
|
||||
|
||||
/* Get the host part of the hostname */
|
||||
const char* hostname = freerdp_settings_get_string(settings, FreeRDP_AadServerHostname);
|
||||
if (!hostname)
|
||||
@@ -303,17 +301,17 @@ int aad_client_begin(rdpAad* aad)
|
||||
return -1;
|
||||
|
||||
/* Obtain an oauth authorization code */
|
||||
if (!instance->GetAccessToken)
|
||||
if (!aad->GetCommonAccessToken)
|
||||
{
|
||||
WLog_Print(aad->log, WLOG_ERROR, "instance->GetAccessToken == NULL");
|
||||
WLog_Print(aad->log, WLOG_ERROR, "aad->rdpcontext->GetCommonAccessToken == NULL");
|
||||
return -1;
|
||||
}
|
||||
|
||||
if (!aad_fetch_wellknown(aad->log, aad->rdpcontext))
|
||||
return -1;
|
||||
|
||||
const BOOL arc = instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AAD, &aad->access_token,
|
||||
2, aad->scope, aad->kid);
|
||||
const BOOL arc = aad->GetCommonAccessToken(aad->rdpcontext, ACCESS_TOKEN_TYPE_AAD,
|
||||
&aad->access_token, 2, aad->scope, aad->kid);
|
||||
if (!arc)
|
||||
{
|
||||
WLog_Print(aad->log, WLOG_ERROR, "Unable to obtain access token");
|
||||
@@ -788,7 +786,8 @@ static BOOL ensure_wellknown(WINPR_ATTR_UNUSED rdpContext* context)
|
||||
|
||||
#endif
|
||||
|
||||
rdpAad* aad_new(rdpContext* context, rdpTransport* transport)
|
||||
rdpAad* aad_new(rdpContext* context, rdpTransport* transport,
|
||||
pGetCommonAccessToken GetCommonAccessToken)
|
||||
{
|
||||
WINPR_ASSERT(transport);
|
||||
WINPR_ASSERT(context);
|
||||
@@ -799,6 +798,7 @@ rdpAad* aad_new(rdpContext* context, rdpTransport* transport)
|
||||
return NULL;
|
||||
|
||||
aad->log = WLog_Get(FREERDP_TAG("aad"));
|
||||
aad->GetCommonAccessToken = GetCommonAccessToken;
|
||||
aad->key = freerdp_key_new();
|
||||
if (!aad->key)
|
||||
goto fail;
|
||||
|
||||
@@ -42,6 +42,7 @@ FREERDP_LOCAL AAD_STATE aad_get_state(rdpAad* aad);
|
||||
FREERDP_LOCAL void aad_free(rdpAad* aad);
|
||||
|
||||
WINPR_ATTR_MALLOC(aad_free, 1)
|
||||
FREERDP_LOCAL rdpAad* aad_new(rdpContext* context, rdpTransport* transport);
|
||||
FREERDP_LOCAL rdpAad* aad_new(rdpContext* context, rdpTransport* transport,
|
||||
pGetCommonAccessToken GetCommonAccessToken);
|
||||
|
||||
#endif /* FREERDP_LIB_CORE_AAD_H */
|
||||
|
||||
@@ -789,6 +789,56 @@ BOOL freerdp_context_new(freerdp* instance)
|
||||
return freerdp_context_new_ex(instance, NULL);
|
||||
}
|
||||
|
||||
static BOOL freerdp_common_context(rdpContext* context, AccessTokenType tokenType, char** token,
|
||||
size_t count, ...)
|
||||
{
|
||||
BOOL rc = FALSE;
|
||||
|
||||
WINPR_ASSERT(context);
|
||||
if (!context->instance || !context->instance->GetAccessToken)
|
||||
return TRUE;
|
||||
|
||||
va_list ap;
|
||||
va_start(ap, count);
|
||||
switch (tokenType)
|
||||
{
|
||||
case ACCESS_TOKEN_TYPE_AAD:
|
||||
if (count != 2)
|
||||
{
|
||||
WLog_ERR(TAG,
|
||||
"ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz
|
||||
", aborting",
|
||||
count);
|
||||
}
|
||||
else
|
||||
{
|
||||
const char* scope = va_arg(ap, const char*);
|
||||
const char* req_cnf = va_arg(ap, const char*);
|
||||
rc = context->instance->GetAccessToken(context->instance, tokenType, token, count,
|
||||
scope, req_cnf);
|
||||
}
|
||||
break;
|
||||
case ACCESS_TOKEN_TYPE_AVD:
|
||||
if (count != 0)
|
||||
{
|
||||
WLog_WARN(TAG,
|
||||
"ACCESS_TOKEN_TYPE_AVD expected 0 additional arguments, but got %" PRIuz
|
||||
", ignoring",
|
||||
count);
|
||||
}
|
||||
else
|
||||
{
|
||||
rc = context->instance->GetAccessToken(context->instance, tokenType, token, count);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
va_end(ap);
|
||||
|
||||
return rc;
|
||||
}
|
||||
|
||||
BOOL freerdp_context_new_ex(freerdp* instance, rdpSettings* settings)
|
||||
{
|
||||
rdpRdp* rdp = NULL;
|
||||
@@ -869,10 +919,19 @@ BOOL freerdp_context_new_ex(freerdp* instance, rdpSettings* settings)
|
||||
if (!context->dump)
|
||||
goto fail;
|
||||
|
||||
/* Fallback:
|
||||
* Client common library might set a function pointer to handle this, but here we provide a
|
||||
* default implementation that simply calls instance->GetAccessToken.
|
||||
*/
|
||||
if (!freerdp_set_common_access_token(context, freerdp_common_context))
|
||||
goto fail;
|
||||
|
||||
IFCALLRET(instance->ContextNew, ret, instance, context);
|
||||
|
||||
if (ret)
|
||||
return TRUE;
|
||||
if (!ret)
|
||||
goto fail;
|
||||
|
||||
return TRUE;
|
||||
|
||||
fail:
|
||||
freerdp_context_free(instance);
|
||||
@@ -1507,3 +1566,19 @@ const char* freerdp_disconnect_reason_string(int reason)
|
||||
return "rn-unknown";
|
||||
}
|
||||
}
|
||||
|
||||
BOOL freerdp_set_common_access_token(rdpContext* context,
|
||||
pGetCommonAccessToken GetCommonAccessToken)
|
||||
{
|
||||
WINPR_ASSERT(context);
|
||||
WINPR_ASSERT(context->rdp);
|
||||
context->rdp->GetCommonAccessToken = GetCommonAccessToken;
|
||||
return TRUE;
|
||||
}
|
||||
|
||||
pGetCommonAccessToken freerdp_get_common_access_token(rdpContext* context)
|
||||
{
|
||||
WINPR_ASSERT(context);
|
||||
WINPR_ASSERT(context->rdp);
|
||||
return context->rdp->GetCommonAccessToken;
|
||||
}
|
||||
|
||||
@@ -194,9 +194,7 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method,
|
||||
WINPR_ASSERT(content_type);
|
||||
|
||||
WINPR_ASSERT(arm->context);
|
||||
|
||||
freerdp* instance = arm->context->instance;
|
||||
WINPR_ASSERT(instance);
|
||||
WINPR_ASSERT(arm->context->rdp);
|
||||
|
||||
uri = http_context_get_uri(arm->http);
|
||||
request = http_request_new();
|
||||
@@ -211,7 +209,7 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method,
|
||||
{
|
||||
char* token = NULL;
|
||||
|
||||
if (!instance->GetAccessToken)
|
||||
if (!arm->context->rdp->GetCommonAccessToken)
|
||||
{
|
||||
WLog_Print(arm->log, WLOG_ERROR, "No authorization token provided");
|
||||
goto out;
|
||||
@@ -220,7 +218,8 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method,
|
||||
if (!arm_fetch_wellknown(arm))
|
||||
goto out;
|
||||
|
||||
if (!instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AVD, &token, 0))
|
||||
if (!arm->context->rdp->GetCommonAccessToken(arm->context, ACCESS_TOKEN_TYPE_AVD, &token,
|
||||
0))
|
||||
{
|
||||
WLog_Print(arm->log, WLOG_ERROR, "Unable to obtain access token");
|
||||
goto out;
|
||||
|
||||
@@ -2318,7 +2318,7 @@ static bool rdp_new_common(rdpRdp* rdp)
|
||||
goto fail;
|
||||
}
|
||||
|
||||
rdp->aad = aad_new(rdp->context, rdp->transport);
|
||||
rdp->aad = aad_new(rdp->context, rdp->transport, rdp->GetCommonAccessToken);
|
||||
if (!rdp->aad)
|
||||
goto fail;
|
||||
|
||||
|
||||
@@ -209,6 +209,7 @@ struct rdp_rdp
|
||||
char log_context[64];
|
||||
WINPR_JSON* wellknown;
|
||||
FreeRDPTimer* timer;
|
||||
pGetCommonAccessToken GetCommonAccessToken;
|
||||
};
|
||||
|
||||
FREERDP_LOCAL BOOL rdp_read_security_header(rdpRdp* rdp, wStream* s, UINT16* flags, UINT16* length);
|
||||
|
||||
Reference in New Issue
Block a user