20
20
import java .lang .reflect .Method ;
21
21
import java .util .ArrayList ;
22
22
import java .util .Arrays ;
23
+ import java .util .Collection ;
24
+ import java .util .Collections ;
25
+ import java .util .LinkedHashSet ;
23
26
import java .util .List ;
27
+ import java .util .Set ;
24
28
25
29
import org .springframework .beans .factory .BeanFactory ;
26
30
import org .springframework .beans .factory .BeanFactoryUtils ;
31
+ import org .springframework .beans .factory .FactoryBean ;
32
+ import org .springframework .beans .factory .HierarchicalBeanFactory ;
33
+ import org .springframework .beans .factory .ListableBeanFactory ;
34
+ import org .springframework .beans .factory .config .BeanDefinition ;
27
35
import org .springframework .beans .factory .config .ConfigurableListableBeanFactory ;
28
36
import org .springframework .context .annotation .Bean ;
29
37
import org .springframework .context .annotation .Condition ;
30
38
import org .springframework .context .annotation .ConditionContext ;
31
39
import org .springframework .context .annotation .ConfigurationCondition ;
40
+ import org .springframework .core .ResolvableType ;
32
41
import org .springframework .core .type .AnnotatedTypeMetadata ;
33
42
import org .springframework .core .type .MethodMetadata ;
34
43
import org .springframework .util .Assert ;
43
52
*
44
53
* @author Phillip Webb
45
54
* @author Dave Syer
55
+ * @author Jakub Kubrynski
46
56
*/
47
57
class OnBeanCondition extends SpringBootCondition implements ConfigurationCondition {
48
58
@@ -100,8 +110,8 @@ private List<String> getMatchingBeans(ConditionContext context, BeanSearchSpec b
100
110
boolean considerHierarchy = beans .getStrategy () == SearchStrategy .ALL ;
101
111
102
112
for (String type : beans .getTypes ()) {
103
- beanNames .addAll (Arrays . asList ( getBeanNamesForType (beanFactory , type ,
104
- context .getClassLoader (), considerHierarchy ))) ;
113
+ beanNames .addAll (getBeanNamesForType (beanFactory , type ,
114
+ context .getClassLoader (), considerHierarchy ));
105
115
}
106
116
107
117
for (String annotation : beans .getAnnotations ()) {
@@ -126,25 +136,94 @@ private boolean containsBean(ConfigurableListableBeanFactory beanFactory,
126
136
return beanFactory .containsLocalBean (beanName );
127
137
}
128
138
129
- private String [] getBeanNamesForType (ConfigurableListableBeanFactory beanFactory ,
130
- String type , ClassLoader classLoader , boolean considerHierarchy )
131
- throws LinkageError {
132
- // eagerInit set to false to prevent early instantiation (some
133
- // factory beans will not be able to determine their object type at this
134
- // stage, so those are not eligible for matching this condition)
139
+ private Collection <String > getBeanNamesForType (
140
+ ConfigurableListableBeanFactory beanFactory , String type ,
141
+ ClassLoader classLoader , boolean considerHierarchy ) throws LinkageError {
135
142
try {
136
- Class <?> typeClass = ClassUtils .forName (type , classLoader );
137
- if (considerHierarchy ) {
138
- return BeanFactoryUtils .beanNamesForTypeIncludingAncestors (beanFactory ,
139
- typeClass , false , false );
140
- }
141
- return beanFactory .getBeanNamesForType (typeClass , false , false );
143
+ Set <String > result = new LinkedHashSet <String >();
144
+ collectBeanNamesForType (result , beanFactory ,
145
+ ClassUtils .forName (type , classLoader ), considerHierarchy );
146
+ return result ;
142
147
}
143
148
catch (ClassNotFoundException ex ) {
144
- return NO_BEANS ;
149
+ return Collections .emptySet ();
150
+ }
151
+ }
152
+
153
+ private void collectBeanNamesForType (Set <String > result ,
154
+ ListableBeanFactory beanFactory , Class <?> type , boolean considerHierarchy ) {
155
+ // eagerInit set to false to prevent early instantiation
156
+ result .addAll (Arrays .asList (beanFactory .getBeanNamesForType (type , true , false )));
157
+ if (beanFactory instanceof ConfigurableListableBeanFactory ) {
158
+ collectBeanNamesForTypeFromFactoryBeans (result ,
159
+ (ConfigurableListableBeanFactory ) beanFactory , type );
160
+ }
161
+ if (considerHierarchy && beanFactory instanceof HierarchicalBeanFactory ) {
162
+ BeanFactory parent = ((HierarchicalBeanFactory ) beanFactory )
163
+ .getParentBeanFactory ();
164
+ if (parent instanceof ListableBeanFactory ) {
165
+ collectBeanNamesForType (result , (ListableBeanFactory ) parent , type ,
166
+ considerHierarchy );
167
+ }
168
+ }
169
+ }
170
+
171
+ /**
172
+ * Attempt to collect bean names for type by considering FactoryBean generics. Some
173
+ * factory beans will not be able to determine their object type at this stage, so
174
+ * those are not eligible for matching this condition.
175
+ */
176
+ private void collectBeanNamesForTypeFromFactoryBeans (Set <String > result ,
177
+ ConfigurableListableBeanFactory beanFactory , Class <?> type ) {
178
+ String [] names = beanFactory .getBeanNamesForType (FactoryBean .class , true , false );
179
+ for (String name : names ) {
180
+ name = BeanFactoryUtils .transformedBeanName (name );
181
+ BeanDefinition beanDefinition = beanFactory .getBeanDefinition (name );
182
+ Class <?> generic = getFactoryBeanGeneric (beanFactory , beanDefinition );
183
+ if (generic != null && ClassUtils .isAssignable (type , generic )) {
184
+ result .add (name );
185
+ }
145
186
}
146
187
}
147
188
189
+ private Class <?> getFactoryBeanGeneric (ConfigurableListableBeanFactory beanFactory ,
190
+ BeanDefinition definition ) {
191
+ try {
192
+ if (StringUtils .hasLength (definition .getFactoryBeanName ())
193
+ && StringUtils .hasLength (definition .getFactoryMethodName ())) {
194
+ return getConfigurationClassFactoryBeanGeneric (beanFactory , definition );
195
+ }
196
+ if (StringUtils .hasLength (definition .getBeanClassName ())) {
197
+ return getDirectFactoryBeanGeneric (beanFactory , definition );
198
+ }
199
+ }
200
+ catch (Exception ex ) {
201
+ }
202
+ return null ;
203
+ }
204
+
205
+ private Class <?> getConfigurationClassFactoryBeanGeneric (
206
+ ConfigurableListableBeanFactory beanFactory , BeanDefinition definition )
207
+ throws Exception {
208
+ BeanDefinition factoryDefinition = beanFactory .getBeanDefinition (definition
209
+ .getFactoryBeanName ());
210
+ Class <?> factoryClass = ClassUtils .forName (factoryDefinition .getBeanClassName (),
211
+ beanFactory .getBeanClassLoader ());
212
+ Method method = ReflectionUtils .findMethod (factoryClass ,
213
+ definition .getFactoryMethodName ());
214
+ return ResolvableType .forMethodReturnType (method ).as (FactoryBean .class )
215
+ .resolveGeneric ();
216
+ }
217
+
218
+ private Class <?> getDirectFactoryBeanGeneric (
219
+ ConfigurableListableBeanFactory beanFactory , BeanDefinition definition )
220
+ throws ClassNotFoundException , LinkageError {
221
+ Class <?> factoryBeanClass = ClassUtils .forName (definition .getBeanClassName (),
222
+ beanFactory .getBeanClassLoader ());
223
+ return ResolvableType .forClass (factoryBeanClass ).as (FactoryBean .class )
224
+ .resolveGeneric ();
225
+ }
226
+
148
227
private String [] getBeanNamesForAnnotation (
149
228
ConfigurableListableBeanFactory beanFactory , String type ,
150
229
ClassLoader classLoader , boolean considerHierarchy ) throws LinkageError {
0 commit comments