Recently I learned about access and refresh tokens, and thought they were cool but confusing at first. We had to implement this in our Spring Boot assignment, but first I want to go through an overview of that the access/refresh token flow looks like.

Access/Refresh Tokens

Example Process 1

  1. The user sends a login request with username and password.
POST /api/login HTTP/1.1
Host: example.com
Content-Type: application/json
{
    "username": "alex",
    "password": "user_password123"
}
  1. The server verifies the user’s credentials.
  2. If verification succeeds, the server issues a token by sending the token back in the JSON body of the HTTP response.
  • The server generates both a short-lived accessToken and a long-lived refreshToken.

    • It sends the accessToken in the JSON body.
    • It sends the refreshToken in a secure, HttpOnly cookie.
HTTP/1.1 200 OK
Content-Type: application/json
Set-Cookie: refreshToken=aVeryLongAndSecureRefreshTokenString; HttpOnly; Secure; Path=/api/refresh
{
    "message": "Login successful!",
    "accessToken": "generated_access_token"
}
  1. The client stores the tokens.
  • accesstoken - typically in memory (a variable, state management memory, or localStorage)
  • refreshToken - browser automatically stores this as a secure httpOnly cookie
  1. For subsequent requests, the client includes the accessToken in the Authorization header.
  • for regular API calls like this, the client only sends the accessToken w/o the refreshToken
GET /api/user/profile HTTP/1.1
Host: example.com
Authorization: Bearer generated_access_token
  1. The server validates the new access token and processes the request.

Example Process 2 - accessToken expires

  1. The original accessToken (which might only be valid for 15 minutes) has now expired.
  2. Client tries to make another API call
  • The client doesn't know the token has expired yet, so it sends the same request as in Step 5, but with the expired token.
GET /api/user/orders HTTP/1.1
Host: example.com
Authorization: Bearer generated_access_token
  1. Server Rejects the Expired Token
  • The server validates the token, sees that it's expired, and rejects the request with a *401 Unauthorized** status code.
  1. Client then handles the 401 error (The "Refresh" Logic)
  • The client is programmed to handle 401 errors. Instead of logging the user out, it first tries to get a new access token by making a POST request to the refresh endpoint (e.g., /api/refresh).

  • NO Authorization header here -> The browser automatically attaches the refreshToken cookie to this request

    • refreshToken is only sent when accessToken expires & you need a new one)
POST /api/refresh HTTP/1.1
Host: example.com
  1. Server validates refresh token and issues new accessToken and refreshToken
  • The server receives the request at /api/refresh

  • The server looks at the refreshToken it received and verifies that it is valid and has not expired. If it's good:

    • it generates a brand new accessToken & sends this new token back in the JSON body
    • it immediately invalidates that old refreshToken and issues a brand new one along with the new access token
HTTP/1.1 200 OK
Content-Type: application/json
Set-Cookie: refreshToken=a_BRAND_NEW_refresh_token_string; HttpOnly; Secure; Path=/api/refresh
{
    "accessToken": "a_NEW_access_token"
}

12. The client receives the new accessToken

  • The client receives the new accessToken and updates the one it has in memory.
  • It now automatically retries the failed request from Step 8, but this time with the new, valid token.
GET /api/user/orders HTTP/1.1
Host: example.com
Authorization: Bearer a_NEW_access_token

My (badly) drawn picture of the flow

The Assignment

Now I just needed to organize what I did in my assignment (implementing those features).

JwtTokenProvider

@Getter
@Component
public class JwtTokenProvider {

    @Value("${jwt.secret-key}")
    private String secretKey;

    @Value("${jwt.access-token-expiration-minutes}")
    private int accessTokenExpirationMinutes;

    @Value("${jwt.refresh-token-expiration-minutes}")
    private int refreshTokenExpirationMinutes;

    /**
     * Generate access token
     */
    public String generateAccessToken(Map<String, Object> claims, String subject) {
    try {
        JWSSigner signer = new MACSigner(secretKey.getBytes(StandardCharsets.UTF_8));
        Date expiration = new Date(
            System.currentTimeMillis() +
                (long) accessTokenExpirationMinutes * 60 * 1000);

        JWTClaimsSet claimsSet = new Builder()
            .subject(subject)
            .claim("email", claims.get("email"))
            .claim("username", claims.get("username"))
            .claim("role", claims.get("role"))
            .expirationTime(expiration)
            .issueTime(new Date())
            .build();

        SignedJWT signedJwt = new SignedJWT(
            new JWSHeader(JWSAlgorithm.HS256),
            claimsSet
        );

        signedJwt.sign(signer);
        return signedJwt.serialize();

    } catch (Exception e) {
        throw new RuntimeException("JWT token issue failure", e);
    }
    }

