kotlinadvanced

Spring Security with JWT Authentication

Configure Spring Security in Kotlin: JWT token generation, validation, role-based access, and security filters.

kotlin
import io.jsonwebtoken.Jwts
import io.jsonwebtoken.SignatureAlgorithm
import io.jsonwebtoken.security.Keys
import org.springframework.context.annotation.Bean
import org.springframework.context.annotation.Configuration
import org.springframework.http.HttpStatus
import org.springframework.security.authentication.UsernamePasswordAuthenticationToken
import org.springframework.security.config.annotation.method.configuration.EnableMethodSecurity
import org.springframework.security.config.annotation.web.builders.HttpSecurity
import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity
import org.springframework.security.config.http.SessionCreationPolicy
import org.springframework.security.core.authority.SimpleGrantedAuthority
import org.springframework.security.core.context.SecurityContextHolder
import org.springframework.security.web.SecurityFilterChain
import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter
import org.springframework.stereotype.Component
import org.springframework.web.bind.annotation.*
import org.springframework.security.access.prepost.PreAuthorize
import javax.crypto.SecretKey
import jakarta.servlet.FilterChain
import jakarta.servlet.http.HttpServletRequest
import jakarta.servlet.http.HttpServletResponse
import org.springframework.web.filter.OncePerRequestFilter
import java.util.Date

// JWT Service
@Component
class JwtService {
    private val key: SecretKey = Keys.secretKeyFor(SignatureAlgorithm.HS256)
    private val expirationMs = 3600_000L // 1 hour

    fun generateToken(userId: String, roles: List<String>): String =
        Jwts.builder()
            .setSubject(userId)
            .claim("roles", roles)
            .setIssuedAt(Date())
            .setExpiration(Date(System.currentTimeMillis() + expirationMs))
            .signWith(key)
            .compact()

    fun validateToken(token: String): Claims? = try {
        Jwts.parserBuilder()
            .setSigningKey(key)
            .build()
            .parseClaimsJws(token)
            .body
    } catch (e: Exception) {
        null
    }

    fun getUserId(token: String): String? = validateToken(token)?.subject

    @Suppress("UNCHECKED_CAST")
    fun getRoles(token: String): List<String> =
        validateToken(token)?.get("roles", List::class.java) as? List<String> ?: emptyList()
}

// JWT Filter
@Component
class JwtAuthFilter(private val jwtService: JwtService) : OncePerRequestFilter() {
    override fun doFilterInternal(
        request: HttpServletRequest,
        response: HttpServletResponse,
        filterChain: FilterChain
    ) {
        val authHeader = request.getHeader("Authorization")
        if (authHeader != null && authHeader.startsWith("Bearer ")) {
            val token = authHeader.substring(7)
            val userId = jwtService.getUserId(token)
            if (userId != null) {
                val roles = jwtService.getRoles(token)
                val authorities = roles.map { SimpleGrantedAuthority("ROLE_$it") }
                val auth = UsernamePasswordAuthenticationToken(userId, null, authorities)
                SecurityContextHolder.getContext().authentication = auth
            }
        }
        filterChain.doFilter(request, response)
    }
}

// Security Configuration
@Configuration
@EnableWebSecurity
@EnableMethodSecurity
class SecurityConfig(private val jwtFilter: JwtAuthFilter) {
    @Bean
    fun securityFilterChain(http: HttpSecurity): SecurityFilterChain = http
        .csrf { it.disable() }
        .sessionManagement { it.sessionCreationPolicy(SessionCreationPolicy.STATELESS) }
        .authorizeHttpRequests {
            it.requestMatchers("/api/auth/**").permitAll()
              .requestMatchers("/api/admin/**").hasRole("ADMIN")
              .anyRequest().authenticated()
        }
        .addFilterBefore(jwtFilter, UsernamePasswordAuthenticationFilter::class.java)
        .build()
}

// Auth Controller
data class LoginRequest(val username: String, val password: String)
data class TokenResponse(val token: String, val expiresIn: Long = 3600)

@RestController
@RequestMapping("/api/auth")
class AuthController(private val jwtService: JwtService) {
    private val users = mapOf(
        "admin" to Pair("admin123", listOf("ADMIN", "USER")),
        "user" to Pair("user123", listOf("USER"))
    )

    @PostMapping("/login")
    fun login(@RequestBody request: LoginRequest): TokenResponse {
        val (password, roles) = users[request.username]
            ?: throw ResponseStatusException(HttpStatus.UNAUTHORIZED, "Invalid credentials")
        if (password != request.password)
            throw ResponseStatusException(HttpStatus.UNAUTHORIZED, "Invalid credentials")
        return TokenResponse(jwtService.generateToken(request.username, roles))
    }
}

// Protected Controller
@RestController
@RequestMapping("/api")
class UserController {
    @GetMapping("/profile")
    fun profile(): Map<String, Any> {
        val userId = SecurityContextHolder.getContext().authentication.principal as String
        return mapOf("userId" to userId, "message" to "Welcome!")
    }

    @PreAuthorize("hasRole('ADMIN')")
    @GetMapping("/admin/users")
    fun listUsers() = mapOf("users" to listOf("admin", "user"))
}

Sponsored

Try Auth0

Use Cases

  • JWT-based stateless authentication
  • Role-based access control
  • Secure API endpoint protection

Tags

Related Snippets

Similar patterns you can reuse in the same workflow.