qinfengge

qinfengge

醉后不知天在水,满船清梦压星河
github

spring AI (四) 连续对话

在之前的文章中,我们实现的都是简单的调用,只能实现 1 次对话。这既不符合现实也不优雅,哪有人对话只对一句的啊。除了下面那个

image

要是 AI 也只能对话一次,那我们也只能说 AI 大爷您先一边凉快去吧
那怎么让模型连续对话呢,重点是 记忆 把用户的提问记住,同时也把模型自己的输出也记住,这样,模型才能根据 上文 得出合理的回答。

Prompt#

还记得第一章说的吗

事实上,Prompt 的种类很多,玩法也很多样,不仅仅是提示词,同样也是多轮对话的关键。

在创建 Prompt 中可以看到其可以接收 2 种参数,一种是单条 message,还有就是 message 集合。

image
而 message 中的 MessageType 又有下面 4 中类型

image

这不就对上了

    USER("user"),  // 用户的输入

	ASSISTANT("assistant"), // 模型的输出

	SYSTEM("system"), // 模型的人设

	FUNCTION("function"); //函数

想象一下,在现实中你和某个人对话,你说一句他说一句,而且他的话一定要跟 前文 有对应关系,否则的话就是驴唇不对马嘴了。那么在模型上连续对话的关键和这个是一样的,即在每次对话时将前文传递给模型,使其理解对应的上下文关系。这就是message 集合的作用。

实现#

简单的原理以及说完了,接下来就直接开搞。不过我们不能忘记前面几篇文章所作出的努力,要融会贯通。所以接下来就实现一个完整的功能,包含流式输出函数调用连续对话

首先,初始化客户端

    private static final String BASEURL = "https://xxx";

    private static final String TOKEN = "sk-xxxx";

    /**
     * 创建OpenAiChatClient
     * @return OpenAiChatClient
     */
    private static OpenAiChatClient getClient(){
        OpenAiApi openAiApi = new OpenAiApi(BASEURL, TOKEN);
        return new OpenAiChatClient(openAiApi, OpenAiChatOptions.builder()
                .withModel("gpt-3.5-turbo-1106")
                .withTemperature(0.8F)
                .build());
    }

在创建这一步需要注意的是,OpenAI 的一些老模型是不支持函数调用和流式输出的,一些老模型的最大 token 也只有 4K。创建错误可能会导致 400 BAD REQUEST

接下来,保存历史信息
先创建一个 Map

private static Map<String, List<Message>> chatMessage = new ConcurrentHashMap<>();

这里 Map 的 key 对应的是会话 ID,value 就是历史消息了。注意会话一定要有对应的唯一 ID,不然的话就会串台了。
然后在每次会话时传入会话 ID 和用户的输入,将对应的输入放到消息集合里面

/**
     * 返回提示词
     * @param message 用户输入的消息
     * @return Prompt
     */
    private List<Message> getMessages(String id, String message) {
        String systemPrompt = "{prompt}";
        SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemPrompt);

        Message userMessage = new UserMessage(message);

        Message systemMessage = systemPromptTemplate.createMessage(MapUtil.of("prompt", "you are a helpful AI assistant"));

        List<Message> messages = chatMessage.get(id);


        // 如果未获取到消息,则创建新的消息并将系统提示和用户输入的消息添加到消息列表中
        if (messages == null){
            messages = new ArrayList<>();
            messages.add(systemMessage);
            messages.add(userMessage);
        } else {
            messages.add(userMessage);
        }

        return messages;
    }

这里如果是第一轮对话 message 列表是空的话,还会把 systemMessage 也放进去,相当于给模型初始化一个人设。

然后,创建函数

/**
     * 初始化函数调用
     * @return ChatOptions
     */
    private ChatOptions initFunc(){
        return OpenAiChatOptions.builder().withFunctionCallbacks(List.of(
                FunctionCallbackWrapper.builder(new MockWeatherService()).withName("weather").withDescription("Get the weather in location").build(),
                FunctionCallbackWrapper.builder(new WbHotService()).withName("wbHot").withDescription("Get the hot list of Weibo").build(),
                FunctionCallbackWrapper.builder(new TodayNews()).withName("todayNews").withDescription("60s watch world news").build(),
                FunctionCallbackWrapper.builder(new DailyEnglishFunc()).withName("dailyEnglish").withDescription("A daily inspirational sentence in English").build())).build();
    }

关于函数的相关信息,请查看第三章

