提交 07f0aad3 编写于 作者: 小傅哥's avatar 小傅哥

feat:增加同步请求,便于有写场景需要一次处理结果

上级 8699cadf
package cn.bugstack.chatglm.session; package cn.bugstack.chatglm.session;
import cn.bugstack.chatglm.model.ChatCompletionRequest; import cn.bugstack.chatglm.model.ChatCompletionRequest;
import cn.bugstack.chatglm.model.ChatCompletionResponse;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import okhttp3.sse.EventSource; import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener; import okhttp3.sse.EventSourceListener;
import java.util.concurrent.CompletableFuture;
/** /**
* @author 小傅哥,微信:fustack * @author 小傅哥,微信:fustack
* @description 会话服务接口 * @description 会话服务接口
...@@ -15,4 +18,6 @@ public interface OpenAiSession { ...@@ -15,4 +18,6 @@ public interface OpenAiSession {
EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws JsonProcessingException; EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws JsonProcessingException;
CompletableFuture<String> completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException;
} }
...@@ -2,14 +2,23 @@ package cn.bugstack.chatglm.session.defaults; ...@@ -2,14 +2,23 @@ package cn.bugstack.chatglm.session.defaults;
import cn.bugstack.chatglm.IOpenAiApi; import cn.bugstack.chatglm.IOpenAiApi;
import cn.bugstack.chatglm.model.ChatCompletionRequest; import cn.bugstack.chatglm.model.ChatCompletionRequest;
import cn.bugstack.chatglm.model.ChatCompletionResponse;
import cn.bugstack.chatglm.model.EventType;
import cn.bugstack.chatglm.session.Configuration; import cn.bugstack.chatglm.session.Configuration;
import cn.bugstack.chatglm.session.OpenAiSession; import cn.bugstack.chatglm.session.OpenAiSession;
import com.alibaba.fastjson.JSON;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType; import okhttp3.MediaType;
import okhttp3.Request; import okhttp3.Request;
import okhttp3.RequestBody; import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.sse.EventSource; import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener; import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.Nullable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
/** /**
* @author 小傅哥,微信:fustack * @author 小傅哥,微信:fustack
...@@ -17,6 +26,7 @@ import okhttp3.sse.EventSourceListener; ...@@ -17,6 +26,7 @@ import okhttp3.sse.EventSourceListener;
* @github https://github.com/fuzhengwei * @github https://github.com/fuzhengwei
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获! * @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/ */
@Slf4j
public class DefaultOpenAiSession implements OpenAiSession { public class DefaultOpenAiSession implements OpenAiSession {
/** /**
...@@ -28,11 +38,8 @@ public class DefaultOpenAiSession implements OpenAiSession { ...@@ -28,11 +38,8 @@ public class DefaultOpenAiSession implements OpenAiSession {
*/ */
private final EventSource.Factory factory; private final EventSource.Factory factory;
private IOpenAiApi openAiApi;
public DefaultOpenAiSession(Configuration configuration) { public DefaultOpenAiSession(Configuration configuration) {
this.configuration = configuration; this.configuration = configuration;
this.openAiApi = configuration.getOpenAiApi();
this.factory = configuration.createRequestFactory(); this.factory = configuration.createRequestFactory();
} }
...@@ -48,4 +55,43 @@ public class DefaultOpenAiSession implements OpenAiSession { ...@@ -48,4 +55,43 @@ public class DefaultOpenAiSession implements OpenAiSession {
return factory.newEventSource(request, eventSourceListener); return factory.newEventSource(request, eventSourceListener);
} }
@Override
public CompletableFuture<String> completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException {
// 用于执行异步任务并获取结果
CompletableFuture<String> future = new CompletableFuture<>();
StringBuffer dataBuffer = new StringBuffer();
// 构建请求信息
Request request = new Request.Builder()
.url(configuration.getApiHost().concat(IOpenAiApi.v3_completions).replace("{model}", chatCompletionRequest.getModel().getCode()))
.post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
.build();
// 异步响应请求
factory.newEventSource(request, new EventSourceListener() {
@Override
public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
// type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断
if (EventType.add.getCode().equals(type)) {
dataBuffer.append(response.getData());
} else if (EventType.finish.getCode().equals(type)) {
future.complete(dataBuffer.toString());
}
}
@Override
public void onClosed(EventSource eventSource) {
future.completeExceptionally(new RuntimeException("Request closed before completion"));
}
@Override
public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
future.completeExceptionally(new RuntimeException("Request closed before completion"));
}
});
return future;
}
} }
...@@ -17,7 +17,9 @@ import org.junit.Before; ...@@ -17,7 +17,9 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
/** /**
* @author 小傅哥,微信:fustack * @author 小傅哥,微信:fustack
...@@ -108,6 +110,32 @@ public class ApiTest { ...@@ -108,6 +110,32 @@ public class ApiTest {
new CountDownLatch(1).await(); new CountDownLatch(1).await();
} }
/**
* 同步请求
*/
@Test
public void test_completions_future() throws ExecutionException, InterruptedException {
// 入参;模型、请求信息
ChatCompletionRequest request = new ChatCompletionRequest();
request.setModel(Model.CHATGLM_LITE); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
request.setPrompt(new ArrayList<ChatCompletionRequest.Prompt>() {
private static final long serialVersionUID = -7988151926241837899L;
{
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
.content("写个java冒泡排序")
.build());
}
});
CompletableFuture<String> future = openAiSession.completions(request);
String response = future.get();
log.info("测试结果:{}", response);
}
@Test @Test
public void test_curl() { public void test_curl() {
// 1. 配置文件 // 1. 配置文件
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册