Skip to content

Commit a36baff

Browse files
committed
Polish OpenSamlAuthenticationRequestFactory
- Refactored to use SAMLMetadataSignatureSigningParametersResolver Issue gh-7758
1 parent 2ee455b commit a36baff

File tree

2 files changed

+117
-51
lines changed

2 files changed

+117
-51
lines changed

saml2/saml2-service-provider/src/main/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactory.java

Lines changed: 79 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
import java.security.cert.X509Certificate;
2222
import java.time.Clock;
2323
import java.time.Instant;
24-
import java.util.Collection;
24+
import java.util.ArrayList;
25+
import java.util.Collections;
2526
import java.util.LinkedHashMap;
27+
import java.util.List;
2628
import java.util.Map;
2729
import java.util.UUID;
2830

31+
import net.shibboleth.utilities.java.support.resolver.CriteriaSet;
2932
import net.shibboleth.utilities.java.support.xml.SerializeSupport;
3033
import org.joda.time.DateTime;
3134
import org.opensaml.core.config.ConfigurationService;
@@ -37,15 +40,18 @@
3740
import org.opensaml.saml.saml2.core.impl.AuthnRequestBuilder;
3841
import org.opensaml.saml.saml2.core.impl.AuthnRequestMarshaller;
3942
import org.opensaml.saml.saml2.core.impl.IssuerBuilder;
43+
import org.opensaml.saml.security.impl.SAMLMetadataSignatureSigningParametersResolver;
4044
import org.opensaml.security.SecurityException;
4145
import org.opensaml.security.credential.BasicCredential;
4246
import org.opensaml.security.credential.Credential;
4347
import org.opensaml.security.credential.CredentialSupport;
4448
import org.opensaml.security.credential.UsageType;
4549
import org.opensaml.xmlsec.SignatureSigningParameters;
50+
import org.opensaml.xmlsec.SignatureSigningParametersResolver;
51+
import org.opensaml.xmlsec.criterion.SignatureSigningConfigurationCriterion;
4652
import org.opensaml.xmlsec.crypto.XMLSigningUtil;
53+
import org.opensaml.xmlsec.impl.BasicSignatureSigningConfiguration;
4754
import org.opensaml.xmlsec.signature.support.SignatureConstants;
48-
import org.opensaml.xmlsec.signature.support.SignatureException;
4955
import org.opensaml.xmlsec.signature.support.SignatureSupport;
5056
import org.w3c.dom.Element;
5157

@@ -58,6 +64,7 @@
5864
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
5965
import org.springframework.util.Assert;
6066
import org.springframework.util.StringUtils;
67+
import org.springframework.web.util.UriComponentsBuilder;
6168
import org.springframework.web.util.UriUtils;
6269