    /**
     * Generate refresh token (same as above but without the claims)
     */
    public String generateRefreshToken(String subject) {
    try {
        JWSSigner signer = new MACSigner(secretKey.getBytes(StandardCharsets.UTF_8));
        Date expiration = new Date(
            System.currentTimeMillis() +
                (long) refreshTokenExpirationMinutes * 60 * 1000);

        JWTClaimsSet claimsSet = new Builder()
            .subject(subject)
            .expirationTime(expiration)
            .issueTime(new Date())
            .build();

        SignedJWT signedJwt = new SignedJWT(
            new JWSHeader(JWSAlgorithm.HS256),
            claimsSet
        );

        signedJwt.sign(signer);
        return signedJwt.serialize();

    } catch (Exception e) {
        throw new RuntimeException("JWT token issue failure", e);
    }
    }

    public Map<String, Object> getClaims(String token) {
    try {
        SignedJWT signedJWT = SignedJWT.parse(token);
        JWSVerifier verifier = new MACVerifier(secretKey.getBytes(StandardCharsets.UTF_8));

        // signature verification
        if (!signedJWT.verify(verifier)) {
        throw new RuntimeException("JWT signature verification failed");
        }
        JWTClaimsSet claimsSet = signedJWT.getJWTClaimsSet();

        // expiration verification
        if (new Date().after(claimsSet.getExpirationTime())) {
        throw new RuntimeException("JWT expired.");
        }

        return claimsSet.getClaims();

    } catch (Exception e) {
        throw new RuntimeException("JWT 파싱 실패", e);
    }
    }
}
  • JwtTokenProvider class

    • A specialized class responsible for creating, reading, and validating JSON Web Tokens (JWTs).
  • generateAccessToken(), generateRefreshToken()

    • Creates an access/refresh token using nimbus
  • getClaims()

    • verifies if the token is legitimate before returning its claims

Login + subsequent calls

Making sessions stateless and adding filters

