diff --git a/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java b/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java index 53058ce6f979f8674426e7027e448be07ab84775..23776fa4167ebf007e4555a8a37a734099dd91e7 100644 --- a/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java +++ b/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java @@ -2,6 +2,7 @@ package cn.bugstack.chatglm; import cn.bugstack.chatglm.model.ChatCompletionRequest; import cn.bugstack.chatglm.model.ChatCompletionResponse; +import cn.bugstack.chatglm.model.ChatCompletionSyncResponse; import io.reactivex.Single; import retrofit2.http.Body; import retrofit2.http.POST; @@ -16,8 +17,12 @@ import retrofit2.http.Path; public interface IOpenAiApi { String v3_completions = "api/paas/v3/model-api/{model}/sse-invoke"; + String v3_completions_sync = "api/paas/v3/model-api/{model}/invoke"; @POST(v3_completions) Single completions(@Path("model") String model, @Body ChatCompletionRequest chatCompletionRequest); + @POST(v3_completions_sync) + Single completions(@Body ChatCompletionRequest chatCompletionRequest); + } diff --git a/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java b/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java index 27c96e06f94d1b106b2ba9e244eba7ef07b75cfb..2c58ffea5ee33047a741aaf667b89a1f3ad2df36 100644 --- a/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java +++ b/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java @@ -30,14 +30,13 @@ public class OpenAiHTTPInterceptor implements Interceptor { public @NotNull Response intercept(Chain chain) throws IOException { // 1. 获取原始 Request Request original = chain.request(); - // 2. 构建请求 Request request = original.newBuilder() .url(original.url()) .header("Authorization", "Bearer " + BearerTokenUtils.getToken(configuration.getApiKey(), configuration.getApiSecret())) .header("Content-Type", Configuration.JSON_CONTENT_TYPE) .header("User-Agent", Configuration.DEFAULT_USER_AGENT) - .header("Accept", Configuration.SSE_CONTENT_TYPE) + .header("Accept", null != original.header("Accept") ? original.header("Accept") : Configuration.SSE_CONTENT_TYPE) .method(original.method(), original.body()) .build(); diff --git a/src/main/java/cn/bugstack/chatglm/model/ChatCompletionSyncResponse.java b/src/main/java/cn/bugstack/chatglm/model/ChatCompletionSyncResponse.java new file mode 100644 index 0000000000000000000000000000000000000000..485f97258276dfe568e62f72e8a294ca36237970 --- /dev/null +++ b/src/main/java/cn/bugstack/chatglm/model/ChatCompletionSyncResponse.java @@ -0,0 +1,42 @@ +package cn.bugstack.chatglm.model; + +import lombok.Data; + +import java.util.List; + +/** + * 同步调用响应 + * @author max + * @date 2023/12/14 15:41 + */ +@Data +public class ChatCompletionSyncResponse { + + private Integer code; + private String msg; + private Boolean success; + private ChatGLMData data; + + @Data + public static class ChatGLMData { + private List choices; + private String task_status; + private Usage usage; + private String task_id; + private String request_id; + } + + @Data + public static class Usage { + private int completion_tokens; + private int prompt_tokens; + private int total_tokens; + } + + @Data + public static class Choice { + private String role; + private String content; + } + +} diff --git a/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java b/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java index 78c63095796f31f868417fcbcbac14ecfc50bc6a..88b25954f82f7e16fdcf6e3060b5523ae0129515 100644 --- a/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java +++ b/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java @@ -2,10 +2,12 @@ package cn.bugstack.chatglm.session; import cn.bugstack.chatglm.model.ChatCompletionRequest; import cn.bugstack.chatglm.model.ChatCompletionResponse; +import cn.bugstack.chatglm.model.ChatCompletionSyncResponse; import com.fasterxml.jackson.core.JsonProcessingException; import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; +import java.io.IOException; import java.util.concurrent.CompletableFuture; /** @@ -20,4 +22,6 @@ public interface OpenAiSession { CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException; + ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException; + } diff --git a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java index ae546c2acf49e40d4e23b7c029709353afd1b6bc..687e408319b308ff4266d0e259f71bece3d1526e 100644 --- a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java +++ b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java @@ -3,20 +3,19 @@ package cn.bugstack.chatglm.session.defaults; import cn.bugstack.chatglm.IOpenAiApi; import cn.bugstack.chatglm.model.ChatCompletionRequest; import cn.bugstack.chatglm.model.ChatCompletionResponse; +import cn.bugstack.chatglm.model.ChatCompletionSyncResponse; import cn.bugstack.chatglm.model.EventType; import cn.bugstack.chatglm.session.Configuration; import cn.bugstack.chatglm.session.OpenAiSession; import com.alibaba.fastjson.JSON; import com.fasterxml.jackson.core.JsonProcessingException; import lombok.extern.slf4j.Slf4j; -import okhttp3.MediaType; -import okhttp3.Request; -import okhttp3.RequestBody; -import okhttp3.Response; +import okhttp3.*; import okhttp3.sse.EventSource; import okhttp3.sse.EventSourceListener; import org.jetbrains.annotations.Nullable; +import java.io.IOException; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; @@ -94,4 +93,20 @@ public class DefaultOpenAiSession implements OpenAiSession { return future; } + @Override + public ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException { + // 构建请求信息 + Request request = new Request.Builder() + .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions_sync).replace("{model}", chatCompletionRequest.getModel().getCode())) + .header("Accept",Configuration.APPLICATION_JSON) + .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString())) + .build(); + OkHttpClient okHttpClient = configuration.getOkHttpClient(); + Response response = okHttpClient.newCall(request).execute(); + if(!response.isSuccessful()){ + new RuntimeException("Request failed"); + } + return JSON.parseObject(response.body().string(),ChatCompletionSyncResponse.class); + } + } diff --git a/src/test/java/cn/bugstack/chatglm/test/ApiTest.java b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java index 13da146339592969cdcf41c35dc1167f288b2bdb..963e939c2e443c813d3dda76fe090a075f6e8024 100644 --- a/src/test/java/cn/bugstack/chatglm/test/ApiTest.java +++ b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java @@ -17,6 +17,7 @@ import org.jetbrains.annotations.Nullable; import org.junit.Before; import org.junit.Test; +import java.io.IOException; import java.util.ArrayList; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; @@ -119,7 +120,7 @@ public class ApiTest { public void test_completions_future() throws ExecutionException, InterruptedException { // 入参;模型、请求信息 ChatCompletionRequest request = new ChatCompletionRequest(); - request.setModel(Model.CHATGLM_LITE); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro + request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro request.setPrompt(new ArrayList() { private static final long serialVersionUID = -7988151926241837899L; @@ -137,6 +138,30 @@ public class ApiTest { log.info("测试结果:{}", response); } + /** + * 同步请求 + */ + @Test + public void test_completions_sync() throws IOException { + // 入参;模型、请求信息 + ChatCompletionRequest request = new ChatCompletionRequest(); + request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro + request.setPrompt(new ArrayList() { + private static final long serialVersionUID = -7988151926241837899L; + + { + add(ChatCompletionRequest.Prompt.builder() + .role(Role.user.getCode()) + .content("写个java冒泡排序") + .build()); + } + }); + + ChatCompletionSyncResponse response = openAiSession.completionsSync(request); + + log.info("测试结果:{}", JSON.toJSONString(response)); + } + @Test public void test_curl() {