BaseAuthenticationSuccessHandler.java 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. /**
  2. *
  3. */
  4. package com.yihu.base.security.hander;
  5. import com.fasterxml.jackson.databind.ObjectMapper;
  6. import com.yihu.base.security.rbas.ClientServiceProvider;
  7. import org.apache.commons.codec.binary.StringUtils;
  8. import org.apache.commons.collections.MapUtils;
  9. import org.slf4j.Logger;
  10. import org.slf4j.LoggerFactory;
  11. import org.springframework.beans.factory.annotation.Autowired;
  12. import org.springframework.beans.factory.annotation.Qualifier;
  13. import org.springframework.security.authentication.BadCredentialsException;
  14. import org.springframework.security.core.Authentication;
  15. import org.springframework.security.crypto.codec.Base64;
  16. import org.springframework.security.oauth2.common.OAuth2AccessToken;
  17. import org.springframework.security.oauth2.common.exceptions.UnapprovedClientAuthenticationException;
  18. import org.springframework.security.oauth2.provider.*;
  19. import org.springframework.security.oauth2.provider.token.AuthorizationServerTokenServices;
  20. import org.springframework.security.oauth2.provider.token.DefaultTokenServices;
  21. import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler;
  22. import org.springframework.stereotype.Component;
  23. import javax.annotation.Resource;
  24. import javax.servlet.ServletException;
  25. import javax.servlet.http.HttpServletRequest;
  26. import javax.servlet.http.HttpServletResponse;
  27. import java.io.IOException;
  28. import java.io.UnsupportedEncodingException;
  29. /**
  30. * @author chenweida
  31. * <p>
  32. * 账号密码提交需要在 head 中添加 Basic clientID:cliengSecurty
  33. */
  34. @Component("BaseAuthenticationSuccessHandler")
  35. public class BaseAuthenticationSuccessHandler extends SavedRequestAwareAuthenticationSuccessHandler {
  36. private Logger logger = LoggerFactory.getLogger(getClass());
  37. @Autowired
  38. private ObjectMapper objectMapper;
  39. @Autowired
  40. private ClientServiceProvider clientDetailsService;
  41. @Autowired
  42. private AuthorizationServerTokenServices defaultTokenServices;
  43. /*
  44. * (non-Javadoc)
  45. *
  46. * @see org.springframework.security.web.authentication.
  47. * AuthenticationSuccessHandler#onAuthenticationSuccess(javax.servlet.http.
  48. * HttpServletRequest, javax.servlet.http.HttpServletResponse,
  49. * org.springframework.security.core.Authentication)
  50. */
  51. @Override
  52. public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response,
  53. Authentication authentication) throws IOException, ServletException {
  54. String header = request.getHeader("Authorization");
  55. if (org.springframework.util.StringUtils.isEmpty(header) || (!header.startsWith("Basic "))) {
  56. throw new UnapprovedClientAuthenticationException("请求头没有client信息");
  57. }
  58. //解析头部的basic信息
  59. String[] tokens = extractAndDecodeHeader(header, request);
  60. assert tokens.length == 2;
  61. String clientId = tokens[0];
  62. String clientSecurity = tokens[1];
  63. //得到ClientDetails
  64. ClientDetails clientDetails = clientDetailsService.loadClientByClientId(clientId);
  65. if (clientDetails == null) {
  66. throw new UnapprovedClientAuthenticationException("clientId不存在 client:" + clientId);
  67. } else if (!StringUtils.equals(clientDetails.getClientSecret(), clientSecurity)) {
  68. throw new UnapprovedClientAuthenticationException("clientSecurity 不匹配 client:" + clientId);
  69. }
  70. TokenRequest tokenRequest = new TokenRequest(MapUtils.EMPTY_MAP, clientId, clientDetails.getScope(), "custom_password");
  71. OAuth2Request oAuth2Request = tokenRequest.createOAuth2Request(clientDetails);
  72. OAuth2Authentication oAuth2Authentication = new OAuth2Authentication(oAuth2Request, authentication);
  73. OAuth2AccessToken token = defaultTokenServices.createAccessToken(oAuth2Authentication);
  74. response.setContentType("application/json;charset=UTF-8");
  75. response.getWriter().write(objectMapper.writeValueAsString(token));
  76. }
  77. /**
  78. * 解析
  79. *
  80. * @param header
  81. * @param request
  82. * @return
  83. * @throws IOException
  84. */
  85. private String[] extractAndDecodeHeader(String header, HttpServletRequest request)
  86. throws IOException {
  87. byte[] base64Token = header.substring(6).getBytes("UTF-8");
  88. byte[] decoded;
  89. try {
  90. decoded = Base64.decode(base64Token);
  91. } catch (IllegalArgumentException e) {
  92. throw new BadCredentialsException(
  93. "Failed to decode basic authentication token");
  94. }
  95. String token = new String(decoded, "UTF-8");
  96. int delim = token.indexOf(":");
  97. if (delim == -1) {
  98. throw new BadCredentialsException("Basic 信息不合法");
  99. }
  100. return new String[]{token.substring(0, delim), token.substring(delim + 1)};
  101. }
  102. }