6370
/**
@@ -105,9 +112,17 @@ public String createAuthenticationRequest(Saml2AuthenticationRequest request) {
105112
request.getAssertionConsumerServiceUrl(), this.protocolBindingResolver.convert(null));
106113
for (org.springframework.security.saml2.credentials.Saml2X509Credential credential : request.getCredentials()) {
107114
if (credential.isSigningCredential()) {
108-
Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(),
109-
request.getIssuer());
110-
return serialize(sign(authnRequest, cred));
115+
X509Certificate certificate = credential.getCertificate();
116+
PrivateKey privateKey = credential.getPrivateKey();
117+
BasicCredential cred = CredentialSupport.getSimpleCredential(certificate, privateKey);
118+
cred.setEntityId(request.getIssuer());
119+
cred.setUsageType(UsageType.SIGNING);
120+
SignatureSigningParameters parameters = new SignatureSigningParameters();
121+
parameters.setSigningCredential(cred);
122+
parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
123+
parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256);
124+
parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
125+
return serialize(sign(authnRequest, parameters));
111126
}
112127
}
113128
throw new IllegalArgumentException("No signing credential provided");
@@ -132,16 +147,13 @@ public Saml2RedirectAuthenticationRequest createRedirectAuthenticationRequest(
132147
String deflatedAndEncoded = Saml2Utils.samlEncode(Saml2Utils.samlDeflate(xml));
133148
result.samlRequest(deflatedAndEncoded).relayState(context.getRelayState());
134149
if (context.getRelyingPartyRegistration().getAssertingPartyDetails().getWantAuthnRequestsSigned()) {
135-
Collection<Saml2X509Credential> signingCredentials = context.getRelyingPartyRegistration()
136-
.getSigningX509Credentials();
137-
for (Saml2X509Credential credential : signingCredentials) {
138-
Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(), "");
139-
Map<String, String> signedParams = signQueryParameters(cred, deflatedAndEncoded,
140-
context.getRelayState());
141-
return result.samlRequest(signedParams.get("SAMLRequest")).relayState(signedParams.get("RelayState"))
142-
.sigAlg(signedParams.get("SigAlg")).signature(signedParams.get("Signature")).build();
150+
Map<String, String> parameters = new LinkedHashMap<>();
151+
parameters.put("SAMLRequest", deflatedAndEncoded);
152+
if (StringUtils.hasText(context.getRelayState())) {
153+
parameters.put("RelayState", context.getRelayState());
143154
}
144-
throw new Saml2Exception("No signing credential provided");
155+
sign(parameters, context.getRelyingPartyRegistration());
156+
return result.sigAlg(parameters.get("SigAlg")).signature(parameters.get("Signature")).build();
145157
}
146158
return result.build();
147159
}
@@ -211,59 +223,39 @@ public void setProtocolBinding(String protocolBinding) {
211223
}
212224

213225
private AuthnRequest sign(AuthnRequest authnRequest, RelyingPartyRegistration relyingPartyRegistration) {
214-
for (Saml2X509Credential credential : relyingPartyRegistration.getSigningX509Credentials()) {
215-
Credential cred = getSigningCredential(credential.getCertificate(), credential.getPrivateKey(),
216-
relyingPartyRegistration.getEntityId());
217-
return sign(authnRequest, cred);
218-
}
219-
throw new IllegalArgumentException("No signing credential provided");
226+
SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration);
227+
return sign(authnRequest, parameters);
220228
}
221229

222-
private AuthnRequest sign(AuthnRequest authnRequest, Credential credential) {
223-
SignatureSigningParameters parameters = new SignatureSigningParameters();
224-
parameters.setSigningCredential(credential);
225-
parameters.setSignatureAlgorithm(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
226-
parameters.setSignatureReferenceDigestMethod(SignatureConstants.ALGO_ID_DIGEST_SHA256);
227-
parameters.setSignatureCanonicalizationAlgorithm(SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS);
230+
private AuthnRequest sign(AuthnRequest authnRequest, SignatureSigningParameters parameters) {
228231
try {
229232
SignatureSupport.signObject(authnRequest, parameters);
230233
return authnRequest;
231234
}
232-
catch (MarshallingException | SignatureException | SecurityException ex) {
235+
catch (Exception ex) {
233236
throw new Saml2Exception(ex);
234237
}
235238
}
236239

237-
private Credential getSigningCredential(X509Certificate certificate, PrivateKey privateKey, String entityId) {
238-
BasicCredential cred = CredentialSupport.getSimpleCredential(certificate, privateKey);
239-
cred.setEntityId(entityId);
240-
cred.setUsageType(UsageType.SIGNING);
241-
return cred;
240+
private void sign(Map<String, String> components, RelyingPartyRegistration relyingPartyRegistration) {
241+
SignatureSigningParameters parameters = resolveSigningParameters(relyingPartyRegistration);
242+
sign(components, parameters);
242243
}
243244

244-
private Map<String, String> signQueryParameters(Credential credential, String samlRequest, String relayState) {
245-
Assert.notNull(samlRequest, "samlRequest cannot be null");
246-
String algorithmUri = SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256;
247-
StringBuilder queryString = new StringBuilder();
248-
queryString.append("SAMLRequest").append("=").append(UriUtils.encode(samlRequest, StandardCharsets.ISO_8859_1))
249-
.append("&");
250-
if (StringUtils.hasText(relayState)) {
251-
queryString.append("RelayState").append("=")
252-
.append(UriUtils.encode(relayState, StandardCharsets.ISO_8859_1)).append("&");
245+
private void sign(Map<String, String> components, SignatureSigningParameters parameters) {
246+
Credential credential = parameters.getSigningCredential();
247+
String algorithmUri = parameters.getSignatureAlgorithm();
248+
components.put("SigAlg", algorithmUri);
249+
UriComponentsBuilder builder = UriComponentsBuilder.newInstance();
250+
for (Map.Entry<String, String> component : components.entrySet()) {
251+
builder.queryParam(component.getKey(), UriUtils.encode(component.getValue(), StandardCharsets.ISO_8859_1));
253252
}
254-
queryString.append("SigAlg").append("=").append(UriUtils.encode(algorithmUri, StandardCharsets.ISO_8859_1));
253+
String queryString = builder.build(true).toString().substring(1);
255254
try {
256255
byte[] rawSignature = XMLSigningUtil.signWithURI(credential, algorithmUri,
257-
queryString.toString().getBytes(StandardCharsets.UTF_8));
256+
queryString.getBytes(StandardCharsets.UTF_8));
258257
String b64Signature = Saml2Utils.samlEncode(rawSignature);
259-
Map<String, String> result = new LinkedHashMap<>();
260-
result.put("SAMLRequest", samlRequest);
261-
if (StringUtils.hasText(relayState)) {
262-
result.put("RelayState", relayState);
263-
}
264-
result.put("SigAlg", algorithmUri);
265-
result.put("Signature", b64Signature);
266-
return result;
258+
components.put("Signature", b64Signature);
267259
}
268260
catch (SecurityException ex) {
269261
throw new Saml2Exception(ex);
@@ -280,4 +272,40 @@ private String serialize(AuthnRequest authnRequest) {
280272
}
281273
}
282274

275+
private SignatureSigningParameters resolveSigningParameters(RelyingPartyRegistration relyingPartyRegistration) {
276+
List<Credential> credentials = resolveSigningCredentials(relyingPartyRegistration);
277+
List<String> algorithms = Collections.singletonList(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
278+
List<String> digests = Collections.singletonList(SignatureConstants.ALGO_ID_DIGEST_SHA256);
279+
String canonicalization = SignatureConstants.ALGO_ID_C14N_EXCL_OMIT_COMMENTS;
280+
SignatureSigningParametersResolver resolver = new SAMLMetadataSignatureSigningParametersResolver();
281+
CriteriaSet criteria = new CriteriaSet();
282+
BasicSignatureSigningConfiguration signingConfiguration = new BasicSignatureSigningConfiguration();
283+
signingConfiguration.setSigningCredentials(credentials);
284+
signingConfiguration.setSignatureAlgorithms(algorithms);
285+
signingConfiguration.setSignatureReferenceDigestMethods(digests);
286+
signingConfiguration.setSignatureCanonicalizationAlgorithm(canonicalization);
287+
criteria.add(new SignatureSigningConfigurationCriterion(signingConfiguration));
288+
try {
289+
SignatureSigningParameters parameters = resolver.resolveSingle(criteria);
290+
Assert.notNull(parameters, "Failed to resolve any signing credential");
291+
return parameters;
292+
}
293+
catch (Exception ex) {
294+
throw new Saml2Exception(ex);
295+
}
296+
}
297+
298+
private List<Credential> resolveSigningCredentials(RelyingPartyRegistration relyingPartyRegistration) {
299+
List<Credential> credentials = new ArrayList<>();
300+
for (Saml2X509Credential x509Credential : relyingPartyRegistration.getSigningX509Credentials()) {
301+
X509Certificate certificate = x509Credential.getCertificate();
302+
PrivateKey privateKey = x509Credential.getPrivateKey();
303+
BasicCredential credential = CredentialSupport.getSimpleCredential(certificate, privateKey);
304+
credential.setEntityId(relyingPartyRegistration.getEntityId());
305+
credential.setUsageType(UsageType.SIGNING);
306+
credentials.add(credential);
307+
}
308+
return credentials;
309+
}
310+
283311
}

saml2/saml2-service-provider/src/test/java/org/springframework/security/saml2/provider/service/authentication/OpenSamlAuthenticationRequestFactoryTests.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,20 @@
2626
import org.opensaml.saml.common.xml.SAMLConstants;
2727
import org.opensaml.saml.saml2.core.AuthnRequest;
2828
import org.opensaml.saml.saml2.core.impl.AuthnRequestUnmarshaller;
29+
import org.opensaml.xmlsec.signature.support.SignatureConstants;
2930
import org.w3c.dom.Document;
3031
import org.w3c.dom.Element;
3132

3233
import org.springframework.core.convert.converter.Converter;
3334
import org.springframework.security.saml2.Saml2Exception;
35+
import org.springframework.security.saml2.core.Saml2X509Credential;
3436
import org.springframework.security.saml2.credentials.TestSaml2X509Credentials;
3537
import org.springframework.security.saml2.provider.service.registration.RelyingPartyRegistration;
3638
import org.springframework.security.saml2.provider.service.registration.Saml2MessageBinding;
39+
import org.springframework.security.saml2.provider.service.registration.TestRelyingPartyRegistrations;
3740

3841
import static org.assertj.core.api.Assertions.assertThat;
42+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
3943
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
4044
import static org.mockito.BDDMockito.given;
4145
import static org.mockito.Mockito.mock;
@@ -110,6 +114,28 @@ public void createRedirectAuthenticationRequestWhenNotSignRequestThenNoSignature
110114
assertThat(result.getBinding()).isEqualTo(Saml2MessageBinding.REDIRECT);
111115
}
112116

117+
@Test
118+
public void createRedirectAuthenticationRequestWhenSignRequestThenSignatureIsPresent() {
119+
this.context = this.contextBuilder.relayState("Relay State Value")
120+
.relyingPartyRegistration(this.relyingPartyRegistration).build();
121+
Saml2RedirectAuthenticationRequest request = this.factory.createRedirectAuthenticationRequest(this.context);
122+
assertThat(request.getRelayState()).isEqualTo("Relay State Value");
123+
assertThat(request.getSigAlg()).isEqualTo(SignatureConstants.ALGO_ID_SIGNATURE_RSA_SHA256);
124+
assertThat(request.getSignature()).isNotNull();
125+
}
126+
127+
@Test
128+
public void createRedirectAuthenticationRequestWhenSignRequestThenCredentialIsRequired() {
129+
Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials
130+
.relyingPartyVerifyingCredential();
131+
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials()
132+
.assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build();
133+
this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration)
134+
.build();
135+
assertThatExceptionOfType(Saml2Exception.class)
136+
.isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context));
137+
}
138+
113139
@Test
114140
public void createPostAuthenticationRequestWhenNotSignRequestThenNoSignatureIsPresent() {
115141
this.context = this.contextBuilder.relayState("Relay State Value")
@@ -139,6 +165,18 @@ public void createPostAuthenticationRequestWhenSignRequestThenSignatureIsPresent
139165
.contains("ds:Signature");
140166
}
141167

168+
@Test
169+
public void createPostAuthenticationRequestWhenSignRequestThenCredentialIsRequired() {
170+
Saml2X509Credential credential = org.springframework.security.saml2.core.TestSaml2X509Credentials
171+
.relyingPartyVerifyingCredential();
172+
RelyingPartyRegistration registration = TestRelyingPartyRegistrations.noCredentials()
173+
.assertingPartyDetails((party) -> party.verificationX509Credentials((c) -> c.add(credential))).build();
174+
this.context = this.contextBuilder.relayState("Relay State Value").relyingPartyRegistration(registration)
175+
.build();
176+
assertThatExceptionOfType(Saml2Exception.class)
177+
.isThrownBy(() -> this.factory.createPostAuthenticationRequest(this.context));
178+
}
179+
142180
@Test
143181
public void createAuthenticationRequestWhenDefaultThenReturnsPostBinding() {
144182
AuthnRequest authn = getAuthNRequest(Saml2MessageBinding.POST);

0 commit comments

Comments
 (0)