Logo of Versilite

Event-Driven AI Response Streaming with Axon Framework - A Practical Guide

16 min read

Event-Driven AI Response Streaming with Axon Framework: A Practical Guide

Let’s dive into a practical implementation of AI response streaming using event-driven architecture with Axon Framework and Spring. This approach offers real-time updates while maintaining clean architectural patterns.

Core Components

Commands and Events

data class TestCommand(@TargetAggregateIdentifier val id: UUID, val message: String)
data class TestEvent(val id: UUID, val message: String)
data class TestAiResponseCommand(@TargetAggregateIdentifier val id: UUID, val message: String)
data class TestAiResponseEvent(val id: UUID, val message: String)

These simple data classes form our event-driven backbone. Each serves a specific purpose in our CQRS pattern.

The Shared Flux Component

@Component
class SharedFlux {
    final val sink: Sinks.Many<Map<UUID, String>> = Sinks.many().multicast().onBackpressureBuffer()
    val flux = sink.asFlux()
}

This component acts as our broadcast mechanism, allowing multiple subscribers to receive real-time updates from our AI stream.

The Implementation

1. Processing AI Responses

@Component
class Processor(
    private val openAiChatModel: OpenAiChatModel,
    private val sharedFlux: SharedFlux,
    private val commandGateway: CommandGateway
) {
    @EventHandler
    fun on(event: TestEvent) {
        val prompt = Prompt(UserMessage(event.message))
        val finalResponse = StringBuilder()
        
        openAiChatModel.stream(prompt)
            .doOnNext { response ->
                val content = response.result.output.content ?: ""
                sharedFlux.sink.emitNext(mapOf(event.id to content), Sinks.EmitFailureHandler.FAIL_FAST)
                finalResponse.append(content)
            }
            .doOnComplete {
                commandGateway.send<TestAiResponseCommand>(
                    TestAiResponseCommand(event.id, finalResponse.toString())
                )
            }
            .subscribe()
    }
}

2. API Endpoints

@RestController
@RequestMapping("/api")
class SharedFluxController(
    val queryGateway: QueryGateway,
    private val sharedFlux: SharedFlux,
    val commandGateway: CommandGateway,
) {
    @GetMapping("/generateStream")
    fun generateStream(message: String): Flux<Map<UUID, String>> {
        val id = UUID.randomUUID()
        commandGateway.send<TestCommand>(TestCommand(id, message))
        return sharedFlux.flux
    }

    @GetMapping("/liveResponse")
    fun getLiveResponse(): Flux<String> {
        return sharedFlux.flux
            .filter { it.containsKey(id) }
            .map { it[id]!! }
    }
}

Key Benefits

  1. Separation of Concerns

    • Commands handle user requests
    • Events process AI responses
    • Queries retrieve final results
    • Shared Flux manages real-time updates
  2. Real-Time Streaming

    • Immediate updates as AI generates responses
    • No polling required
    • Efficient multicast to all subscribers
  3. Event Sourcing

    • Complete history of AI responses
    • Ability to replay events
    • Query final responses at any time
  4. Scalability

    • Decoupled components
    • Event-driven architecture
    • Ready for distributed systems

Practical Use Cases

This architecture is particularly useful for:

  • Chat applications requiring real-time responses
  • AI-powered content generation tools
  • Interactive AI assistants
  • Any system requiring real-time AI interaction with multiple clients

Implementation Notes

  • Uses Spring AI’s OpenAiChatModel for AI integration
  • Leverages Axon Framework for event sourcing and CQRS
  • Implements both streaming and query-based endpoints
  • Maintains state through event sourcing

The beauty of this approach lies in its simplicity and separation of concerns, while still providing robust functionality for real-time AI response streaming.

Some further Notes:

  • You may want to consider an automatic cleanup mechanism for the shared flux to avoid memory leaks.
  • In a distributed system, you may want to choose kafka as a streaming service to handle the shared flux.

Complete Implementation

package com.versilite.demo.liveSharedAiResponse

import org.axonframework.commandhandling.CommandHandler
import org.axonframework.commandhandling.gateway.CommandGateway
import org.axonframework.eventhandling.EventHandler
import org.axonframework.eventsourcing.EventSourcingHandler
import org.axonframework.eventsourcing.eventstore.EventStore
import org.axonframework.modelling.command.*
import org.axonframework.queryhandling.QueryGateway
import org.axonframework.queryhandling.QueryHandler
import org.axonframework.queryhandling.QueryUpdateEmitter
import org.axonframework.spring.stereotype.Aggregate
import org.springframework.ai.chat.messages.UserMessage
import org.springframework.ai.chat.prompt.Prompt
import org.springframework.ai.openai.OpenAiChatModel
import org.springframework.ai.openai.OpenAiChatOptions
import org.springframework.ai.openai.api.OpenAiApi
import org.springframework.http.MediaType
import org.springframework.stereotype.Component
import org.springframework.web.bind.annotation.GetMapping
import org.springframework.web.bind.annotation.RequestMapping
import org.springframework.web.bind.annotation.RequestParam
import org.springframework.web.bind.annotation.RestController
import reactor.core.publisher.Flux
import reactor.core.publisher.Sinks
import java.util.*

