[core,aad] Split GetAccessToken callback

To allow client-common library to override the GetAccessToken callback
introduce a new GetCommonAccessToken callback.
This callback defaults to call the existing GetAccessToken callback, but
client-common library can override if desired, so that a common token
retrieval method is executed before a client UI is invoked.
This commit is contained in:
Armin Novak
2025-05-21 10:01:37 +02:00
parent 8fe6450eef
commit 1882cebbce
7 changed files with 138 additions and 11 deletions

View File

@@ -134,9 +134,35 @@ extern "C"
ACCESS_TOKEN_TYPE_AVD /**!< oauth2 access token for Azure Virtual Desktop */
} AccessTokenType;
/** @brief A function to be implemented by a client. It is called whenever the connection
* requires an access token.
* @param instance The instance the function is called for
* @param tokenType The type of token requested
* @param token A pointer that will hold the (allocated) token string
* @param count The number of arguments following
*
* @return \b TRUE for success, \b FALSE otherwise
* @since version 3.0.0
*/
typedef BOOL (*pGetAccessToken)(freerdp* instance, AccessTokenType tokenType, char** token,
size_t count, ...);
/** @brief The function is called whenever the connection requires an access token.
* It differs from \ref pGetAccessToken and is not meant to be implemented by a client
* directly. The client-common library will use this to provide common means to retrieve a token
* and only if that fails the instanc->GetAccessToken callback will be called.
*
* @param context The context the function is called for
* @param tokenType The type of token requested
* @param token A pointer that will hold the (allocated) token string
* @param count The number of arguments following
*
* @return \b TRUE for success, \b FALSE otherwise
* @since version 3.16.0
*/
typedef BOOL (*pGetCommonAccessToken)(rdpContext* context, AccessTokenType tokenType,
char** token, size_t count, ...);
/** @brief Callback used to inform about a reconnection attempt
*
* @param instance The instance the information is for
@@ -769,6 +795,25 @@ owned by rdpRdp */
*/
FREERDP_API BOOL freerdp_persist_credentials(rdpContext* context);
/** @brief set a new function to be called when an access token is requested.
*
* @param context The rdp context to set the function for. Must not be \b NULL
* @param GetCommonAccessToken The function pointer to set, \b NULL to disable
*
* @return \b TRUE for success, \b FALSE otherwise
* @since version 3.16.0
*/
FREERDP_API BOOL freerdp_set_common_access_token(rdpContext* context,
pGetCommonAccessToken GetCommonAccessToken);
/** @brief get the current function pointer set as GetCommonAccessToken
*
* @param context The rdp context to set the function for. Must not be \b NULL
* @return The current function pointer set or \b NULL
* @since version 3.16.0
*/
FREERDP_API pGetCommonAccessToken freerdp_get_common_access_token(rdpContext* context);
#ifdef __cplusplus
}
#endif

View File

