1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19 package org.apache.shiro.guice.web;
20
21 import com.google.inject.Guice;
22 import com.google.inject.Inject;
23 import com.google.inject.Injector;
24 import com.google.inject.Key;
25 import com.google.inject.Provides;
26 import com.google.inject.binder.AnnotatedBindingBuilder;
27 import com.google.inject.name.Names;
28 import org.apache.shiro.guice.ShiroModuleTest;
29 import org.apache.shiro.env.Environment;
30 import org.apache.shiro.mgt.SecurityManager;
31 import org.apache.shiro.realm.Realm;
32 import org.apache.shiro.session.mgt.SessionManager;
33 import org.apache.shiro.web.env.EnvironmentLoader;
34 import org.apache.shiro.web.env.WebEnvironment;
35 import org.apache.shiro.web.filter.InvalidRequestFilter;
36 import org.apache.shiro.web.filter.authc.BasicHttpAuthenticationFilter;
37 import org.apache.shiro.web.filter.authc.FormAuthenticationFilter;
38 import org.apache.shiro.web.filter.authz.PermissionsAuthorizationFilter;
39 import org.apache.shiro.web.filter.authz.RolesAuthorizationFilter;
40 import org.apache.shiro.web.filter.mgt.FilterChainResolver;
41 import org.apache.shiro.web.mgt.DefaultWebSecurityManager;
42 import org.apache.shiro.web.mgt.WebSecurityManager;
43 import org.apache.shiro.web.session.mgt.DefaultWebSessionManager;
44 import org.apache.shiro.web.session.mgt.ServletContainerSessionManager;
45 import org.easymock.EasyMock;
46 import org.junit.Assume;
47 import org.junit.Test;
48
49 import javax.inject.Named;
50 import javax.servlet.Filter;
51 import javax.servlet.FilterChain;
52 import javax.servlet.FilterConfig;
53 import javax.servlet.ServletContext;
54 import javax.servlet.ServletException;
55 import javax.servlet.ServletRequest;
56 import javax.servlet.ServletResponse;
57 import javax.servlet.http.HttpServletRequest;
58 import java.io.IOException;
59 import java.util.Collection;
60 import java.util.Collections;
61 import java.util.Iterator;
62 import java.util.List;
63
64 import static org.easymock.EasyMock.*;
65 import static org.junit.Assert.*;
66 import static org.hamcrest.Matchers.*;
67
68
69 public class ShiroWebModuleTest {
70
71
72 @Test
73 public void basicInstantiation() {
74 final ShiroModuleTest.MockRealm mockRealm = createMock(ShiroModuleTest.MockRealm.class);
75 ServletContext servletContext = createMock(ServletContext.class);
76
77 Injector injector = Guice.createInjector(new ShiroWebModule(servletContext) {
78 @Override
79 protected void configureShiroWeb() {
80 bindRealm().to(ShiroModuleTest.MockRealm.class);
81 expose(SessionManager.class);
82 }
83
84 @Provides
85 public ShiroModuleTest.MockRealm createRealm() {
86 return mockRealm;
87 }
88
89 });
90
91
92 SecurityManager securityManager = injector.getInstance(SecurityManager.class);
93 assertNotNull(securityManager);
94 assertTrue(securityManager instanceof WebSecurityManager);
95 SessionManager sessionManager = injector.getInstance(SessionManager.class);
96 assertNotNull(sessionManager);
97 assertTrue(sessionManager instanceof ServletContainerSessionManager);
98 assertTrue(((DefaultWebSecurityManager)securityManager).getSessionManager() instanceof ServletContainerSessionManager);
99 }
100
101 @Test
102 public void testBindGuiceFilter() throws Exception {
103
104 }
105
106 @Test
107 public void testBindWebSecurityManager() throws Exception {
108 final ShiroModuleTest.MockRealm mockRealm = createMock(ShiroModuleTest.MockRealm.class);
109 ServletContext servletContext = createMock(ServletContext.class);
110
111 Injector injector = Guice.createInjector(new ShiroWebModule(servletContext) {
112 @Override
113 protected void configureShiroWeb() {
114 bindRealm().to(ShiroModuleTest.MockRealm.class);
115 expose(WebSecurityManager.class);
116 }
117
118 @Provides
119 public ShiroModuleTest.MockRealm createRealm() {
120 return mockRealm;
121 }
122
123 @Override
124 protected void bindWebSecurityManager(AnnotatedBindingBuilder<? super WebSecurityManager> bind) {
125 bind.to(MyDefaultWebSecurityManager.class).asEagerSingleton();
126 }
127 });
128 SecurityManager securityManager = injector.getInstance(SecurityManager.class);
129 assertNotNull(securityManager);
130 assertTrue(securityManager instanceof MyDefaultWebSecurityManager);
131 WebSecurityManager webSecurityManager = injector.getInstance(WebSecurityManager.class);
132 assertNotNull(webSecurityManager);
133 assertTrue(webSecurityManager instanceof MyDefaultWebSecurityManager);
134
135 assertTrue( securityManager == webSecurityManager );
136 }
137
138 @Test
139 public void testBindWebEnvironment() throws Exception {
140 final ShiroModuleTest.MockRealm mockRealm = createMock(ShiroModuleTest.MockRealm.class);
141 ServletContext servletContext = createMock(ServletContext.class);
142
143 Injector injector = Guice.createInjector(new ShiroWebModule(servletContext) {
144 @Override
145 protected void configureShiroWeb() {
146 bindRealm().to(ShiroModuleTest.MockRealm.class);
147 expose(WebEnvironment.class);
148 expose(Environment.class);
149 }
150
151 @Provides
152 public ShiroModuleTest.MockRealm createRealm() {
153 return mockRealm;
154 }
155
156 @Override
157 protected void bindWebEnvironment(AnnotatedBindingBuilder<? super WebEnvironment> bind) {
158 bind.to(MyWebEnvironment.class).asEagerSingleton();
159 }
160 });
161 Environment environment = injector.getInstance(Environment.class);
162 assertNotNull(environment);
163 assertTrue(environment instanceof MyWebEnvironment);
164 WebEnvironment webEnvironment = injector.getInstance(WebEnvironment.class);
165 assertNotNull(webEnvironment);
166 assertTrue(webEnvironment instanceof MyWebEnvironment);
167
168 assertTrue( environment == webEnvironment );
169 }
170
171
172
173
174 @Test
175 public void testAddFilterChainGuice3and4() {
176
177 final ShiroModuleTest.MockRealm mockRealm = createMock(ShiroModuleTest.MockRealm.class);
178 ServletContext servletContext = createMock(ServletContext.class);
179 HttpServletRequest request = createMock(HttpServletRequest.class);
180
181 servletContext.setAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY), EasyMock.anyObject());
182 expect(request.getAttribute("javax.servlet.include.context_path")).andReturn("").anyTimes();
183 expect(request.getCharacterEncoding()).andReturn("UTF-8").anyTimes();
184 expect(request.getAttribute("javax.servlet.include.path_info")).andReturn(null).anyTimes();
185 expect(request.getPathInfo()).andReturn(null).anyTimes();
186 expect(request.getAttribute("javax.servlet.include.servlet_path")).andReturn("/test_authc");
187 expect(request.getAttribute("javax.servlet.include.servlet_path")).andReturn("/test_custom_filter");
188 expect(request.getAttribute("javax.servlet.include.servlet_path")).andReturn("/test_authc_basic");
189 expect(request.getAttribute("javax.servlet.include.servlet_path")).andReturn("/test_perms");
190 expect(request.getAttribute("javax.servlet.include.servlet_path")).andReturn("/multiple_configs");
191 replay(servletContext, request);
192
193 Injector injector = Guice.createInjector(new ShiroWebModule(servletContext) {
194 @Override
195 protected void configureShiroWeb() {
196 bindRealm().to(ShiroModuleTest.MockRealm.class);
197 expose(FilterChainResolver.class);
198 this.addFilterChain("/test_authc/**", filterConfig(AUTHC));
199 this.addFilterChain("/test_custom_filter/**", Key.get(CustomFilter.class));
200 this.addFilterChain("/test_authc_basic/**", AUTHC_BASIC);
201 this.addFilterChain("/test_perms/**", filterConfig(PERMS, "remote:invoke:lan,wan"));
202 this.addFilterChain("/multiple_configs/**", filterConfig(AUTHC), filterConfig(ROLES, "b2bClient"), filterConfig(PERMS, "remote:invoke:lan,wan"));
203 }
204
205 @Provides
206 public ShiroModuleTest.MockRealm createRealm() {
207 return mockRealm;
208 }
209 });
210
211 FilterChainResolver resolver = injector.getInstance(FilterChainResolver.class);
212 assertThat(resolver, instanceOf(SimpleFilterChainResolver.class));
213 SimpleFilterChainResolver simpleFilterChainResolver = (SimpleFilterChainResolver) resolver;
214
215
216 FilterChain filterChain = simpleFilterChainResolver.getChain(request, null, null);
217 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
218 Filter nextFilter = getNextFilter((SimpleFilterChain) filterChain);
219 assertThat(nextFilter, instanceOf(InvalidRequestFilter.class));
220 nextFilter = getNextFilter((SimpleFilterChain) filterChain);
221 assertThat(nextFilter, instanceOf(FormAuthenticationFilter.class));
222
223
224 filterChain = simpleFilterChainResolver.getChain(request, null, null);
225 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
226 nextFilter = getNextFilter((SimpleFilterChain) filterChain);
227 assertThat(nextFilter, instanceOf(InvalidRequestFilter.class));
228 nextFilter = getNextFilter((SimpleFilterChain) filterChain);
229 assertThat(nextFilter, instanceOf(CustomFilter.class));
230
231
232 filterChain = simpleFilterChainResolver.getChain(request, null, null);
233 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
234 nextFilter = getNextFilter((SimpleFilterChain) filterChain);
235 assertThat(nextFilter, instanceOf(InvalidRequestFilter.class));
236 nextFilter = getNextFilter((SimpleFilterChain) filterChain);
237 assertThat(nextFilter, instanceOf(BasicHttpAuthenticationFilter.class));
238
239
240 filterChain = simpleFilterChainResolver.getChain(request, null, null);
241 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
242 nextFilter = getNextFilter((SimpleFilterChain) filterChain);
243 assertThat(nextFilter, instanceOf(InvalidRequestFilter.class));
244 nextFilter = getNextFilter((SimpleFilterChain) filterChain);
245 assertThat(nextFilter, instanceOf(PermissionsAuthorizationFilter.class));
246
247
248 filterChain = simpleFilterChainResolver.getChain(request, null, null);
249 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
250 assertThat(getNextFilter((SimpleFilterChain) filterChain), instanceOf(InvalidRequestFilter.class));
251 assertThat(getNextFilter((SimpleFilterChain) filterChain), instanceOf(FormAuthenticationFilter.class));
252 assertThat(getNextFilter((SimpleFilterChain) filterChain), instanceOf(RolesAuthorizationFilter.class));
253 assertThat(getNextFilter((SimpleFilterChain) filterChain), instanceOf(PermissionsAuthorizationFilter.class));
254
255 verify(servletContext, request);
256 }
257
258
259
260
261 @Test
262 public void testAddFilterChainGuice3Only() {
263
264 Assume.assumeTrue("This test only runs agains Guice 3.x", ShiroWebModule.isGuiceVersion3());
265
266 final ShiroModuleTest.MockRealm mockRealm = createMock(ShiroModuleTest.MockRealm.class);
267 ServletContext servletContext = createMock(ServletContext.class);
268 HttpServletRequest request = createMock(HttpServletRequest.class);
269
270 servletContext.setAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY), EasyMock.anyObject());
271 expect(request.getAttribute("javax.servlet.include.context_path")).andReturn("").anyTimes();
272 expect(request.getCharacterEncoding()).andReturn("UTF-8").anyTimes();
273 expect(request.getAttribute("javax.servlet.include.request_uri")).andReturn("/test_authc");
274 expect(request.getAttribute("javax.servlet.include.request_uri")).andReturn("/test_custom_filter");
275 expect(request.getAttribute("javax.servlet.include.request_uri")).andReturn("/test_perms");
276 expect(request.getAttribute("javax.servlet.include.request_uri")).andReturn("/multiple_configs");
277 replay(servletContext, request);
278
279 Injector injector = Guice.createInjector(new ShiroWebModule(servletContext) {
280 @Override
281 protected void configureShiroWeb() {
282 bindRealm().to(ShiroModuleTest.MockRealm.class);
283 expose(FilterChainResolver.class);
284 this.addFilterChain("/test_authc/**", AUTHC);
285 this.addFilterChain("/test_custom_filter/**", Key.get(CustomFilter.class));
286 this.addFilterChain("/test_perms/**", config(PERMS, "remote:invoke:lan,wan"));
287 this.addFilterChain("/multiple_configs/**", AUTHC, config(ROLES, "b2bClient"), config(PERMS, "remote:invoke:lan,wan"));
288 }
289
290 @Provides
291 public ShiroModuleTest.MockRealm createRealm() {
292 return mockRealm;
293 }
294 });
295
296 FilterChainResolver resolver = injector.getInstance(FilterChainResolver.class);
297 assertThat(resolver, instanceOf(SimpleFilterChainResolver.class));
298 SimpleFilterChainResolver simpleFilterChainResolver = (SimpleFilterChainResolver) resolver;
299
300
301 FilterChain filterChain = simpleFilterChainResolver.getChain(request, null, null);
302 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
303 Filter nextFilter = getNextFilter((SimpleFilterChain) filterChain);
304 assertThat(nextFilter, instanceOf(FormAuthenticationFilter.class));
305
306
307 filterChain = simpleFilterChainResolver.getChain(request, null, null);
308 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
309 nextFilter = getNextFilter((SimpleFilterChain) filterChain);
310 assertThat(nextFilter, instanceOf(CustomFilter.class));
311
312
313 filterChain = simpleFilterChainResolver.getChain(request, null, null);
314 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
315 nextFilter = getNextFilter((SimpleFilterChain) filterChain);
316 assertThat(nextFilter, instanceOf(PermissionsAuthorizationFilter.class));
317
318
319 filterChain = simpleFilterChainResolver.getChain(request, null, null);
320 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
321 assertThat(getNextFilter((SimpleFilterChain) filterChain), instanceOf(FormAuthenticationFilter.class));
322 assertThat(getNextFilter((SimpleFilterChain) filterChain), instanceOf(RolesAuthorizationFilter.class));
323 assertThat(getNextFilter((SimpleFilterChain) filterChain), instanceOf(PermissionsAuthorizationFilter.class));
324
325 verify(servletContext, request);
326 }
327
328 @Test
329 public void testDefaultPath() {
330
331 final ShiroModuleTest.MockRealm mockRealm = createMock(ShiroModuleTest.MockRealm.class);
332 ServletContext servletContext = createMock(ServletContext.class);
333 HttpServletRequest request = createMock(HttpServletRequest.class);
334
335 servletContext.setAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY), EasyMock.anyObject());
336 expect(request.getAttribute("javax.servlet.include.context_path")).andReturn("").anyTimes();
337 expect(request.getCharacterEncoding()).andReturn("UTF-8").anyTimes();
338 expect(request.getAttribute("javax.servlet.include.path_info")).andReturn(null).anyTimes();
339 expect(request.getPathInfo()).andReturn(null).anyTimes();
340 expect(request.getAttribute("javax.servlet.include.servlet_path")).andReturn("/test/foobar");
341 replay(servletContext, request);
342
343 Injector injector = Guice.createInjector(new ShiroWebModule(servletContext) {
344 @Override
345 protected void configureShiroWeb() {
346 bindRealm().to(ShiroModuleTest.MockRealm.class);
347 expose(FilterChainResolver.class);
348
349 }
350
351 @Provides
352 public ShiroModuleTest.MockRealm createRealm() {
353 return mockRealm;
354 }
355 });
356
357 FilterChainResolver resolver = injector.getInstance(FilterChainResolver.class);
358 assertThat(resolver, instanceOf(SimpleFilterChainResolver.class));
359 SimpleFilterChainResolver simpleFilterChainResolver = (SimpleFilterChainResolver) resolver;
360
361
362 FilterChain filterChain = simpleFilterChainResolver.getChain(request, null, null);
363 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
364
365 assertThat(getNextFilter((SimpleFilterChain) filterChain), instanceOf(InvalidRequestFilter.class));
366 assertThat(getNextFilter((SimpleFilterChain) filterChain), nullValue());
367
368 verify(servletContext, request);
369 }
370
371 @Test
372 public void testDisableGlobalFilters() {
373
374 final ShiroModuleTest.MockRealm mockRealm = createMock(ShiroModuleTest.MockRealm.class);
375 ServletContext servletContext = createMock(ServletContext.class);
376 HttpServletRequest request = createMock(HttpServletRequest.class);
377
378 servletContext.setAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY), EasyMock.anyObject());
379 expect(request.getAttribute("javax.servlet.include.context_path")).andReturn("").anyTimes();
380 expect(request.getCharacterEncoding()).andReturn("UTF-8").anyTimes();
381 expect(request.getAttribute("javax.servlet.include.path_info")).andReturn(null).anyTimes();
382 expect(request.getPathInfo()).andReturn(null).anyTimes();
383 expect(request.getAttribute("javax.servlet.include.servlet_path")).andReturn("/test/foobar");
384 replay(servletContext, request);
385
386 Injector injector = Guice.createInjector(new ShiroWebModule(servletContext) {
387 @Override
388 protected void configureShiroWeb() {
389 bindRealm().to(ShiroModuleTest.MockRealm.class);
390 expose(FilterChainResolver.class);
391 this.addFilterChain("/**", filterConfig(AUTHC));
392 }
393
394 @Override
395 public List<FilterConfig<? extends Filter>> globalFilters() {
396 return Collections.emptyList();
397 }
398
399 @Provides
400 public ShiroModuleTest.MockRealm createRealm() {
401 return mockRealm;
402 }
403 });
404
405 FilterChainResolver resolver = injector.getInstance(FilterChainResolver.class);
406 assertThat(resolver, instanceOf(SimpleFilterChainResolver.class));
407 SimpleFilterChainResolver simpleFilterChainResolver = (SimpleFilterChainResolver) resolver;
408
409
410 FilterChain filterChain = simpleFilterChainResolver.getChain(request, null, null);
411 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
412
413 assertThat(getNextFilter((SimpleFilterChain) filterChain), instanceOf(FormAuthenticationFilter.class));
414 assertThat(getNextFilter((SimpleFilterChain) filterChain), nullValue());
415
416 verify(servletContext, request);
417 }
418
419 @Test
420 public void testChangeInvalidFilterConfig() {
421
422 final ShiroModuleTest.MockRealm mockRealm = createMock(ShiroModuleTest.MockRealm.class);
423 ServletContext servletContext = createMock(ServletContext.class);
424 HttpServletRequest request = createMock(HttpServletRequest.class);
425
426 servletContext.setAttribute(eq(EnvironmentLoader.ENVIRONMENT_ATTRIBUTE_KEY), EasyMock.anyObject());
427 expect(request.getAttribute("javax.servlet.include.context_path")).andReturn("").anyTimes();
428 expect(request.getCharacterEncoding()).andReturn("UTF-8").anyTimes();
429 expect(request.getAttribute("javax.servlet.include.path_info")).andReturn(null).anyTimes();
430 expect(request.getPathInfo()).andReturn(null).anyTimes();
431 expect(request.getAttribute("javax.servlet.include.servlet_path")).andReturn("/test/foobar");
432 replay(servletContext, request);
433
434 Injector injector = Guice.createInjector(new ShiroWebModule(servletContext) {
435 @Override
436 protected void configureShiroWeb() {
437
438 bindConstant().annotatedWith(Names.named("shiro.blockBackslash")).to(false);
439
440 bindRealm().to(ShiroModuleTest.MockRealm.class);
441 expose(FilterChainResolver.class);
442 this.addFilterChain("/**", filterConfig(AUTHC));
443 }
444
445 @Provides
446 public ShiroModuleTest.MockRealm createRealm() {
447 return mockRealm;
448 }
449 });
450
451 FilterChainResolver resolver = injector.getInstance(FilterChainResolver.class);
452 assertThat(resolver, instanceOf(SimpleFilterChainResolver.class));
453 SimpleFilterChainResolver simpleFilterChainResolver = (SimpleFilterChainResolver) resolver;
454
455
456 FilterChain filterChain = simpleFilterChainResolver.getChain(request, null, null);
457 assertThat(filterChain, instanceOf(SimpleFilterChain.class));
458
459 Filter invalidRequestFilter = getNextFilter((SimpleFilterChain) filterChain);
460 assertThat(invalidRequestFilter, instanceOf(InvalidRequestFilter.class));
461 assertFalse("Expected 'blockBackslash' to be false", ((InvalidRequestFilter) invalidRequestFilter).isBlockBackslash());
462 assertThat(getNextFilter((SimpleFilterChain) filterChain), instanceOf(FormAuthenticationFilter.class));
463 assertThat(getNextFilter((SimpleFilterChain) filterChain), nullValue());
464
465 verify(servletContext, request);
466 }
467
468 private Filter getNextFilter(SimpleFilterChain filterChain) {
469
470 Iterator<? extends Filter> filters = filterChain.getFilters();
471 if (filters.hasNext()) {
472 return filters.next();
473 }
474
475 return null;
476 }
477
478 public static class MyDefaultWebSecurityManager extends DefaultWebSecurityManager {
479 @Inject
480 public MyDefaultWebSecurityManager(Collection<Realm> realms) {
481 super(realms);
482 }
483 }
484
485 public static class MyDefaultWebSessionManager extends DefaultWebSessionManager {
486 }
487
488 public static class MyWebEnvironment extends WebGuiceEnvironment {
489 @Inject
490 MyWebEnvironment(FilterChainResolver filterChainResolver, @Named(ShiroWebModule.NAME) ServletContext servletContext, WebSecurityManager securityManager) {
491 super(filterChainResolver, servletContext, securityManager);
492 }
493 }
494
495 public static class CustomFilter implements Filter {
496
497 @Override
498 public void init(FilterConfig filterConfig) throws ServletException {}
499
500 @Override
501 public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException {}
502
503 @Override
504 public void destroy() {}
505 }
506 }