Skip to content

Commit 19c2209

Browse files
committed
ServerOAuth2AuthorizedClientExchangeFilterFunction works with UnAuthenticatedServerOAuth2AuthorizedClientRepository
Fixes gh-7544
1 parent 18f48e4 commit 19c2209

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

oauth2/oauth2-client/src/main/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunction.java

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.springframework.security.core.context.ReactiveSecurityContextHolder;
2323
import org.springframework.security.core.context.SecurityContext;
2424
import org.springframework.security.oauth2.client.ClientCredentialsReactiveOAuth2AuthorizedClientProvider;
25+
import org.springframework.security.oauth2.client.OAuth2AuthorizationContext;
2526
import org.springframework.security.oauth2.client.OAuth2AuthorizeRequest;
2627
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
2728
import org.springframework.security.oauth2.client.ReactiveOAuth2AuthorizedClientManager;
@@ -35,6 +36,7 @@
3536
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository;
3637
import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
3738
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
39+
import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository;
3840
import org.springframework.util.Assert;
3941
import org.springframework.web.reactive.function.client.ClientRequest;
4042
import org.springframework.web.reactive.function.client.ClientResponse;
@@ -124,6 +126,17 @@ private static ReactiveOAuth2AuthorizedClientManager createDefaultAuthorizedClie
124126
.clientCredentials()
125127
.password()
126128
.build();
129+
130+
// gh-7544
131+
if (authorizedClientRepository instanceof UnAuthenticatedServerOAuth2AuthorizedClientRepository) {
132+
UnAuthenticatedReactiveOAuth2AuthorizedClientManager unauthenticatedAuthorizedClientManager =
133+
new UnAuthenticatedReactiveOAuth2AuthorizedClientManager(
134+
clientRegistrationRepository,
135+
(UnAuthenticatedServerOAuth2AuthorizedClientRepository) authorizedClientRepository);
136+
unauthenticatedAuthorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
137+
return unauthenticatedAuthorizedClientManager;
138+
}
139+
127140
DefaultReactiveOAuth2AuthorizedClientManager authorizedClientManager = new DefaultReactiveOAuth2AuthorizedClientManager(
128141
clientRegistrationRepository, authorizedClientRepository);
129142
authorizedClientManager.setAuthorizedClientProvider(authorizedClientProvider);
@@ -266,7 +279,11 @@ private void updateDefaultAuthorizedClientManager() {
266279
.clientCredentials(this::updateClientCredentialsProvider)
267280
.password(configurer -> configurer.clockSkew(this.accessTokenExpiresSkew))
268281
.build();
269-
((DefaultReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
282+
if (this.authorizedClientManager instanceof UnAuthenticatedReactiveOAuth2AuthorizedClientManager) {
283+
((UnAuthenticatedReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
284+
} else {
285+
((DefaultReactiveOAuth2AuthorizedClientManager) this.authorizedClientManager).setAuthorizedClientProvider(authorizedClientProvider);
286+
}
270287
}
271288

272289
private void updateClientCredentialsProvider(ReactiveOAuth2AuthorizedClientProviderBuilder.ClientCredentialsGrantBuilder builder) {
@@ -376,4 +393,52 @@ private ClientRequest bearer(ClientRequest request, OAuth2AuthorizedClient autho
376393
.headers(headers -> headers.setBearerAuth(authorizedClient.getAccessToken().getTokenValue()))
377394
.build();
378395
}
396+
397+
private static class UnAuthenticatedReactiveOAuth2AuthorizedClientManager implements ReactiveOAuth2AuthorizedClientManager {
398+
private final ReactiveClientRegistrationRepository clientRegistrationRepository;
399+
private final UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository;
400+
private ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider;
401+
402+
private UnAuthenticatedReactiveOAuth2AuthorizedClientManager(
403+
ReactiveClientRegistrationRepository clientRegistrationRepository,
404+
UnAuthenticatedServerOAuth2AuthorizedClientRepository authorizedClientRepository) {
405+
this.clientRegistrationRepository = clientRegistrationRepository;
406+
this.authorizedClientRepository = authorizedClientRepository;
407+
}
408+
409+
@Override
410+
public Mono<OAuth2AuthorizedClient> authorize(OAuth2AuthorizeRequest authorizeRequest) {
411+
Assert.notNull(authorizeRequest, "authorizeRequest cannot be null");
412+
413+
String clientRegistrationId = authorizeRequest.getClientRegistrationId();
414+
Authentication principal = authorizeRequest.getPrincipal();
415+
416+
return Mono.justOrEmpty(authorizeRequest.getAuthorizedClient())
417+
.switchIfEmpty(Mono.defer(() -> this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, principal, null)))
418+
.flatMap(authorizedClient -> {
419+
// Re-authorize
420+
return Mono.just(OAuth2AuthorizationContext.withAuthorizedClient(authorizedClient).principal(principal).build())
421+
.flatMap(this.authorizedClientProvider::authorize)
422+
.flatMap(reauthorizedClient -> this.authorizedClientRepository.saveAuthorizedClient(reauthorizedClient, principal, null).thenReturn(reauthorizedClient))
423+
// Default to the existing authorizedClient if the client was not re-authorized
424+
.defaultIfEmpty(authorizeRequest.getAuthorizedClient() != null ?
425+
authorizeRequest.getAuthorizedClient() : authorizedClient);
426+
})
427+
.switchIfEmpty(Mono.deferWithContext(context ->
428+
// Authorize
429+
this.clientRegistrationRepository.findByRegistrationId(clientRegistrationId)
430+
.switchIfEmpty(Mono.error(() -> new IllegalArgumentException(
431+
"Could not find ClientRegistration with id '" + clientRegistrationId + "'")))
432+
.flatMap(clientRegistration -> Mono.just(OAuth2AuthorizationContext.withClientRegistration(clientRegistration).principal(principal).build()))
433+
.flatMap(this.authorizedClientProvider::authorize)
434+
.flatMap(authorizedClient -> this.authorizedClientRepository.saveAuthorizedClient(authorizedClient, principal, null).thenReturn(authorizedClient))
435+
.subscriberContext(context)
436+
));
437+
}
438+
439+
private void setAuthorizedClientProvider(ReactiveOAuth2AuthorizedClientProvider authorizedClientProvider) {
440+
Assert.notNull(authorizedClientProvider, "authorizedClientProvider cannot be null");
441+
this.authorizedClientProvider = authorizedClientProvider;
442+
}
443+
}
379444
}

oauth2/oauth2-client/src/test/java/org/springframework/security/oauth2/client/web/reactive/function/client/ServerOAuth2AuthorizedClientExchangeFilterFunctionTests.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
import org.springframework.security.oauth2.client.registration.TestClientRegistrations;
5858
import org.springframework.security.oauth2.client.web.DefaultReactiveOAuth2AuthorizedClientManager;
5959
import org.springframework.security.oauth2.client.web.server.ServerOAuth2AuthorizedClientRepository;
60+
import org.springframework.security.oauth2.client.web.server.UnAuthenticatedServerOAuth2AuthorizedClientRepository;
6061
import org.springframework.security.oauth2.core.OAuth2AccessToken;
6162
import org.springframework.security.oauth2.core.OAuth2RefreshToken;
6263
import org.springframework.security.oauth2.core.endpoint.OAuth2AccessTokenResponse;
@@ -587,6 +588,43 @@ public void filterWhenClientRegistrationIdAndServerWebExchangeFromContextThenSer
587588
verify(this.authorizedClientRepository).loadAuthorizedClient(eq(this.registration.getRegistrationId()), any(), eq(this.serverWebExchange));
588589
}
589590

591+
// gh-7544
592+
@Test
593+
public void filterWhenClientCredentialsClientNotAuthorizedAndOutsideRequestContextThenGetNewToken() {
594+
// Use UnAuthenticatedServerOAuth2AuthorizedClientRepository when operating outside of a request context
595+
ServerOAuth2AuthorizedClientRepository unauthenticatedAuthorizedClientRepository = spy(new UnAuthenticatedServerOAuth2AuthorizedClientRepository());
596+
this.function = new ServerOAuth2AuthorizedClientExchangeFilterFunction(
597+
this.clientRegistrationRepository, unauthenticatedAuthorizedClientRepository);
598+
this.function.setClientCredentialsTokenResponseClient(this.clientCredentialsTokenResponseClient);
599+
600+
OAuth2AccessTokenResponse accessTokenResponse = OAuth2AccessTokenResponse.withToken("new-token")
601+
.tokenType(OAuth2AccessToken.TokenType.BEARER)
602+
.expiresIn(360)
603+
.build();
604+
when(this.clientCredentialsTokenResponseClient.getTokenResponse(any())).thenReturn(Mono.just(accessTokenResponse));
605+
606+
ClientRegistration registration = TestClientRegistrations.clientCredentials().build();
607+
when(this.clientRegistrationRepository.findByRegistrationId(eq(registration.getRegistrationId()))).thenReturn(Mono.just(registration));
608+
609+
ClientRequest request = ClientRequest.create(GET, URI.create("https://example.com"))
610+
.attributes(clientRegistrationId(registration.getRegistrationId()))
611+
.build();
612+
613+
this.function.filter(request, this.exchange).block();
614+
615+
verify(unauthenticatedAuthorizedClientRepository).loadAuthorizedClient(any(), any(), any());
616+
verify(this.clientCredentialsTokenResponseClient).getTokenResponse(any());
617+
verify(unauthenticatedAuthorizedClientRepository).saveAuthorizedClient(any(), any(), any());
618+
619+
List<ClientRequest> requests = this.exchange.getRequests();
620+
assertThat(requests).hasSize(1);
621+
ClientRequest request1 = requests.get(0);
622+
assertThat(request1.headers().getFirst(HttpHeaders.AUTHORIZATION)).isEqualTo("Bearer new-token");
623+
assertThat(request1.url().toASCIIString()).isEqualTo("https://example.com");
624+
assertThat(request1.method()).isEqualTo(HttpMethod.GET);
625+
assertThat(getBody(request1)).isEmpty();
626+
}
627+
590628
private Context serverWebExchange() {
591629
return Context.of(ServerWebExchange.class, this.serverWebExchange);
592630
}

0 commit comments

Comments
 (0)