最后,就是输出
这里说一下,因为最终的实现效果是网页,所以使用了服务端主动推送的功能,这里使用的是 SSE,关于 SSE 的介绍可以看之前博客里面写的消息推送
总之,下面是一个 SSE 的工具类

@Component
@Slf4j
public class SseEmitterUtils {
    /**
     * 当前连接数
     */
    private static AtomicInteger count = new AtomicInteger(0);

    /**
     * 存储 SseEmitter 信息
     */
    private static Map<String, SseEmitter> sseEmitterMap = new ConcurrentHashMap<>();

    /**
     * 创建用户连接并返回 SseEmitter
     * @param key userId
     * @return SseEmitter
     */
    public static SseEmitter connect(String key) {
        if (sseEmitterMap.containsKey(key)) {
            return sseEmitterMap.get(key);
        }

        try {
            // 设置超时时间,0表示不过期。默认30秒
            SseEmitter sseEmitter = new SseEmitter(0L);
            // 注册回调
            sseEmitter.onCompletion(completionCallBack(key));
            sseEmitter.onError(errorCallBack(key));
            sseEmitter.onTimeout(timeoutCallBack(key));
            sseEmitterMap.put(key, sseEmitter);
            // 数量+1
            count.getAndIncrement();
            return sseEmitter;
        } catch (Exception e) {
            log.info("创建新的SSE连接异常,当前连接Key为:{}", key);
        }
        return null;
    }

    /**
     * 给指定用户发送消息
     * @param key userId
     * @param message 消息内容
     */
    public static void sendMessage(String key, String message) {
        if (sseEmitterMap.containsKey(key)) {
            try {
                sseEmitterMap.get(key).send(message);
            } catch (IOException e) {
                log.error("用户[{}]推送异常:{}", key, e.getMessage());
                remove(key);
            }
        }
    }

    /**
     * 向同组人发布消息,要求:key + groupId
     * @param groupId 群组id
     * @param message 消息内容
     */
    public static void groupSendMessage(String groupId, String message) {
        if (!CollectionUtils.isEmpty(sseEmitterMap)) {
            sseEmitterMap.forEach((k, v) -> {
                try {
                    if (k.startsWith(groupId)) {
                        v.send(message, MediaType.APPLICATION_JSON);
                    }
                } catch (IOException e) {
                    log.error("用户[{}]推送异常:{}", k, e.getMessage());
                    remove(k);
                }
            });
        }
    }

    /**
     * 广播群发消息
     * @param message 消息内容
     */
    public static void batchSendMessage(String message) {
        sseEmitterMap.forEach((k, v) -> {
            try {
                v.send(message, MediaType.APPLICATION_JSON);
            } catch (IOException e) {
                log.error("用户[{}]推送异常:{}", k, e.getMessage());
                remove(k);
            }
        });
    }

    /**
     * 群发消息
     * @param message 消息内容
     * @param ids 用户id集合
     */
    public static void batchSendMessage(String message, Set<String> ids) {
        ids.forEach(userId -> sendMessage(userId, message));
    }

    /**
     * 移除连接
     * @param key userId
     */
    public static void remove(String key) {
        sseEmitterMap.remove(key);
        // 数量-1
        count.getAndDecrement();
        log.info("移除连接:{}", key);
    }

    /**
     * 获取当前连接信息
     * @return Map
     */
    public static List<String> getIds() {
        return new ArrayList<>(sseEmitterMap.keySet());
    }

    /**
     * 获取当前连接数量
     * @return int
     */
    public static int getCount() {
        return count.intValue();
    }

    private static Runnable completionCallBack(String key) {
        return () -> {
            log.info("结束连接:{}", key);
            remove(key);
        };
    }

    private static Runnable timeoutCallBack(String key) {
        return () -> {
            log.info("连接超时:{}", key);
            remove(key);
        };
    }

    private static Consumer<Throwable> errorCallBack(String key) {
        return throwable -> {
            log.info("连接异常:{}", key);
            remove(key);
        };
    }
}

接下来就可以对话了吗,不,如果是网页,我们还要思考具体的实现。主流的 AI 模型的网页端主要有这 2 个方面

  1. 快速提问,用户可以在首页直接提问
  2. 保存对话信息,每轮对话都是唯一的,用户可随时返回某一个对话

具体实现就是

  1. 创建一个接口,用户访问首页时调用此接口并返回会话 ID
  2. 接下来的用户输入都绑定在第一步返回的会话 ID 上,除非刷新浏览器或创建新的会话

那么第一步就是

 /**
     * 创建连接
     */
    @SneakyThrows
    @GetMapping("/init/{message}")
    public String init() {
        return String.valueOf(UUID.randomUUID());
    }

