diff --git a/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java b/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java index 86cd99d0c8fa293f8b31f37d926b759055343f21..78c63095796f31f868417fcbcbac14ecfc50bc6a 100644 --- a/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java +++ b/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java @@ -1,10 +1,13 @@ package cn.bugstack.chatglm.session; import cn.bugstack.chatglm.model.ChatCompletionRequest; +import cn.bugstack.chatglm.model.ChatCompletionResponse; import com.fasterxml.jackson.core.JsonProcessingException; import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; +import java.util.concurrent.CompletableFuture; + /** * @author 小傅哥,微信:fustack * @description 会话服务接口 @@ -15,4 +18,6 @@ public interface OpenAiSession { EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws JsonProcessingException; + CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException; + } diff --git a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java index 97578ad775e2ad57f7d304c5c6126886296f3a4b..ae546c2acf49e40d4e23b7c029709353afd1b6bc 100644 --- a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java +++ b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java @@ -2,14 +2,23 @@ package cn.bugstack.chatglm.session.defaults; import cn.bugstack.chatglm.IOpenAiApi; import cn.bugstack.chatglm.model.ChatCompletionRequest; +import cn.bugstack.chatglm.model.ChatCompletionResponse; +import cn.bugstack.chatglm.model.EventType; import cn.bugstack.chatglm.session.Configuration; import cn.bugstack.chatglm.session.OpenAiSession; +import com.alibaba.fastjson.JSON; import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.extern.slf4j.Slf4j; import okhttp3.MediaType; import okhttp3.Request; import okhttp3.RequestBody; +import okhttp3.Response; import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; +import org.jetbrains.annotations.Nullable; + +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CountDownLatch; /** * @author 小傅哥,微信:fustack @@ -17,6 +26,7 @@ import okhttp3.sse.EventSourceListener; * @github https://github.com/fuzhengwei * @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获! */ +@Slf4j public class DefaultOpenAiSession implements OpenAiSession { /** @@ -28,11 +38,8 @@ public class DefaultOpenAiSession implements OpenAiSession { */ private final EventSource.Factory factory; - private IOpenAiApi openAiApi; - public DefaultOpenAiSession(Configuration configuration) { this.configuration = configuration; - this.openAiApi = configuration.getOpenAiApi(); this.factory = configuration.createRequestFactory(); } @@ -48,4 +55,43 @@ public class DefaultOpenAiSession implements OpenAiSession { return factory.newEventSource(request, eventSourceListener); } + @Override + public CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException { + // 用于执行异步任务并获取结果 + CompletableFuture future = new CompletableFuture<>(); + StringBuffer dataBuffer = new StringBuffer(); + + // 构建请求信息 + Request request = new Request.Builder() + .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions).replace("{model}", chatCompletionRequest.getModel().getCode())) + .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString())) + .build(); + + // 异步响应请求 + factory.newEventSource(request, new EventSourceListener() { + @Override + public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) { + ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class); + // type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断 + if (EventType.add.getCode().equals(type)) { + dataBuffer.append(response.getData()); + } else if (EventType.finish.getCode().equals(type)) { + future.complete(dataBuffer.toString()); + } + } + + @Override + public void onClosed(EventSource eventSource) { + future.completeExceptionally(new RuntimeException("Request closed before completion")); + } + + @Override + public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) { + future.completeExceptionally(new RuntimeException("Request closed before completion")); + } + }); + + return future; + } + } diff --git a/src/test/java/cn/bugstack/chatglm/test/ApiTest.java b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java index e141178098aba9743d7c2d9a41061eb5ec8a6353..1d996a1cd8115f99133b5b3380a695360ced13b3 100644 --- a/src/test/java/cn/bugstack/chatglm/test/ApiTest.java +++ b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java @@ -17,7 +17,9 @@ import org.junit.Before; import org.junit.Test; import java.util.ArrayList; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; /** * @author 小傅哥,微信:fustack @@ -108,6 +110,32 @@ public class ApiTest { new CountDownLatch(1).await(); } + /** + * 同步请求 + */ + @Test + public void test_completions_future() throws ExecutionException, InterruptedException { + // 入参;模型、请求信息 + ChatCompletionRequest request = new ChatCompletionRequest(); + request.setModel(Model.CHATGLM_LITE); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro + request.setPrompt(new ArrayList() { + private static final long serialVersionUID = -7988151926241837899L; + + { + add(ChatCompletionRequest.Prompt.builder() + .role(Role.user.getCode()) + .content("写个java冒泡排序") + .build()); + } + }); + + CompletableFuture future = openAiSession.completions(request); + String response = future.get(); + + log.info("测试结果:{}", response); + } + + @Test public void test_curl() { // 1. 配置文件