@Configuration
@EnableWebSecurity
@EnableMethodSecurity
@RequiredArgsConstructor
public class SecurityConfig {
    ...
    @Bean
    public SecurityFilterChain filterChain(HttpSecurity http) throws Exception {
    JwtLoginFilter jwtLoginFilter = new JwtLoginFilter(authenticationManager, jwtTokenProvider, refreshTokenRepository);
    JwtAuthenticationFilter jwtAuthenticationFilter = new JwtAuthenticationFilter(jwtTokenProvider, userDetailsService);

    ...

    http
        ...
        // since we're using tokens, we make sessions STATELESS (no sessions)
        // we don't create HttpSession or a JsessionId cookie
        .sessionManagement(management -> management
            .sessionCreationPolicy(SessionCreationPolicy.STATELESS))
        // custom filters
        .addFilter(jwtLoginFilter) // Add the login filter
        .addFilterBefore(jwtAuthenticationFilter,
            JwtLoginFilter.class) // Add verification filter before login filter
  • sessionCreationPolicy(SessionCreationPolicy.STATELESS)

    • Previously, I had implemented the sessionManagement part with maximumSessions where each user could not log in to another session (max was set to 1). Here, the server was maintaining the state, controlling how many tabs one user can open at once.
    • Now since we're using tokens now, it's stateless as the server doesn't need to remember anything. The server creates the access tokens and sends to the client without storing the access token, and for each request the client sends with the access token the server just needs to validate the signature and expiration date.
  • Added custom filters - JwtLoginFilter and JwtAuthenticationFilter

    • JwtLoginFilter - handles the initial login
    • JwtAuthenticationFilter - handles all subsequent requests

JwtLoginFilter

@RequiredArgsConstructor
public class JwtLoginFilter extends UsernamePasswordAuthenticationFilter {

    private final AuthenticationManager authenticationManager;
    private final JwtTokenProvider jwtTokenProvider;
    private final ObjectMapper objectMapper = new ObjectMapper();
    private final RefreshTokenRepository refreshTokenRepository;

    @SneakyThrows
    @Override
    public Authentication attemptAuthentication(HttpServletRequest request,
        HttpServletResponse response) {
    String username = request.getParameter("username");
    String password = request.getParameter("password");
    username = (username != null) ? username : "";
    password = (password != null) ? password : "";

    UsernamePasswordAuthenticationToken authenticationToken = new UsernamePasswordAuthenticationToken(
        username,
        password
    );
    return authenticationManager.authenticate(authenticationToken);
    }

    @Override
    protected void successfulAuthentication(HttpServletRequest request,
        HttpServletResponse response,
        FilterChain chain,
        Authentication authResult) throws IOException {

    DiscodeitUserDetails userDetails = (DiscodeitUserDetails) authResult.getPrincipal();
    UserDto user = userDetails.getUserDto();
    String accessToken = delegateAccessToken(user);
    String refreshToken = delegateRefreshToken(user);

    Cookie refreshTokenCookie = new Cookie("REFRESH_TOKEN", refreshToken);
    refreshTokenCookie.setHttpOnly(true);
    refreshTokenCookie.setPath("/");
    response.addCookie(refreshTokenCookie);

    Optional<RefreshToken> existingRefreshToken = refreshTokenRepository.findByUserId(
        user.userId());

    RefreshToken refreshTokenToSave;
    if (existingRefreshToken.isPresent()) {
        RefreshToken foundToken = existingRefreshToken.get();
        foundToken.updateToken(refreshToken, jwtTokenProvider.getRefreshTokenExpirationMinutes());
        refreshTokenToSave = foundToken;
        System.out.println("Updating existing refresh token for user: " + user.userId());
    } else {
        LocalDateTime expirationDate = LocalDateTime.now().plusMinutes(
            jwtTokenProvider.getRefreshTokenExpirationMinutes());
        refreshTokenToSave = RefreshToken.builder()
            .userId(user.userId())
            .token(refreshToken)
            .expiredAt(expirationDate)
            .rotated(false)
            .build();
        System.out.println("Creating new refresh token for user: " + user.userId());
    }

    refreshTokenRepository.save(refreshTokenToSave);

    // setting response body
    JwtDto jwtDto = new JwtDto(user, accessToken);
    String responseBody = objectMapper.writeValueAsString(jwtDto);
    response.getWriter().write(responseBody);

    }

    // ====================== private methods ======================
    // create access token
    private String delegateAccessToken(UserDto user) {
    Map<String, Object> claims = new HashMap<>();
    claims.put("username", user.username());
    claims.put("email", user.email());
    claims.put("role", user.role());
    String subject = user.userId().toString();

    return jwtTokenProvider.generateAccessToken(claims, subject);
    }

    // create refresh token
    private String delegateRefreshToken(UserDto user) {
    String subject = user.userId().toString();

    return jwtTokenProvider.generateRefreshToken(subject);
    }
}
  • JwtLoginFilter

    • This filter is called when the user logs initially logs in
    • It extends UsernamePasswordAuthenticationFilter, which is the default class for handling form-based logins. By extending this we can get its functionality and add our own logic.
  • attemptAuthentication

    • It calls AuthenticationManager to do the authentication, in which it calls the custom DiscodeitUserDetailsService that was implemented earlier. This then returns UserDetails, but since this was implemented as DiscodeitUserDetails earlier, we will get DiscodeitUserDetails back.
  • successfulAuthentication

    • This method is called automatically when authentication was successful. From Authentication authResult , we can retrieve the DiscodeitUserDetails.

    • With the user details, we create the access token and the refresh token.

    • However for the refresh token, we check for 2 cases

      • first time login - application creates a new refresh token and saves in the db
      • repeat login (ex. user is logged in multiple devices but one device timed out) - since refresh token already exists in db we just update the existing one with the new token value and add a new expiration date
    • For the response, we send back the UserDto and the accessToken in the response body and send the refresh token in a cookie.

JwtAuthenticationFilter

@RequiredArgsConstructor
public class JwtAuthenticationFilter extends OncePerRequestFilter {

    private final JwtTokenProvider jwtTokenProvider;
    private final DiscodeitUserDetailsService userDetailsService;

    @Override
    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
        FilterChain filterChain) throws ServletException, IOException {

    // get claims
    Map<String, Object> claims = verifyJws(request);

    // get userDetail
    String username = claims.get("username").toString();
    UserDetails userDetails = userDetailsService.loadUserByUsername(username);

    UsernamePasswordAuthenticationToken authentication =
        new UsernamePasswordAuthenticationToken(
            userDetails,
            null,
            userDetails.getAuthorities()
        );
    SecurityContextHolder.getContext().setAuthentication(authentication);

    // continue to next chain
    filterChain.doFilter(request, response);
    }

    @Override
    protected boolean shouldNotFilter(HttpServletRequest request) throws ServletException {
    String authorization = request.getHeader("Authorization");
    return authorization == null || !authorization.startsWith("Bearer");
    }

    // ======================== private methods ========================
    private Map<String, Object> verifyJws(HttpServletRequest request) {
    String jws = request.getHeader("Authorization").replace("Bearer ", "");
    return jwtTokenProvider.getClaims(jws); // includes verification
    }
}
  • JwtAuthenticationFilter