//That's how the testing started with the OpenAI Chat Model
@RestController
class TestController(
    private val openAiChatModel: OpenAiChatModel
) {

    @GetMapping("/ai/generateStream")
    fun generateStream(
        @RequestParam(
            value = "message",
            defaultValue = "Tell me a joke"
        ) message: String?
    ): Flux<String> {
        val client = OpenAiChatOptions.builder()
            .withModel(OpenAiApi.ChatModel.GPT_4_O_MINI)
            .build()
        val prompt = Prompt(UserMessage(message), client)
        val fullResponse = StringBuilder()
       return openAiChatModel.stream(prompt)
            .doOnNext { response ->
                fullResponse.append(response.result.output.content?: "")
            }.doOnComplete {
               println("Full response: $fullResponse")
           }.map {
                it.result.output.content?: ""
           }
    }
}

data class TestCommand(@TargetAggregateIdentifier val id: UUID, val message: String)
data class TestEvent(val id: UUID, val message: String)

data class TestAiResponseCommand(@TargetAggregateIdentifier val id: UUID, val message: String)
data class TestAiResponseEvent(val id: UUID, val message: String)

@Aggregate
class TestAggregate {
    @AggregateIdentifier
    private lateinit var id: UUID

    constructor() // Required no-args constructor for Axon
    
    @CommandHandler
    @CreationPolicy(AggregateCreationPolicy.ALWAYS)
    fun handle(command: TestCommand) {
        println("Handling command")
        AggregateLifecycle.apply(TestEvent(command.id, command.message))
    }

    @CommandHandler
    fun handle(command: TestAiResponseCommand) {
        println("Handling ai response command")
        AggregateLifecycle.apply(TestAiResponseEvent(command.id, command.message))
    }

    @EventSourcingHandler
    fun on(event: TestEvent) {
        id = event.id
    }
}

@Component
class Processor(
    private val openAiChatModel: OpenAiChatModel,
    private val sharedFlux: SharedFlux,
    private val commandGateway: CommandGateway
) {
    @EventHandler
    fun on(event: TestEvent) {
        val prompt = Prompt(UserMessage(event.message), OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_O_MINI).build())
        val finalResponse = StringBuilder()
        openAiChatModel.stream(prompt)
            .doOnNext { response ->
                val content = response.result.output.content ?: ""
                val map = mapOf(event.id to content)
                sharedFlux.sink.emitNext(map, Sinks.EmitFailureHandler.FAIL_FAST)  // Emit each response chunk to the shared sink
                finalResponse.append(content)
            }
            .doOnComplete {
                sharedFlux.sink.tryEmitComplete()  // Complete the stream after processing is done
                println(finalResponse)
                commandGateway.send<TestAiResponseCommand>(TestAiResponseCommand(event.id, finalResponse.toString()))
            }
            .subscribe()

    }
}


data class TestQuery(val id: UUID)
data class TestReadModel(val message: String)

@Component
class SharedFlux {
    final val sink: Sinks.Many<Map<UUID, String>> = Sinks.many().multicast().onBackpressureBuffer()
    val flux = sink.asFlux()  // Shared Flux that emits values to subscribers
}

@RestController
@RequestMapping("/api")
class SharedFluxController(
    val queryGateway: QueryGateway,
    private val sharedFlux: SharedFlux,
    val commandGateway: CommandGateway,
    ) {
    private val id: UUID = UUID.randomUUID()

    @GetMapping("/generateStream", produces = [MediaType.TEXT_EVENT_STREAM_VALUE])
    fun generateStream(
        @RequestParam(
            value = "message",
            defaultValue = "Tell me a joke"
        ) message: String
    ): Flux<Map<UUID, String>> {
        println("Our id in generateStream is $id")
       commandGateway.send<TestCommand>(TestCommand(id, message))
        return sharedFlux.flux  // Return the shared flux for streaming to the client
    }
    
    @GetMapping("/finalResponse", produces = [MediaType.TEXT_EVENT_STREAM_VALUE])
    fun getFinalResponse(
    ): Flux<TestReadModel> {
        println("Our id in finalResponse is $id")
        val query=  queryGateway.subscriptionQuery(TestQuery(id), TestReadModel::class.java, TestReadModel::class.java)
        return query.initialResult()
            .concatWith(query.updates())

    }

    @GetMapping("/liveResponse", produces = [MediaType.TEXT_EVENT_STREAM_VALUE])
    fun getLiveResponse(
    ): Flux<String> {
        println("Our id in liveResponse is $id")
        return sharedFlux.flux.filter { it.containsKey(id) }.mapNotNull { it[id] }
    }
}

@Component
class QueryHandler(
    private val eventStore: EventStore,
    private val queryUpdateEmitter: QueryUpdateEmitter
) {
    @QueryHandler
    fun handle(query: TestQuery): TestReadModel {
        val events = eventStore
            .readEvents(query.id.toString())
            .asStream()
            .filter { it.payload is TestAiResponseEvent }
            .map {
                it.payload as TestAiResponseEvent
            }.toList()
        if (events.isEmpty()) {
            return TestReadModel("No response yet")
        }
        return TestReadModel(events.last().message)
    }

    @EventHandler
    fun on(event: TestAiResponseEvent) {
        queryUpdateEmitter.emit(
            TestQuery::class.java,
            { query -> query.id == event.id },
            TestReadModel(event.message)
        )
    }
}

Did any thoughts pop up?