From 07f0aad301507802d4ef02a0dd2214060565575d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=B0=8F=E5=82=85=E5=93=A5?= <184172133@qq.com> Date: Fri, 10 Nov 2023 21:48:56 +0800 Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E5=A2=9E=E5=8A=A0=E5=90=8C?= =?UTF-8?q?=E6=AD=A5=E8=AF=B7=E6=B1=82=EF=BC=8C=E4=BE=BF=E4=BA=8E=E6=9C=89?= =?UTF-8?q?=E5=86=99=E5=9C=BA=E6=99=AF=E9=9C=80=E8=A6=81=E4=B8=80=E6=AC=A1?= =?UTF-8?q?=E5=A4=84=E7=90=86=E7=BB=93=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../chatglm/session/OpenAiSession.java | 5 ++ .../defaults/DefaultOpenAiSession.java | 52 +++++++++++++++++-- .../cn/bugstack/chatglm/test/ApiTest.java | 28 ++++++++++ 3 files changed, 82 insertions(+), 3 deletions(-) diff --git a/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java b/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java index 86cd99d..78c6309 100644 --- a/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java +++ b/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java @@ -1,10 +1,13 @@ package cn.bugstack.chatglm.session; import cn.bugstack.chatglm.model.ChatCompletionRequest; +import cn.bugstack.chatglm.model.ChatCompletionResponse; import com.fasterxml.jackson.core.JsonProcessingException; import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; +import java.util.concurrent.CompletableFuture; + /** * @author 小傅哥,微信:fustack * @description 会话服务接口 @@ -15,4 +18,6 @@ public interface OpenAiSession { EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws JsonProcessingException; + CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException; + } diff --git a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java index 97578ad..ae546c2 100644 --- a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java +++ b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java @@ -2,14 +2,23 @@ package cn.bugstack.chatglm.session.defaults; import cn.bugstack.chatglm.IOpenAiApi; import cn.bugstack.chatglm.model.ChatCompletionRequest; +import cn.bugstack.chatglm.model.ChatCompletionResponse; +import cn.bugstack.chatglm.model.EventType; import cn.bugstack.chatglm.session.Configuration; import cn.bugstack.chatglm.session.OpenAiSession; +import com.alibaba.fastjson.JSON; import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.extern.slf4j.Slf4j; import okhttp3.MediaType; import okhttp3.Request; import okhttp3.RequestBody; +import okhttp3.Response; import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; +import org.jetbrains.annotations.Nullable; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; /** * @author 小傅哥,微信:fustack @@ -17,6 +26,7 @@ import okhttp3.sse.EventSourceListener; * @github https://github.com/fuzhengwei * @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获! */ +@Slf4j public class DefaultOpenAiSession implements OpenAiSession { /** @@ -28,11 +38,8 @@ public class DefaultOpenAiSession implements OpenAiSession { */ private final EventSource.Factory factory; - private IOpenAiApi openAiApi; - public DefaultOpenAiSession(Configuration configuration) { this.configuration = configuration; - this.openAiApi = configuration.getOpenAiApi(); this.factory = configuration.createRequestFactory(); } @@ -48,4 +55,43 @@ public class DefaultOpenAiSession implements OpenAiSession { return factory.newEventSource(request, eventSourceListener); } + @Override + public CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException { + // 用于执行异步任务并获取结果 + CompletableFuture future = new CompletableFuture<>(); + StringBuffer dataBuffer = new StringBuffer(); + + // 构建请求信息 + Request request = new Request.Builder() + .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions).replace("{model}", chatCompletionRequest.getModel().getCode())) + .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString())) + .build(); + + // 异步响应请求 + factory.newEventSource(request, new EventSourceListener() { + @Override + public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) { + ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class); + // type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断 + if (EventType.add.getCode().equals(type)) { + dataBuffer.append(response.getData()); + } else if (EventType.finish.getCode().equals(type)) { + future.complete(dataBuffer.toString()); + } + } + + @Override + public void onClosed(EventSource eventSource) { + future.completeExceptionally(new RuntimeException("Request closed before completion")); + } + + @Override + public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) { + future.completeExceptionally(new RuntimeException("Request closed before completion")); + } + }); + + return future; + } + } diff --git a/src/test/java/cn/bugstack/chatglm/test/ApiTest.java b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java index e141178..1d996a1 100644 --- a/src/test/java/cn/bugstack/chatglm/test/ApiTest.java +++ b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java @@ -17,7 +17,9 @@ import org.junit.Before; import org.junit.Test; import java.util.ArrayList; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; /** * @author 小傅哥,微信:fustack @@ -108,6 +110,32 @@ public class ApiTest { new CountDownLatch(1).await(); } + /** + * 同步请求 + */ + @Test + public void test_completions_future() throws ExecutionException, InterruptedException { + // 入参;模型、请求信息 + ChatCompletionRequest request = new ChatCompletionRequest(); + request.setModel(Model.CHATGLM_LITE); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro + request.setPrompt(new ArrayList() { + private static final long serialVersionUID = -7988151926241837899L; + + { + add(ChatCompletionRequest.Prompt.builder() + .role(Role.user.getCode()) + .content("写个java冒泡排序") + .build()); + } + }); + + CompletableFuture future = openAiSession.completions(request); + String response = future.get(); + + log.info("测试结果:{}", response); + } + + @Test public void test_curl() { // 1. 配置文件 -- GitLab