直接返回 UUID 给前端

最后有了会话 ID 就可以绑定到会话并输出了。

@GetMapping("chat/{id}/{message}")
    public SseEmitter chat(@PathVariable String id, @PathVariable String message, HttpServletResponse response) {

        response.setHeader("Content-type", "text/html;charset=UTF-8");
        response.setCharacterEncoding("UTF-8");

        OpenAiChatClient client = getClient();
        SseEmitter emitter = SseEmitterUtils.connect(id);
        List<Message> messages = getMessages(id, message);
        System.err.println("chatMessage大小: " + messages.size());
        System.err.println("chatMessage: " + chatMessage);

        if (messages.size() > MAX_MESSAGE){
            SseEmitterUtils.sendMessage(id, "对话次数过多,请稍后重试🤔");
        }else {
            // 获取模型的输出流
            Flux<ChatResponse> stream = client.stream(new Prompt(messages,initFunc()));

            // 把流里面的消息使用SSE发送
            Mono<String> result = stream
                    .flatMap(it -> {
                        StringBuilder sb = new StringBuilder();
                        String content = it.getResult().getOutput().getContent();
                        Optional.ofNullable(content).ifPresent(r -> {
                            SseEmitterUtils.sendMessage(id, content);
                            sb.append(content);
                        });
                        return Mono.just(sb.toString());
                    })
                    // 将消息拼接成字符串
                    .reduce((a, b) -> a + b)
                    .defaultIfEmpty("");

            // 将消息存储到chatMessage中的AssistantMessage
            result.subscribe(finalContent -> messages.add(new AssistantMessage(finalContent)));

            // 将消息存储到chatMessage中
            chatMessage.put(id, messages);

        }
        return emitter;

    }

首先使用 response 将返回编码设置为 UTF-8 防止乱码
然后使用 SseEmitterUtils连接到对应的会话
接着使用 getMessages 返回获取对应会话的历史消息
然后使用 MAX_MESSAGE 对会话轮数进行了判断,如果大于这里的值则不再调用模型输出,这里主要是降低成本

private static final Integer MAX_MESSAGE = 10;

这里写的是 10 轮,其实是 5 轮对话,因为是用历史消息的 size 判断的,而历史消息里面是包含用户的输入和模型的输出的,所以要除以 2.

chatMessage: {e2578f9e-8d71-4531-a6af-400a80fb6569=[SystemMessage{content='you are a helpful AI assistant', properties={}, messageType=SYSTEM}, UserMessage{content='你好呀', properties={}, messageType=USER}, AssistantMessage{content='你好!需要我的帮助吗?', properties={}, messageType=ASSISTANT}, UserMessage{content='你是谁啊', properties={}, messageType=USER}]}

最后就是模型的输出

Flux<ChatResponse> stream = client.stream(new Prompt(messages,initFunc()));

使用了 stream 流,然后 Prompt 传入历史消息和函数
获取到输出流后使用 SseEmitterUtils.sendMessage(id, content); 把流里面的内容发送到对应的会话啦。

还有最后一步,我们要把模型的输出也放到历史消息里面,要让模型知道以及回答过的不用再次回答。
如果不放进去,那么模型会把用户前面的所有输入全部回答。
例如,第一轮问 “介绍下杭州”,此时 AI 的回答是正常的
第二轮接着问 “杭州有哪些著名的景点”,此时 AI 不知道上一轮是否回答过了,所以它倾向于同时回答这 2 个问题,即 “介绍下杭州,杭州有哪些著名的景点”。
第三轮、第四轮同样会同时回答。

那么怎么从流里面获取完整的输出呢?
最开始我使用 stream.subscribeStringBuilder 将流里面的内容追加到 sb 里面,但是 sb 总是为 null。问了 claude 才知道 Flux 是异步的,最后使用了 Mono 进行处理。

在这段代码中,我们在 flatMap 的回调函数中创建了一个新的 StringBuilder 实例 sb。然后,我们将每个响应的内容追加到 sb 中,并返回一个 Mono 发射 sb.toString () 的结果。
接下来,我们使用 reduce 操作符将所有的 Mono 合并成一个 Mono。reduce 的参数是一个合并函数,它将前一个值和当前值合并成一个新值。在这里,我们使用字符串连接操作将所有响应内容拼接在一起。
最后,我们订阅这个 Mono, 并在其回调函数中将最终内容添加到 messages 中。如果没有任何响应,我们使用 defaultIfEmpty ("") 确保发射一个空字符串,而不是 null。
通过这种方式,我们可以正确地获取到流式响应的全部内容,并将其添加到 messages 中。