@@ -48,6 +48,7 @@ struct rdp_aad
char* hostname;
char* scope;
wLog* log;
pGetCommonAccessToken GetCommonAccessToken;
};
#ifdef WITH_AAD
@@ -303,17 +304,17 @@ int aad_client_begin(rdpAad* aad)
return -1;
/* Obtain an oauth authorization code */
if (!instance->GetAccessToken)
if (!aad->GetCommonAccessToken)
{
WLog_Print(aad->log, WLOG_ERROR, "instance->GetAccessToken == NULL");
WLog_Print(aad->log, WLOG_ERROR, "aad->rdpcontext->GetCommonAccessToken == NULL");
return -1;
}
if (!aad_fetch_wellknown(aad->log, aad->rdpcontext))
return -1;
const BOOL arc = instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AAD, &aad->access_token,
2, aad->scope, aad->kid);
const BOOL arc = aad->GetCommonAccessToken(aad->rdpcontext, ACCESS_TOKEN_TYPE_AAD,
&aad->access_token, 2, aad->scope, aad->kid);
if (!arc)
{
WLog_Print(aad->log, WLOG_ERROR, "Unable to obtain access token");
@@ -788,7 +789,8 @@ static BOOL ensure_wellknown(WINPR_ATTR_UNUSED rdpContext* context)
#endif
rdpAad* aad_new(rdpContext* context, rdpTransport* transport)
rdpAad* aad_new(rdpContext* context, rdpTransport* transport,
pGetCommonAccessToken GetCommonAccessToken)
{
WINPR_ASSERT(transport);
WINPR_ASSERT(context);
@@ -799,6 +801,7 @@ rdpAad* aad_new(rdpContext* context, rdpTransport* transport)
return NULL;
aad->log = WLog_Get(FREERDP_TAG("aad"));
aad->GetCommonAccessToken = GetCommonAccessToken;
aad->key = freerdp_key_new();
if (!aad->key)
goto fail;

View File

@@ -42,6 +42,7 @@ FREERDP_LOCAL AAD_STATE aad_get_state(rdpAad* aad);
FREERDP_LOCAL void aad_free(rdpAad* aad);
WINPR_ATTR_MALLOC(aad_free, 1)
FREERDP_LOCAL rdpAad* aad_new(rdpContext* context, rdpTransport* transport);
FREERDP_LOCAL rdpAad* aad_new(rdpContext* context, rdpTransport* transport,
pGetCommonAccessToken GetCommonAccessToken);
#endif /* FREERDP_LIB_CORE_AAD_H */

View File

@@ -789,6 +789,56 @@ BOOL freerdp_context_new(freerdp* instance)
return freerdp_context_new_ex(instance, NULL);
}
static BOOL freerdp_common_context(rdpContext* context, AccessTokenType tokenType, char** token,
size_t count, ...)
{
BOOL rc = FALSE;
WINPR_ASSERT(context);
if (!context->instance || !context->instance->GetAccessToken)
return TRUE;
va_list ap;
va_start(ap, count);
switch (tokenType)
{
case ACCESS_TOKEN_TYPE_AAD:
if (count != 2)
{
WLog_ERR(TAG,
"ACCESS_TOKEN_TYPE_AAD expected 2 additional arguments, but got %" PRIuz
", aborting",
count);
}
else
{
const char* scope = va_arg(ap, const char*);
const char* req_cnf = va_arg(ap, const char*);
rc = context->instance->GetAccessToken(context->instance, tokenType, token, count,
scope, req_cnf);
}
break;
case ACCESS_TOKEN_TYPE_AVD:
if (count != 0)
{
WLog_WARN(TAG,
"ACCESS_TOKEN_TYPE_AVD expected 0 additional arguments, but got %" PRIuz
", ignoring",
count);
}
else
{
rc = context->instance->GetAccessToken(context->instance, tokenType, token, count);
}
break;
default:
break;
}
va_end(ap);
return rc;
}
BOOL freerdp_context_new_ex(freerdp* instance, rdpSettings* settings)
{
rdpRdp* rdp = NULL;
@@ -869,10 +919,19 @@ BOOL freerdp_context_new_ex(freerdp* instance, rdpSettings* settings)
if (!context->dump)
goto fail;
/* Fallback:
* Client common library might set a function pointer to handle this, but here we provide a
* default implementation that simply calls instance->GetAccessToken.
*/
if (!freerdp_set_common_access_token(context, freerdp_common_context))
goto fail;
IFCALLRET(instance->ContextNew, ret, instance, context);
if (ret)
return TRUE;
if (!ret)
goto fail;
return TRUE;
fail:
freerdp_context_free(instance);
@@ -1507,3 +1566,19 @@ const char* freerdp_disconnect_reason_string(int reason)
return "rn-unknown";
}
}
BOOL freerdp_set_common_access_token(rdpContext* context,
pGetCommonAccessToken GetCommonAccessToken)
{
WINPR_ASSERT(context);
WINPR_ASSERT(context->rdp);
context->rdp->GetCommonAccessToken = GetCommonAccessToken;
return TRUE;
}
pGetCommonAccessToken freerdp_get_common_access_token(rdpContext* context)
{
WINPR_ASSERT(context);
WINPR_ASSERT(context->rdp);
return context->rdp->GetCommonAccessToken;
}

View File

@@ -194,6 +194,7 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method,
WINPR_ASSERT(content_type);
WINPR_ASSERT(arm->context);
WINPR_ASSERT(arm->context->rdp);
freerdp* instance = arm->context->instance;
WINPR_ASSERT(instance);
@@ -211,7 +212,7 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method,
{
char* token = NULL;
if (!instance->GetAccessToken)
if (!arm->context->rdp->GetCommonAccessToken)
{
WLog_Print(arm->log, WLOG_ERROR, "No authorization token provided");
goto out;
@@ -220,7 +221,8 @@ static wStream* arm_build_http_request(rdpArm* arm, const char* method,
if (!arm_fetch_wellknown(arm))
goto out;
if (!instance->GetAccessToken(instance, ACCESS_TOKEN_TYPE_AVD, &token, 0))
if (!arm->context->rdp->GetCommonAccessToken(arm->context, ACCESS_TOKEN_TYPE_AVD, &token,
0))
{
WLog_Print(arm->log, WLOG_ERROR, "Unable to obtain access token");
goto out;

View File

@@ -2318,7 +2318,7 @@ static bool rdp_new_common(rdpRdp* rdp)
goto fail;
}
rdp->aad = aad_new(rdp->context, rdp->transport);
rdp->aad = aad_new(rdp->context, rdp->transport, rdp->GetCommonAccessToken);
if (!rdp->aad)
goto fail;

View File

@@ -209,6 +209,7 @@ struct rdp_rdp
char log_context[64];
WINPR_JSON* wellknown;
FreeRDPTimer* timer;
pGetCommonAccessToken GetCommonAccessToken;
};
FREERDP_LOCAL BOOL rdp_read_security_header(rdpRdp* rdp, wStream* s, UINT16* flags, UINT16* length);