    • This filter is called for every API request to the backend, and it is responsible for validating the access token that comes per request.
    • At this point the user is logged in and the user information is saved.
    • It extends OncePerRequestFilter, which is fitting because it should be called once per request.
  • shouldNotFilter

    • This method first checks if the request has the Authorization: Bearer... header. If token exists in the header, doFilterInternal is called.
  • doFilterInternal

    • This method extracts the access token and verifies the signature and check. If it throws an error during the verification, the application will stop; if nothing happens the application will continue to run.

      • If access token is not valid, the server will respond with 401 Unauthorized, which will let the client call /refresh.
    • It takes the username from the valid token, uses it to get the UserDetails, then authenticates the user by creating an Authentication object and placing it in the SecurityContextHolder.

/refresh

  • This endpoint is called when the old access token is expired; It gives a new access token and a refresh token as a cookie.
  • It's called only once by the client after getting a 401 Unauthorized error.

AuthController

    @PostMapping("/refresh")
    public ResponseEntity<?> refresh(
        @CookieValue(value = "REFRESH_TOKEN", required = false) String refreshTokenValue,
        HttpServletResponse response
    ) {

    if (refreshTokenValue == null) {
        return ResponseEntity.status(HttpStatus.UNAUTHORIZED)
            .body("Refresh token is missing.");
    }

    TokenResponse tokenResponse = tokenService.reissueToken(refreshTokenValue);

    //creating new cookie
    Cookie refreshTokenCookie = new Cookie("REFRESH_TOKEN", tokenResponse.getRefreshToken());
    refreshTokenCookie.setHttpOnly(true);
    refreshTokenCookie.setPath("/");
    response.addCookie(refreshTokenCookie);
    JwtDto jwtDto = new JwtDto(tokenResponse.getUserDto(), tokenResponse.getAccessToken());
    return ResponseEntity.ok(jwtDto);
    }

### TokenService

@Service
@RequiredArgsConstructor
public class TokenService {

    private final RefreshTokenRepository refreshTokenRepository;
    private final JwtTokenProvider jwtTokenProvider;
    private final UserRepository userRepository;
    private final UserMapper userMapper;

    @Transactional
    public TokenResponse reissueToken(String refreshTokenValue) {
    RefreshToken refreshToken = refreshTokenRepository.findByToken(refreshTokenValue)
        .orElseThrow(() -> new IllegalArgumentException("Invalid Refresh Token"));

    if (refreshToken.getExpiredAt().isBefore(LocalDateTime.now())) {
        throw new IllegalArgumentException("Expired Refresh Token - Please login again");
    }
    if (refreshToken.isRotated()) {
        throw new IllegalArgumentException("Rotated Refresh Token - Please login again");
    }

    User user = userRepository.findById(refreshToken.getUserId())
        .orElseThrow(UserNotFoundException::new);
    UUID userId = user.getId();

    // new access Token
    Map<String, Object> claims = new HashMap<>();
    claims.put("userId", userId);
    claims.put("email", user.getEmail());
    claims.put("username", user.getUsername());
    claims.put("role", user.getRole().name());
    String subject = user.getId().toString();
    String newAccessToken = jwtTokenProvider.generateAccessToken(claims, subject);

    String newRefreshTokenValue = refreshToken.getToken(); // old token

    // rotating by update
    if (shouldRotate(refreshToken)) {
        long expirationMinutes = jwtTokenProvider.getRefreshTokenExpirationMinutes();
        LocalDateTime newExpiration = LocalDateTime.now().plusMinutes(expirationMinutes);
        newRefreshTokenValue = jwtTokenProvider.generateRefreshToken(userId.toString());

        refreshToken.setToken(newRefreshTokenValue);
        refreshToken.setExpiredAt(newExpiration);
    }
    return new TokenResponse(userMapper.toDto(user), newAccessToken, newRefreshTokenValue);
    }

    private boolean shouldRotate(RefreshToken token) {
    return token.getExpiredAt().isBefore(LocalDateTime.now().plusDays(3));
    }
}
  • reissueToken()

    • This function is called when access token has expired.
    • It first checks the refresh token's validity, and if it's invalid then it throws an error.
    • It creates a new access token using the userid from the refresh token.
    • It updates the refresh token's value and expiration date depending on shouldRotate() , which checks if the current refresh token has less than 3 days left. If so, it will give a new lifespan by updating its token value to a new one and setting a new expiration date.