最后就大功告成了😎

哦,忘了,还差一个前端,但是由于我对前端不甚了解。
所以我选择使用openui这个 AI 工具来帮我写页面和样式,然后用 Claude 帮我写接口逻辑。于是,我最终得到了这个

<!doctype html>
<html>

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <script src="https://cdn.tailwindcss.com"></script>
</head>

<body class="bg-zinc-100 dark:bg-zinc-800 min-h-screen p-4">
    <div class="flex flex-col h-full">
        <div id="messages" class="flex-1 overflow-y-auto p-4 space-y-4">
            <div class="flex items-end">
                <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
                <div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs">嗨~(⁄ ⁄•⁄ω⁄•⁄ ⁄)⁄</div>
            </div>
        </div>
        <div class="p-2">
            <input type="text" id="messageInput" placeholder="请输入消息..."
                class="w-full p-2 rounded-lg border-2 border-zinc-300 dark:border-zinc-600 focus:outline-none focus:border-blue-500 dark:focus:border-blue-400">
            <button onclick="sendMessage()"
                class="mt-2 w-full bg-blue-500 hover:bg-blue-600 dark:bg-blue-600 dark:hover:bg-blue-700 text-white p-2 rounded-lg">发送</button>
        </div>
    </div>
    <script>
        let sessionId; // 用于存储会话 ID

        // 发送 HTTP 请求并处理响应
        function sendHTTPRequest(url, method = 'GET', body = null) {
            return new Promise((resolve, reject) => {
                const xhr = new XMLHttpRequest();
                xhr.open(method, url, true);
                xhr.onload = () => {
                    if (xhr.status >= 200 && xhr.status < 300) {
                        resolve(xhr.response);
                    } else {
                        reject(xhr.statusText);
                    }
                };
                xhr.onerror = () => reject(xhr.statusText);
                if (body) {
                    xhr.setRequestHeader('Content-Type', 'application/json');
                    xhr.send(JSON.stringify(body));
                } else {
                    xhr.send();
                }
            });
        }

        // 处理服务器返回的 SSE 流
        function handleSSEStream(stream) {
            console.log('Stream started');
            console.log(stream);
            const messagesContainer = document.getElementById('messages');
            const responseDiv = document.createElement('div');
            responseDiv.className = 'flex items-end';
            responseDiv.innerHTML = `
    <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
    <div class="ml-2 p-2 bg-white dark:bg-zinc-700 rounded-lg max-w-xs"></div>
  `;
            messagesContainer.appendChild(responseDiv);

            const messageContentDiv = responseDiv.querySelector('div');

            // 监听 'message' 事件,当后端发送新的数据时触发
            stream.onmessage = function (event) {
                const data = event.data;
                console.log('Received data:', data);
                messageContentDiv.textContent += data;
                messagesContainer.scrollTop = messagesContainer.scrollHeight;
            };
        }

        // 发送消息
        function sendMessage() {
            const input = document.getElementById('messageInput');
            const message = input.value.trim();
            if (message) {
                const messagesContainer = document.getElementById('messages');
                const newMessageDiv = document.createElement('div');
                newMessageDiv.className = 'flex items-end justify-end';
                newMessageDiv.innerHTML = `
          <div class="mr-2 p-2 bg-green-200 dark:bg-green-700 rounded-lg max-w-xs">
            ${message}
          </div>
          <img src="https://placehold.co/40x40" alt="avatar" class="rounded-full">
        `;
                messagesContainer.appendChild(newMessageDiv);
                input.value = '';
                messagesContainer.scrollTop = messagesContainer.scrollHeight;

                // 第一次发送消息时,发送 init 请求获取会话 ID
                if (!this.sessionId) {
                    console.log('init');
                    sendHTTPRequest(`http://127.0.0.1:8868/pro/init/${message}`, 'GET')
                        .then(response => {
                            this.sessionId = response; // 存储会话 ID
                            return handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
                        });

                } else {
                    // 之后的请求直接发送到 chat 接口
                    handleSSEStream(new EventSource(`http://127.0.0.1:8868/pro/chat/${this.sessionId}/${message}`))
                }
            }
        }
    </script>
</body>

</html>

最终效果#

image

PS: 其实前端可以再优化一下,比如显示历史会话,使用 markdown 渲染输出等。有兴趣的可以使用 AI 工具修改下。

Spring AI 连续对话

Loading...
Ownership of this post data is guaranteed by blockchain and smart contracts to the creator alone.