diff --git a/pom.xml b/pom.xml
index 1b0609a16801464206752671f3200ff22d597d5d..1976f50248764a550d17e7b2e4bb80a8c4801842 100644
--- a/pom.xml
+++ b/pom.xml
@@ -6,7 +6,7 @@
cn.bugstack
chatglm-sdk-java
- 1.1
+ 2.0
chatglm-sdk-java
OpenAI Java SDK, ZhiPuAi ChatGLM Java SDK . Copyright © 2023 bugstack虫洞栈 All rights reserved. 版权所有(C)小傅哥 https://github.com/fuzhengwei/chatglm-sdk-java
@@ -134,6 +134,18 @@
provided
true
+
+
+ org.apache.httpcomponents
+ httpclient
+ 4.5.14
+
+
+ org.apache.httpcomponents
+ httpmime
+ 4.5.10
+
+
diff --git a/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java b/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java
index 23776fa4167ebf007e4555a8a37a734099dd91e7..2fb309f9708a9923c88fccd38988a3c71ff88b27 100644
--- a/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java
+++ b/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java
@@ -1,8 +1,6 @@
package cn.bugstack.chatglm;
-import cn.bugstack.chatglm.model.ChatCompletionRequest;
-import cn.bugstack.chatglm.model.ChatCompletionResponse;
-import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
+import cn.bugstack.chatglm.model.*;
import io.reactivex.Single;
import retrofit2.http.Body;
import retrofit2.http.POST;
@@ -11,7 +9,7 @@ import retrofit2.http.Path;
/**
* @author 小傅哥,微信:fustack
* @description OpenAi 接口,用于扩展通用类服务
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
public interface IOpenAiApi {
@@ -25,4 +23,11 @@ public interface IOpenAiApi {
@POST(v3_completions_sync)
Single completions(@Body ChatCompletionRequest chatCompletionRequest);
+ String v4 = "api/paas/v4/chat/completions";
+
+ String cogview3 = "api/paas/v4/images/generations";
+
+ @POST(cogview3)
+ Single genImages(@Body ImageCompletionRequest imageCompletionRequest);
+
}
diff --git a/src/main/java/cn/bugstack/chatglm/executor/Executor.java b/src/main/java/cn/bugstack/chatglm/executor/Executor.java
new file mode 100644
index 0000000000000000000000000000000000000000..6124785b100cf217d1d747e99fa6557194f066d5
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/executor/Executor.java
@@ -0,0 +1,50 @@
+package cn.bugstack.chatglm.executor;
+
+import cn.bugstack.chatglm.model.ChatCompletionRequest;
+import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
+import cn.bugstack.chatglm.model.ImageCompletionRequest;
+import cn.bugstack.chatglm.model.ImageCompletionResponse;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSourceListener;
+
+import java.io.IOException;
+import java.util.concurrent.CompletableFuture;
+
+public interface Executor {
+
+ /**
+ * 问答模式,流式反馈
+ *
+ * @param chatCompletionRequest 请求信息
+ * @param eventSourceListener 实现监听;通过监听的 onEvent 方法接收数据
+ * @return 应答结果
+ * @throws Exception 异常
+ */
+ EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws Exception;
+
+ /**
+ * 问答模式,同步反馈 —— 用流式转化 Future
+ *
+ * @param chatCompletionRequest 请求信息
+ * @return 应答结果
+ */
+ CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException;
+
+ /**
+ * 同步应答接口
+ *
+ * @param chatCompletionRequest 请求信息
+ * @return ChatCompletionSyncResponse
+ * @throws IOException 异常
+ */
+ ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws Exception;
+
+ /**
+ * 图片生成接口
+ *
+ * @param request 请求信息
+ * @return 应答结果
+ */
+ ImageCompletionResponse genImages(ImageCompletionRequest request) throws Exception;
+
+}
diff --git a/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMExecutor.java b/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMExecutor.java
new file mode 100644
index 0000000000000000000000000000000000000000..4ec602333ea175016a4c66d609e4796c96ed1db8
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMExecutor.java
@@ -0,0 +1,178 @@
+package cn.bugstack.chatglm.executor.aigc;
+
+import cn.bugstack.chatglm.IOpenAiApi;
+import cn.bugstack.chatglm.executor.Executor;
+import cn.bugstack.chatglm.executor.result.ResultHandler;
+import cn.bugstack.chatglm.model.*;
+import cn.bugstack.chatglm.session.Configuration;
+import com.alibaba.fastjson.JSON;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.*;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSourceListener;
+import org.jetbrains.annotations.Nullable;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+
+/**
+ * 智谱AI 通用大模型 glm-3-turbo、glm-4 执行器
+ * https://open.bigmodel.cn/dev/api
+ */
+@Slf4j
+public class GLMExecutor implements Executor, ResultHandler {
+
+ /**
+ * OpenAi 接口
+ */
+ private final Configuration configuration;
+ /**
+ * 工厂事件
+ */
+ private final EventSource.Factory factory;
+ /**
+ * 统一接口
+ */
+ private IOpenAiApi openAiApi;
+
+ private OkHttpClient okHttpClient;
+
+ public GLMExecutor(Configuration configuration) {
+ this.configuration = configuration;
+ this.factory = configuration.createRequestFactory();
+ this.openAiApi = configuration.getOpenAiApi();
+ this.okHttpClient = configuration.getOkHttpClient();
+ }
+
+ @Override
+ public EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws Exception {
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v4))
+ .post(RequestBody.create(MediaType.parse(Configuration.JSON_CONTENT_TYPE), chatCompletionRequest.toString()))
+ .build();
+
+ // 返回事件结果
+ return factory.newEventSource(request, chatCompletionRequest.getIsCompatible() ? eventSourceListener(eventSourceListener) : eventSourceListener);
+ }
+
+ @Override
+ public CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException {
+ // 用于执行异步任务并获取结果
+ CompletableFuture future = new CompletableFuture<>();
+ StringBuffer dataBuffer = new StringBuffer();
+
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v4))
+ .post(RequestBody.create(MediaType.parse("application/json;charset=utf-8"), chatCompletionRequest.toString()))
+ .build();
+
+ factory.newEventSource(request, new EventSourceListener() {
+ @Override
+ public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
+ if ("[DONE]".equals(data)) {
+ log.info("[输出结束] Tokens {}", JSON.toJSONString(data));
+ return;
+ }
+
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ log.info("测试结果:{}", JSON.toJSONString(response));
+ List choices = response.getChoices();
+ for (ChatCompletionResponse.Choice choice : choices) {
+ if (!"stop".equals(choice.getFinishReason())) {
+ dataBuffer.append(choice.getDelta().getContent());
+ }
+ }
+
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource) {
+ future.complete(dataBuffer.toString());
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ future.completeExceptionally(new RuntimeException("Request closed before completion"));
+ }
+ });
+
+ return future;
+ }
+
+ @Override
+ public ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws Exception {
+ // sync 同步请求,stream 为 false
+ chatCompletionRequest.setStream(false);
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v4))
+ .post(RequestBody.create(MediaType.parse(Configuration.JSON_CONTENT_TYPE), chatCompletionRequest.toString()))
+ .build();
+ OkHttpClient okHttpClient = configuration.getOkHttpClient();
+ Response response = okHttpClient.newCall(request).execute();
+ if(!response.isSuccessful()){
+ throw new RuntimeException("Request failed");
+ }
+ return JSON.parseObject(response.body().string(),ChatCompletionSyncResponse.class);
+ }
+
+ /**
+ * 图片生成,注释的方式留作扩展使用
+ *
+ * Request request = new Request.Builder()
+ * .url(configuration.getApiHost().concat(IOpenAiApi.cogview3))
+ * .post(RequestBody.create(MediaType.parse(Configuration.JSON_CONTENT_TYPE), imageCompletionRequest.toString()))
+ * .build();
+ * // 请求结果
+ * Call call = okHttpClient.newCall(request);
+ * Response execute = call.execute();
+ * ResponseBody body = execute.body();
+ * if (execute.isSuccessful() && body != null) {
+ * String responseBody = body.string();
+ * ObjectMapper objectMapper = new ObjectMapper();
+ * return objectMapper.readValue(responseBody, ImageCompletionResponse.class);
+ * } else {
+ * throw new IOException("Failed to get image response");
+ * }
+ * @param imageCompletionRequest 请求信息
+ * @return
+ * @throws Exception
+ */
+ @Override
+ public ImageCompletionResponse genImages(ImageCompletionRequest imageCompletionRequest) throws Exception {
+ return openAiApi.genImages(imageCompletionRequest).blockingGet();
+ }
+
+ @Override
+ public EventSourceListener eventSourceListener(EventSourceListener eventSourceListener) {
+ return new EventSourceListener() {
+ @Override
+ public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
+ if ("[DONE]".equals(data)) {
+ return;
+ }
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ if (response.getChoices() != null && 1 == response.getChoices().size() && "stop".equals(response.getChoices().get(0).getFinishReason())) {
+ eventSourceListener.onEvent(eventSource, id, EventType.finish.getCode(), data);
+ return;
+ }
+ eventSourceListener.onEvent(eventSource, id, EventType.add.getCode(), data);
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource) {
+ eventSourceListener.onClosed(eventSource);
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ eventSourceListener.onFailure(eventSource, t, response);
+ }
+ };
+ }
+
+}
diff --git a/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMOldExecutor.java b/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMOldExecutor.java
new file mode 100644
index 0000000000000000000000000000000000000000..0c2e4211dd78ce2d0a024133519bc4fd1b29c992
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMOldExecutor.java
@@ -0,0 +1,108 @@
+package cn.bugstack.chatglm.executor.aigc;
+
+import cn.bugstack.chatglm.IOpenAiApi;
+import cn.bugstack.chatglm.executor.Executor;
+import cn.bugstack.chatglm.model.*;
+import cn.bugstack.chatglm.session.Configuration;
+import com.alibaba.fastjson.JSON;
+import okhttp3.*;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSourceListener;
+import org.jetbrains.annotations.Nullable;
+
+import java.io.IOException;
+import java.util.concurrent.CompletableFuture;
+
+/**
+ * 智谱AI旧版接口模型; chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
+ * https://open.bigmodel.cn/dev/api
+ */
+public class GLMOldExecutor implements Executor {
+
+ /**
+ * OpenAi 接口
+ */
+ private final Configuration configuration;
+ /**
+ * 工厂事件
+ */
+ private final EventSource.Factory factory;
+
+ public GLMOldExecutor(Configuration configuration) {
+ this.configuration = configuration;
+ this.factory = configuration.createRequestFactory();
+ }
+
+ @Override
+ public EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws Exception {
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions).replace("{model}", chatCompletionRequest.getModel().getCode()))
+ .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
+ .build();
+
+ // 返回事件结果
+ return factory.newEventSource(request, eventSourceListener);
+ }
+
+ @Override
+ public CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException {
+ // 用于执行异步任务并获取结果
+ CompletableFuture future = new CompletableFuture<>();
+ StringBuffer dataBuffer = new StringBuffer();
+
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions).replace("{model}", chatCompletionRequest.getModel().getCode()))
+ .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
+ .build();
+
+ // 异步响应请求
+ factory.newEventSource(request, new EventSourceListener() {
+ @Override
+ public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ // type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断
+ if (EventType.add.getCode().equals(type)) {
+ dataBuffer.append(response.getData());
+ } else if (EventType.finish.getCode().equals(type)) {
+ future.complete(dataBuffer.toString());
+ }
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource) {
+ future.completeExceptionally(new RuntimeException("Request closed before completion"));
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ future.completeExceptionally(new RuntimeException("Request closed before completion"));
+ }
+ });
+
+ return future;
+ }
+
+ @Override
+ public ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException {
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions_sync).replace("{model}", chatCompletionRequest.getModel().getCode()))
+ .header("Accept",Configuration.APPLICATION_JSON)
+ .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
+ .build();
+ OkHttpClient okHttpClient = configuration.getOkHttpClient();
+ Response response = okHttpClient.newCall(request).execute();
+ if(!response.isSuccessful()){
+ new RuntimeException("Request failed");
+ }
+ return JSON.parseObject(response.body().string(),ChatCompletionSyncResponse.class);
+ }
+
+ @Override
+ public ImageCompletionResponse genImages(ImageCompletionRequest request) {
+ throw new RuntimeException("旧版无图片生成接口");
+ }
+
+}
diff --git a/src/main/java/cn/bugstack/chatglm/executor/result/ResultHandler.java b/src/main/java/cn/bugstack/chatglm/executor/result/ResultHandler.java
new file mode 100644
index 0000000000000000000000000000000000000000..6abd7c5459d9f88787dab9374d134e85ec48be95
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/executor/result/ResultHandler.java
@@ -0,0 +1,12 @@
+package cn.bugstack.chatglm.executor.result;
+
+import okhttp3.sse.EventSourceListener;
+
+/**
+ * 结果封装器
+ */
+public interface ResultHandler {
+
+ EventSourceListener eventSourceListener(EventSourceListener eventSourceListener);
+
+}
diff --git a/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java b/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java
index 2c58ffea5ee33047a741aaf667b89a1f3ad2df36..378d2b348b993930b4b50d4a29ba8df816545510 100644
--- a/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java
+++ b/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java
@@ -12,7 +12,7 @@ import java.io.IOException;
/**
* @author 小傅哥,微信:fustack
* @description 接口拦截器
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
public class OpenAiHTTPInterceptor implements Interceptor {
@@ -36,7 +36,7 @@ public class OpenAiHTTPInterceptor implements Interceptor {
.header("Authorization", "Bearer " + BearerTokenUtils.getToken(configuration.getApiKey(), configuration.getApiSecret()))
.header("Content-Type", Configuration.JSON_CONTENT_TYPE)
.header("User-Agent", Configuration.DEFAULT_USER_AGENT)
- .header("Accept", null != original.header("Accept") ? original.header("Accept") : Configuration.SSE_CONTENT_TYPE)
+// .header("Accept", null != original.header("Accept") ? original.header("Accept") : Configuration.SSE_CONTENT_TYPE)
.method(original.method(), original.body())
.build();
diff --git a/src/main/java/cn/bugstack/chatglm/model/ChatCompletionRequest.java b/src/main/java/cn/bugstack/chatglm/model/ChatCompletionRequest.java
index af9302f919802120d70b0aee7b2b4ba547959138..1dc83d54c23785f88b1efd3393e193a4cd62c51f 100644
--- a/src/main/java/cn/bugstack/chatglm/model/ChatCompletionRequest.java
+++ b/src/main/java/cn/bugstack/chatglm/model/ChatCompletionRequest.java
@@ -1,5 +1,6 @@
package cn.bugstack.chatglm.model;
+import com.alibaba.fastjson.JSON;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
@@ -14,7 +15,7 @@ import java.util.Map;
/**
* @author 小傅哥,微信:fustack
* @description 请求参数
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Data
@@ -25,16 +26,36 @@ import java.util.Map;
@AllArgsConstructor
public class ChatCompletionRequest {
+ /**
+ * 是否对返回结果数据做兼容,24年1月发布的 GLM_3_5_TURBO、GLM_4 模型,与之前的模型在返回结果上有差异。开启 true 可以做兼容。
+ */
+ private Boolean isCompatible = true;
/**
* 模型
*/
- private Model model = Model.CHATGLM_6B_SSE;
-
+ private Model model = Model.GLM_3_5_TURBO;
+ /**
+ * 请求参数 {"role": "user", "content": "你好"}
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ private List messages;
/**
* 请求ID
*/
@JsonProperty("request_id")
private String requestId = String.format("xfg-%d", System.currentTimeMillis());
+ /**
+ * do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ @JsonProperty("do_sample")
+ private Boolean doSample = true;
+ /**
+ * 使用同步调用时,此参数应当设置为 Fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。
+ * 如果设置为 True,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data: [DONE]消息。
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ private Boolean stream = true;
/**
* 控制温度【随机性】
*/
@@ -44,6 +65,28 @@ public class ChatCompletionRequest {
*/
@JsonProperty("top_p")
private float topP = 0.7f;
+ /**
+ * 模型输出最大tokens
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ @JsonProperty("max_tokens")
+ private Integer maxTokens = 2048;
+ /**
+ * 模型在遇到stop所制定的字符时将停止生成,目前仅支持单个停止词,格式为["stop_word1"]
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ private List stop;
+ /**
+ * 可供模型调用的工具列表,tools字段会计算 tokens ,同样受到tokens长度的限制
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ private List tools;
+ /**
+ * 用于控制模型是如何选择要调用的函数,仅当工具类型为function时补充。默认为auto,当前仅支持auto。
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ @JsonProperty("tool_choice")
+ private String toolChoice = "auto";
/**
* 输入给模型的会话信息
* 用户输入的内容;role=user
@@ -60,24 +103,191 @@ public class ChatCompletionRequest {
private String sseFormat = "data";
@Data
- @Builder
@NoArgsConstructor
@AllArgsConstructor
public static class Prompt {
private String role;
private String content;
+
+ public static PromptBuilder builder() {
+ return new PromptBuilder();
+ }
+
+ public static class PromptBuilder {
+ private String role;
+ private String content;
+
+ PromptBuilder() {
+ }
+
+ public PromptBuilder role(String role) {
+ this.role = role;
+ return this;
+ }
+
+ public PromptBuilder content(String content) {
+ this.content = content;
+ return this;
+ }
+
+ public PromptBuilder content(Content content) {
+ this.content = JSON.toJSONString(content);
+ return this;
+ }
+
+ public Prompt build() {
+ return new Prompt(this.role, this.content);
+ }
+
+ public String toString() {
+ return "ChatCompletionRequest.Prompt.PromptBuilder(role=" + this.role + ", content=" + this.content + ")";
+ }
+ }
+
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class Content {
+ private String type = Type.text.code;
+ private String text;
+ @JsonProperty("image_url")
+ private ImageUrl imageUrl;
+
+ @Getter
+ @AllArgsConstructor
+ public static enum Type {
+ text("text", "文本"),
+ image_url("image_url", "图"),
+ ;
+ private final String code;
+ private final String info;
+ }
+
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class ImageUrl {
+ private String url;
+ }
+
+ }
+ }
+
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class Tool {
+ private Type type;
+ private Function function;
+ private Retrieval retrieval;
+ @JsonProperty("web_search")
+ private WebSearch webSearch;
+
+ public String getType() {
+ return type.code;
+ }
+
+ @Getter
+ @AllArgsConstructor
+ public static enum Type {
+ function("function", "函数功能"),
+ retrieval("retrieval", "知识库"),
+ web_search("web_search", "联网"),
+ ;
+ private final String code;
+ private final String info;
+ }
+
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class Function {
+ // 函数名称,只能包含a-z,A-Z,0-9,下划线和中横线。最大长度限制为64
+ private String name;
+ // 用于描述函数功能。模型会根据这段描述决定函数调用方式。
+ private String description;
+ // parameter 字段需要传入一个 Json Schema 对象,以准确地定义函数所接受的参数。https://open.bigmodel.cn/dev/api#glm-3-turbo
+ private Object parameters;
+ }
+
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class Retrieval {
+ // 当涉及到知识库ID时,请前往开放平台的知识库模块进行创建或获取。
+ @JsonProperty("knowledge_id")
+ private String knowledgeId;
+ // 请求模型时的知识库模板,默认模板:
+ @JsonProperty("prompt_template")
+ private String promptTemplate = "\"\"\"\n" +
+ "{{ knowledge}}\n" +
+ "\"\"\"\n" +
+ "中找问题\n" +
+ "\"\"\"\n" +
+ "{{question}}\n" +
+ "\"\"\"";
+ }
+
+ // 仅当工具类型为web_search时补充,如果tools中存在类型retrieval,此时web_search不生效。
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class WebSearch {
+ // 是否启用搜索,默认启用搜索 enable = true/false
+ private Boolean enable = true;
+ // 强制搜索自定义关键内容,此时模型会根据自定义搜索关键内容返回的结果作为背景知识来回答用户发起的对话。
+ @JsonProperty("search_query")
+ private String searchQuery;
+ }
+
+
}
@Override
public String toString() {
- Map paramsMap = new HashMap<>();
- paramsMap.put("request_id", requestId);
- paramsMap.put("prompt", prompt);
- paramsMap.put("incremental", incremental);
- paramsMap.put("temperature", temperature);
- paramsMap.put("top_p", topP);
- paramsMap.put("sseFormat", sseFormat);
try {
+ // 24年1月发布新模型后调整
+ if (Model.GLM_3_5_TURBO.equals(this.model) || Model.GLM_4.equals(this.model) || Model.GLM_4V.equals(this.model)) {
+ Map paramsMap = new HashMap<>();
+ paramsMap.put("model", this.model.getCode());
+ if (null == this.messages && null == this.prompt) {
+ throw new RuntimeException("One of messages or prompt must not be empty!");
+ }
+ paramsMap.put("messages", this.messages != null ? this.messages : this.prompt);
+ if (null != this.requestId) {
+ paramsMap.put("request_id", this.requestId);
+ }
+ if (null != this.doSample) {
+ paramsMap.put("do_sample", this.doSample);
+ }
+ paramsMap.put("stream", this.stream);
+ paramsMap.put("temperature", this.temperature);
+ paramsMap.put("top_p", this.topP);
+ paramsMap.put("max_tokens", this.maxTokens);
+ if (null != this.stop && this.stop.size() > 0) {
+ paramsMap.put("stop", this.stop);
+ }
+ if (null != this.tools && this.tools.size() > 0) {
+ paramsMap.put("tools", this.tools);
+ paramsMap.put("tool_choice", this.toolChoice);
+ }
+ return new ObjectMapper().writeValueAsString(paramsMap);
+ }
+
+ // 默认
+ Map paramsMap = new HashMap<>();
+ paramsMap.put("request_id", requestId);
+ paramsMap.put("prompt", prompt);
+ paramsMap.put("incremental", incremental);
+ paramsMap.put("temperature", temperature);
+ paramsMap.put("top_p", topP);
+ paramsMap.put("sseFormat", sseFormat);
return new ObjectMapper().writeValueAsString(paramsMap);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
diff --git a/src/main/java/cn/bugstack/chatglm/model/ChatCompletionResponse.java b/src/main/java/cn/bugstack/chatglm/model/ChatCompletionResponse.java
index 6245d800a94172551558b173ea7f3eb800cc92f1..29e7b945d05cf26fc183ecbe299ec4397bba0699 100644
--- a/src/main/java/cn/bugstack/chatglm/model/ChatCompletionResponse.java
+++ b/src/main/java/cn/bugstack/chatglm/model/ChatCompletionResponse.java
@@ -1,20 +1,74 @@
package cn.bugstack.chatglm.model;
+import com.alibaba.fastjson2.JSON;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.List;
/**
* @author 小傅哥,微信:fustack
* @description 返回结果
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Data
public class ChatCompletionResponse {
+ // 旧版获得的数据方式
private String data;
private String meta;
+ // 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ private String id;
+ private Long created;
+ private String model;
+ private List choices;
+ private Usage usage;
+
+ // 封装 setChoices 对 data 属性赋值,兼容旧版使用方式
+ public void setChoices(List choices) {
+ this.choices = choices;
+ for (Choice choice : choices) {
+ if ("stop".equals(choice.finishReason)) {
+ continue;
+ }
+ if (null == this.data) {
+ this.data = "";
+ }
+ this.data = this.data.concat(choice.getDelta().getContent());
+ }
+ }
+
+ // 封装 setChoices 对 meta 属性赋值,兼容旧版使用方式
+ public void setUsage(Usage usage) {
+ this.usage = usage;
+ if (null != usage) {
+ this.meta = JSON.toJSONString(Meta.builder().usage(usage).build());
+ }
+ }
+
+ @Data
+ public static class Choice {
+ private Long index;
+ @JsonProperty("finish_reason")
+ private String finishReason;
+ private Delta delta;
+ }
+
+ @Data
+ public static class Delta {
+ private String role;
+ private String content;
+ }
+
@Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
public static class Meta {
private String task_status;
private Usage usage;
diff --git a/src/main/java/cn/bugstack/chatglm/model/EventType.java b/src/main/java/cn/bugstack/chatglm/model/EventType.java
index 46fa9a0bdee3ace4e721aa42d1b4ec8d3ec15fd4..3e18c1d9f9675082ca74c7ffb6b78311d68837b4 100644
--- a/src/main/java/cn/bugstack/chatglm/model/EventType.java
+++ b/src/main/java/cn/bugstack/chatglm/model/EventType.java
@@ -6,7 +6,7 @@ import lombok.Getter;
/**
* @author 小傅哥,微信:fustack
* @description 消息类型 chatglm_lite
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Getter
diff --git a/src/main/java/cn/bugstack/chatglm/model/ImageCompletionRequest.java b/src/main/java/cn/bugstack/chatglm/model/ImageCompletionRequest.java
new file mode 100644
index 0000000000000000000000000000000000000000..94080bddec6b712e6092a825d7eb2f288152bbf5
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/model/ImageCompletionRequest.java
@@ -0,0 +1,55 @@
+package cn.bugstack.chatglm.model;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * CogView 根据用户的文字描述生成图像,使用同步调用方式请求接口
+ */
+@Data
+@Builder
+@NoArgsConstructor
+@AllArgsConstructor
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public class ImageCompletionRequest {
+
+ /**
+ * 模型;24年1月发布了 cogview-3 生成图片模型
+ */
+ private Model model = Model.COGVIEW_3;
+
+ /**
+ * 所需图像的文本描述
+ */
+ private String prompt;
+
+ public String getModel() {
+ return model.getCode();
+ }
+
+ public Model getModelEnum() {
+ return model;
+ }
+
+ @Override
+ public String toString() {
+ Map paramsMap = new HashMap<>();
+ paramsMap.put("model", model.getCode());
+ paramsMap.put("prompt", prompt);
+ try {
+ return new ObjectMapper().writeValueAsString(paramsMap);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+}
+
diff --git a/src/main/java/cn/bugstack/chatglm/model/ImageCompletionResponse.java b/src/main/java/cn/bugstack/chatglm/model/ImageCompletionResponse.java
new file mode 100644
index 0000000000000000000000000000000000000000..4ef1ac1c7f640f3d8081baf763ae21a31ed5fcf3
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/model/ImageCompletionResponse.java
@@ -0,0 +1,25 @@
+package cn.bugstack.chatglm.model;
+
+import lombok.Data;
+
+import java.util.List;
+
+/**
+ * CogView 根据用户的文字描述生成图像,使用同步调用方式请求接口
+ */
+@Data
+public class ImageCompletionResponse {
+
+ /**
+ * 请求创建时间,是以秒为单位的Unix时间戳。
+ */
+ private Long created;
+
+ private List data;
+
+ @Data
+ public static class Image{
+ private String url;
+ }
+
+}
diff --git a/src/main/java/cn/bugstack/chatglm/model/Model.java b/src/main/java/cn/bugstack/chatglm/model/Model.java
index d10bf013c9aa503c6095de88ca9dc37593719e3b..7df64e3cae373b67c9c2490eee05dc1c07c07c28 100644
--- a/src/main/java/cn/bugstack/chatglm/model/Model.java
+++ b/src/main/java/cn/bugstack/chatglm/model/Model.java
@@ -6,7 +6,7 @@ import lombok.Getter;
/**
* @author 小傅哥,微信:fustack
* @description 会话模型
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Getter
@@ -23,9 +23,13 @@ public enum Model {
CHATGLM_STD("chatglm_std", "适用于对知识量、推理能力、创造力要求较高的场景"),
@Deprecated
CHATGLM_PRO("chatglm_pro", "适用于对知识量、推理能力、创造力要求较高的场景"),
- /** 智谱AI最新模型 */
+ /** 智谱AI 23年06月发布 */
CHATGLM_TURBO("chatglm_turbo", "适用于对知识量、推理能力、创造力要求较高的场景"),
-
+ /** 智谱AI 24年01月发布 */
+ GLM_3_5_TURBO("glm-3-turbo","适用于对知识量、推理能力、创造力要求较高的场景"),
+ GLM_4("glm-4","适用于复杂的对话交互和深度内容创作设计的场景"),
+ GLM_4V("glm-4v","根据输入的自然语言指令和图像信息完成任务,推荐使用 SSE 或同步调用方式请求接口"),
+ COGVIEW_3("cogview-3","根据用户的文字描述生成图像,使用同步调用方式请求接口"),
;
private final String code;
private final String info;
diff --git a/src/main/java/cn/bugstack/chatglm/model/Role.java b/src/main/java/cn/bugstack/chatglm/model/Role.java
index 2b7aa9af7b51e288e81ddfd5122ff5a3e05af049..caa49827c39711d8416b4e9829906669ef0f1935 100644
--- a/src/main/java/cn/bugstack/chatglm/model/Role.java
+++ b/src/main/java/cn/bugstack/chatglm/model/Role.java
@@ -6,7 +6,7 @@ import lombok.Getter;
/**
* @author 小傅哥,微信:fustack
* @description 角色
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Getter
diff --git a/src/main/java/cn/bugstack/chatglm/session/Configuration.java b/src/main/java/cn/bugstack/chatglm/session/Configuration.java
index 9d22ac186fa82eb67566690ef2337077663193ed..07ec3c3a959f5d5e3981171997c3ffc747903c6d 100644
--- a/src/main/java/cn/bugstack/chatglm/session/Configuration.java
+++ b/src/main/java/cn/bugstack/chatglm/session/Configuration.java
@@ -1,6 +1,10 @@
package cn.bugstack.chatglm.session;
import cn.bugstack.chatglm.IOpenAiApi;
+import cn.bugstack.chatglm.executor.Executor;
+import cn.bugstack.chatglm.executor.aigc.GLMOldExecutor;
+import cn.bugstack.chatglm.executor.aigc.GLMExecutor;
+import cn.bugstack.chatglm.model.Model;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
@@ -8,10 +12,12 @@ import okhttp3.logging.HttpLoggingInterceptor;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;
+import java.util.HashMap;
+
/**
* @author 小傅哥,微信:fustack
* @description 配置文件
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Slf4j
@@ -22,7 +28,7 @@ public class Configuration {
// 智普Ai ChatGlM 请求地址
@Getter
@Setter
- private String apiHost = "https://open.bigmodel.cn/api/paas/";
+ private String apiHost = "https://open.bigmodel.cn/";
// 智普Ai https://open.bigmodel.cn/usercenter/apikeys - apiSecretKey = {apiKey}.{apiSecret}
private String apiSecretKey;
@@ -69,10 +75,31 @@ public class Configuration {
@Getter
private long readTimeout = 450;
+ private HashMap executorGroup;
+
// http keywords
public static final String SSE_CONTENT_TYPE = "text/event-stream";
public static final String DEFAULT_USER_AGENT = "Mozilla/4.0 (compatible; MSIE 5.0; Windows NT; DigExt)";
public static final String APPLICATION_JSON = "application/json";
public static final String JSON_CONTENT_TYPE = APPLICATION_JSON + "; charset=utf-8";
+ public HashMap newExecutorGroup() {
+ this.executorGroup = new HashMap<>();
+ // 旧版模型,兼容
+ Executor glmOldExecutor = new GLMOldExecutor(this);
+ this.executorGroup.put(Model.CHATGLM_6B_SSE, glmOldExecutor);
+ this.executorGroup.put(Model.CHATGLM_LITE, glmOldExecutor);
+ this.executorGroup.put(Model.CHATGLM_LITE_32K, glmOldExecutor);
+ this.executorGroup.put(Model.CHATGLM_STD, glmOldExecutor);
+ this.executorGroup.put(Model.CHATGLM_PRO, glmOldExecutor);
+ this.executorGroup.put(Model.CHATGLM_TURBO, glmOldExecutor);
+ // 新版模型,配置
+ Executor glmExecutor = new GLMExecutor(this);
+ this.executorGroup.put(Model.GLM_3_5_TURBO, glmExecutor);
+ this.executorGroup.put(Model.GLM_4, glmExecutor);
+ this.executorGroup.put(Model.GLM_4V, glmExecutor);
+ this.executorGroup.put(Model.COGVIEW_3, glmExecutor);
+ return this.executorGroup;
+ }
+
}
diff --git a/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java b/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java
index 88b25954f82f7e16fdcf6e3060b5523ae0129515..1f388df4a1f4d737d2641bce974491269a8f80df 100644
--- a/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java
+++ b/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java
@@ -1,8 +1,6 @@
package cn.bugstack.chatglm.session;
-import cn.bugstack.chatglm.model.ChatCompletionRequest;
-import cn.bugstack.chatglm.model.ChatCompletionResponse;
-import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
+import cn.bugstack.chatglm.model.*;
import com.fasterxml.jackson.core.JsonProcessingException;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
@@ -13,15 +11,19 @@ import java.util.concurrent.CompletableFuture;
/**
* @author 小傅哥,微信:fustack
* @description 会话服务接口
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
public interface OpenAiSession {
- EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws JsonProcessingException;
+ EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws Exception;
- CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException;
+ CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws Exception;
- ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException;
+ ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws Exception;
+
+ ImageCompletionResponse genImages(ImageCompletionRequest imageCompletionRequest) throws Exception;
+
+ Configuration configuration();
}
diff --git a/src/main/java/cn/bugstack/chatglm/session/OpenAiSessionFactory.java b/src/main/java/cn/bugstack/chatglm/session/OpenAiSessionFactory.java
index f60a4b78037d42306b737af78cc1fb926349bdec..c823404bf0f2cd70b10cd40e6875c1cbea5ecfec 100644
--- a/src/main/java/cn/bugstack/chatglm/session/OpenAiSessionFactory.java
+++ b/src/main/java/cn/bugstack/chatglm/session/OpenAiSessionFactory.java
@@ -3,7 +3,7 @@ package cn.bugstack.chatglm.session;
/**
* @author 小傅哥,微信:fustack
* @description 工厂接口
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
public interface OpenAiSessionFactory {
diff --git a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java
index 687e408319b308ff4266d0e259f71bece3d1526e..41da71c05b05babce17d418f1ce685a9c31e8bb3 100644
--- a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java
+++ b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java
@@ -1,112 +1,64 @@
package cn.bugstack.chatglm.session.defaults;
-import cn.bugstack.chatglm.IOpenAiApi;
-import cn.bugstack.chatglm.model.ChatCompletionRequest;
-import cn.bugstack.chatglm.model.ChatCompletionResponse;
-import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
-import cn.bugstack.chatglm.model.EventType;
+import cn.bugstack.chatglm.executor.Executor;
+import cn.bugstack.chatglm.model.*;
import cn.bugstack.chatglm.session.Configuration;
import cn.bugstack.chatglm.session.OpenAiSession;
-import com.alibaba.fastjson.JSON;
-import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.extern.slf4j.Slf4j;
-import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
-import org.jetbrains.annotations.Nullable;
import java.io.IOException;
+import java.util.Map;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.CountDownLatch;
/**
* @author 小傅哥,微信:fustack
* @description 会话服务
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Slf4j
public class DefaultOpenAiSession implements OpenAiSession {
- /**
- * OpenAi 接口
- */
private final Configuration configuration;
- /**
- * 工厂事件
- */
- private final EventSource.Factory factory;
+ private final Map executorGroup;
- public DefaultOpenAiSession(Configuration configuration) {
+ public DefaultOpenAiSession(Configuration configuration, Map executorGroup) {
this.configuration = configuration;
- this.factory = configuration.createRequestFactory();
+ this.executorGroup = executorGroup;
}
-
@Override
- public EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws JsonProcessingException {
- // 构建请求信息
- Request request = new Request.Builder()
- .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions).replace("{model}", chatCompletionRequest.getModel().getCode()))
- .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
- .build();
-
- // 返回事件结果
- return factory.newEventSource(request, eventSourceListener);
+ public EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws Exception {
+ Executor executor = executorGroup.get(chatCompletionRequest.getModel());
+ if (null == executor) throw new RuntimeException(chatCompletionRequest.getModel() + " 模型执行器尚未实现!");
+ return executor.completions(chatCompletionRequest, eventSourceListener);
}
@Override
- public CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException {
- // 用于执行异步任务并获取结果
- CompletableFuture future = new CompletableFuture<>();
- StringBuffer dataBuffer = new StringBuffer();
-
- // 构建请求信息
- Request request = new Request.Builder()
- .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions).replace("{model}", chatCompletionRequest.getModel().getCode()))
- .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
- .build();
-
- // 异步响应请求
- factory.newEventSource(request, new EventSourceListener() {
- @Override
- public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
- ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
- // type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断
- if (EventType.add.getCode().equals(type)) {
- dataBuffer.append(response.getData());
- } else if (EventType.finish.getCode().equals(type)) {
- future.complete(dataBuffer.toString());
- }
- }
-
- @Override
- public void onClosed(EventSource eventSource) {
- future.completeExceptionally(new RuntimeException("Request closed before completion"));
- }
+ public CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws Exception {
+ Executor executor = executorGroup.get(chatCompletionRequest.getModel());
+ if (null == executor) throw new RuntimeException(chatCompletionRequest.getModel() + " 模型执行器尚未实现!");
+ return executor.completions(chatCompletionRequest);
+ }
- @Override
- public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
- future.completeExceptionally(new RuntimeException("Request closed before completion"));
- }
- });
+ @Override
+ public ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws Exception {
+ Executor executor = executorGroup.get(chatCompletionRequest.getModel());
+ if (null == executor) throw new RuntimeException(chatCompletionRequest.getModel() + " 模型执行器尚未实现!");
+ return executor.completionsSync(chatCompletionRequest);
+ }
- return future;
+ @Override
+ public ImageCompletionResponse genImages(ImageCompletionRequest imageCompletionRequest) throws Exception {
+ Executor executor = executorGroup.get(imageCompletionRequest.getModelEnum());
+ if (null == executor) throw new RuntimeException(imageCompletionRequest.getModel() + " 模型执行器尚未实现!");
+ return executor.genImages(imageCompletionRequest);
}
@Override
- public ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException {
- // 构建请求信息
- Request request = new Request.Builder()
- .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions_sync).replace("{model}", chatCompletionRequest.getModel().getCode()))
- .header("Accept",Configuration.APPLICATION_JSON)
- .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
- .build();
- OkHttpClient okHttpClient = configuration.getOkHttpClient();
- Response response = okHttpClient.newCall(request).execute();
- if(!response.isSuccessful()){
- new RuntimeException("Request failed");
- }
- return JSON.parseObject(response.body().string(),ChatCompletionSyncResponse.class);
+ public Configuration configuration() {
+ return configuration;
}
}
diff --git a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSessionFactory.java b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSessionFactory.java
index b72dd9710ae0a654710d34c9855f37358205fe0c..a2398eaf05e5f15d4aee71d5475c014de6a06a2c 100644
--- a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSessionFactory.java
+++ b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSessionFactory.java
@@ -1,7 +1,9 @@
package cn.bugstack.chatglm.session.defaults;
import cn.bugstack.chatglm.IOpenAiApi;
+import cn.bugstack.chatglm.executor.Executor;
import cn.bugstack.chatglm.interceptor.OpenAiHTTPInterceptor;
+import cn.bugstack.chatglm.model.Model;
import cn.bugstack.chatglm.session.Configuration;
import cn.bugstack.chatglm.session.OpenAiSession;
import cn.bugstack.chatglm.session.OpenAiSessionFactory;
@@ -11,12 +13,13 @@ import retrofit2.Retrofit;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;
+import java.util.HashMap;
import java.util.concurrent.TimeUnit;
/**
* @author 小傅哥,微信:fustack
* @description 会话工厂
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
public class DefaultOpenAiSessionFactory implements OpenAiSessionFactory {
@@ -55,7 +58,10 @@ public class DefaultOpenAiSessionFactory implements OpenAiSessionFactory {
configuration.setOpenAiApi(openAiApi);
- return new DefaultOpenAiSession(configuration);
+ // 4. 实例化执行器
+ HashMap executorGroup = configuration.newExecutorGroup();
+
+ return new DefaultOpenAiSession(configuration, executorGroup);
}
}
diff --git a/src/main/java/cn/bugstack/chatglm/utils/BearerTokenUtils.java b/src/main/java/cn/bugstack/chatglm/utils/BearerTokenUtils.java
index b19927c3d7e355b763df78daa0b0c12bb4646b20..199ff6920a72f9060c0d3641a533530758c0802f 100644
--- a/src/main/java/cn/bugstack/chatglm/utils/BearerTokenUtils.java
+++ b/src/main/java/cn/bugstack/chatglm/utils/BearerTokenUtils.java
@@ -15,7 +15,7 @@ import java.util.concurrent.TimeUnit;
/**
* @author 小傅哥,微信:fustack
* @description 签名工具包;过期时间30分钟,缓存时间29分钟
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Slf4j
diff --git a/src/test/java/cn/bugstack/chatglm/test/ApiTest.java b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java
index 1808b1d8eaed2c90ac6023d7d2324850e17cd188..f0d7224be4f86b4f10324663c353d253e36177e7 100644
--- a/src/test/java/cn/bugstack/chatglm/test/ApiTest.java
+++ b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java
@@ -9,6 +9,7 @@ import cn.bugstack.chatglm.utils.BearerTokenUtils;
import com.alibaba.fastjson.JSON;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.extern.slf4j.Slf4j;
+import okhttp3.Response;
import okhttp3.logging.HttpLoggingInterceptor;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
@@ -26,7 +27,7 @@ import java.util.concurrent.ExecutionException;
/**
* @author 小傅哥,微信:fustack
* @description 在官网申请 ApiSecretKey ApiSecretKey
- * @github https://github.com/fuzhengwei
+ * @github chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Slf4j
@@ -39,7 +40,7 @@ public class ApiTest {
// 1. 配置文件
Configuration configuration = new Configuration();
configuration.setApiHost("https://open.bigmodel.cn/");
- configuration.setApiSecretKey("d570f7c5d289cdac2abdfdc562e39f3f.trqz1dH8ZK6ED7Pg");
+ configuration.setApiSecretKey("62ddec38b1d0b9a7b0fddaf271e6ed90.HpD0SUBUlvqd05ey");
configuration.setLevel(HttpLoggingInterceptor.Level.BODY);
// 2. 会话工厂
OpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
@@ -48,76 +49,213 @@ public class ApiTest {
}
/**
- * 流式对话
+ * 流式对话;
+ * 1. 默认 isCompatible = true 会兼容新旧版数据格式
+ * 2. GLM_3_5_TURBO、GLM_4 支持联网等插件
*/
@Test
- public void test_completions() throws JsonProcessingException, InterruptedException {
+ public void test_completions() throws Exception {
+ CountDownLatch countDownLatch = new CountDownLatch(1);
+
// 入参;模型、请求信息
ChatCompletionRequest request = new ChatCompletionRequest();
- request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
+ request.setModel(Model.GLM_3_5_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
request.setIncremental(false);
+ request.setIsCompatible(true); // 是否对返回结果数据做兼容,24年1月发布的 GLM_3_5_TURBO、GLM_4 模型,与之前的模型在返回结果上有差异。开启 true 可以做兼容。
+ // 24年1月发布的 glm-3-turbo、glm-4 支持函数、知识库、联网功能
+ request.setTools(new ArrayList() {
+ private static final long serialVersionUID = -7988151926241837899L;
+
+ {
+ add(ChatCompletionRequest.Tool.builder()
+ .type(ChatCompletionRequest.Tool.Type.web_search)
+ .webSearch(ChatCompletionRequest.Tool.WebSearch.builder().enable(true).searchQuery("小傅哥").build())
+ .build());
+ }
+ });
request.setPrompt(new ArrayList() {
private static final long serialVersionUID = -7988151926241837899L;
{
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
- .content("1+2")
+ .content("小傅哥的是谁")
+ .build());
+ }
+ });
+
+ // 请求
+ openAiSession.completions(request, new EventSourceListener() {
+ @Override
+ public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ log.info("测试结果 onEvent:{}", response.getData());
+ // type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断
+ if (EventType.finish.getCode().equals(type)) {
+ ChatCompletionResponse.Meta meta = JSON.parseObject(response.getMeta(), ChatCompletionResponse.Meta.class);
+ log.info("[输出结束] Tokens {}", JSON.toJSONString(meta));
+ }
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource) {
+ log.info("对话完成");
+ countDownLatch.countDown();
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ log.info("对话异常");
+ countDownLatch.countDown();
+ }
+ });
+
+ // 等待
+ countDownLatch.await();
+ }
+
+ /**
+ * 流式对话;
+ * 1. 与 test_completions 测试类相比,只是设置 isCompatible = false 这样就是使用了新的数据结构。onEvent 处理接收数据有差异
+ * 2. 不兼容旧版格式的话,仅支持 GLM_3_5_TURBO、GLM_4 其他模型会有解析错误
+ */
+ @Test
+ public void test_completions_new() throws Exception {
+ CountDownLatch countDownLatch = new CountDownLatch(1);
+
+ // 入参;模型、请求信息
+ ChatCompletionRequest request = new ChatCompletionRequest();
+ request.setModel(Model.GLM_3_5_TURBO); // GLM_3_5_TURBO、GLM_4
+ request.setIsCompatible(false);
+ // 24年1月发布的 glm-3-turbo、glm-4 支持函数、知识库、联网功能
+ request.setTools(new ArrayList() {
+ private static final long serialVersionUID = -7988151926241837899L;
+
+ {
+ add(ChatCompletionRequest.Tool.builder()
+ .type(ChatCompletionRequest.Tool.Type.web_search)
+ .webSearch(ChatCompletionRequest.Tool.WebSearch.builder().enable(true).searchQuery("小傅哥").build())
.build());
+ }
+ });
+ request.setMessages(new ArrayList() {
+ private static final long serialVersionUID = -7988151926241837899L;
+ {
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
- .content("Okay")
+ .content("小傅哥的是谁")
.build());
+ }
+ });
- /* system 和 user 为一组出现。如果有参数类型为 system 则 system + user 一组一起传递。*/
+ // 请求
+ openAiSession.completions(request, new EventSourceListener() {
+ @Override
+ public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
+ if ("[DONE]".equals(data)) {
+ log.info("[输出结束] Tokens {}", JSON.toJSONString(data));
+ return;
+ }
+
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ log.info("测试结果:{}", JSON.toJSONString(response));
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource) {
+ log.info("对话完成");
+ countDownLatch.countDown();
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ log.error("对话失败", t);
+ countDownLatch.countDown();
+ }
+ });
+
+ // 等待
+ countDownLatch.await();
+ }
+
+ /**
+ * 模型编码:glm-4v
+ * 根据输入的自然语言指令和图像信息完成任务,推荐使用 SSE 或同步调用方式请求接口
+ * https://open.bigmodel.cn/dev/api#glm-4v
+ */
+ @Test
+ public void test_completions_4v() throws Exception {
+ CountDownLatch countDownLatch = new CountDownLatch(1);
+ // 入参;模型、请求信息
+ ChatCompletionRequest request = new ChatCompletionRequest();
+ request.setModel(Model.GLM_4V); // GLM_3_5_TURBO、GLM_4
+ request.setStream(true);
+ request.setMessages(new ArrayList() {
+ private static final long serialVersionUID = -7988151926241837899L;
+
+ {
+ // content 字符串格式
add(ChatCompletionRequest.Prompt.builder()
- .role(Role.system.getCode())
- .content("1+1=2")
+ .role(Role.user.getCode())
+ .content("这个图片写了什么")
.build());
+ // content 对象格式
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
- .content("Okay")
+ .content(ChatCompletionRequest.Prompt.Content.builder()
+ .type(ChatCompletionRequest.Prompt.Content.Type.text.getCode())
+ .text("这是什么图片")
+ .build())
.build());
+ // content 对象格式,上传图片;图片支持url、basde64
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
- .content("1+2")
+ .content(ChatCompletionRequest.Prompt.Content.builder()
+ .type(ChatCompletionRequest.Prompt.Content.Type.image_url.getCode())
+ .imageUrl(ChatCompletionRequest.Prompt.Content.ImageUrl.builder().url("https://bugstack.cn/images/article/project/chatgpt/chatgpt-extra-231011-01.png").build())
+ .build())
.build());
-
}
});
- // 请求
openAiSession.completions(request, new EventSourceListener() {
@Override
public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
- ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
- log.info("测试结果 onEvent:{}", response.getData());
- // type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断
- if (EventType.finish.getCode().equals(type)) {
- ChatCompletionResponse.Meta meta = JSON.parseObject(response.getMeta(), ChatCompletionResponse.Meta.class);
- log.info("[输出结束] Tokens {}", JSON.toJSONString(meta));
+ if ("[DONE]".equals(data)) {
+ log.info("[输出结束] Tokens {}", JSON.toJSONString(data));
+ return;
}
+
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ log.info("测试结果:{}", JSON.toJSONString(response));
}
@Override
public void onClosed(EventSource eventSource) {
log.info("对话完成");
+ countDownLatch.countDown();
}
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ log.error("对话失败", t);
+ countDownLatch.countDown();
+ }
});
// 等待
- new CountDownLatch(1).await();
+ countDownLatch.await();
+
}
/**
* 同步请求
*/
@Test
- public void test_completions_future() throws ExecutionException, InterruptedException {
+ public void test_completions_future() throws Exception {
// 入参;模型、请求信息
ChatCompletionRequest request = new ChatCompletionRequest();
request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
@@ -142,17 +280,29 @@ public class ApiTest {
* 同步请求
*/
@Test
- public void test_completions_sync() throws IOException {
+ public void test_completions_sync() throws Exception {
// 入参;模型、请求信息
ChatCompletionRequest request = new ChatCompletionRequest();
- request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
+ request.setModel(Model.GLM_4V); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
request.setPrompt(new ArrayList() {
private static final long serialVersionUID = -7988151926241837899L;
{
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
- .content("1+1")
+ .content("小傅哥是谁")
+ .build());
+ }
+ });
+
+ // 24年1月发布的 glm-3-turbo、glm-4 支持函数、知识库、联网功能
+ request.setTools(new ArrayList() {
+ private static final long serialVersionUID = -7988151926241837899L;
+
+ {
+ add(ChatCompletionRequest.Tool.builder()
+ .type(ChatCompletionRequest.Tool.Type.web_search)
+ .webSearch(ChatCompletionRequest.Tool.WebSearch.builder().enable(true).searchQuery("小傅哥").build())
.build());
}
});
@@ -162,13 +312,21 @@ public class ApiTest {
log.info("测试结果:{}", JSON.toJSONString(response));
}
+ @Test
+ public void test_genImages() throws Exception {
+ ImageCompletionRequest request = new ImageCompletionRequest();
+ request.setModel(Model.COGVIEW_3);
+ request.setPrompt("画个小狗");
+ ImageCompletionResponse response = openAiSession.genImages(request);
+ log.info("测试结果:{}", JSON.toJSONString(response));
+ }
@Test
public void test_curl() {
// 1. 配置文件
Configuration configuration = new Configuration();
configuration.setApiHost("https://open.bigmodel.cn/");
- configuration.setApiSecretKey("4d00226f242793b9c267a64ab2eaf5cb.aIwQNiG59MhSWJbn");
+ configuration.setApiSecretKey("62ddec38b1d0b9a7b0fddaf271e6ed90.HpD0SUBUlvqd05ey");
// 2. 获取Token
String token = BearerTokenUtils.getToken(configuration.getApiKey(), configuration.getApiSecret());
diff --git a/src/test/java/cn/bugstack/chatglm/test/JSONTest.java b/src/test/java/cn/bugstack/chatglm/test/JSONTest.java
new file mode 100644
index 0000000000000000000000000000000000000000..79039bd0c2a88e93c04bdc807c4198d20b585c43
--- /dev/null
+++ b/src/test/java/cn/bugstack/chatglm/test/JSONTest.java
@@ -0,0 +1,68 @@
+package cn.bugstack.chatglm.test;
+
+import cn.bugstack.chatglm.model.ChatCompletionResponse;
+import cn.bugstack.chatglm.model.ImageCompletionResponse;
+import com.alibaba.fastjson.JSON;
+import lombok.extern.slf4j.Slf4j;
+import org.junit.Test;
+
+@Slf4j
+public class JSONTest {
+
+ @Test
+ public void test_glm_json() {
+ String json01 = "{\n" +
+ " \"id\": \"8305987191663349153\",\n" +
+ " \"created\": 1705487423,\n" +
+ " \"model\": \"glm-3-turbo\",\n" +
+ " \"choices\": [\n" +
+ " {\n" +
+ " \"index\": 0,\n" +
+ " \"delta\": {\n" +
+ " \"role\": \"assistant\",\n" +
+ " \"content\": \"1\"\n" +
+ " }\n" +
+ " }\n" +
+ " ]\n" +
+ "}";
+
+ String json02 = "{\n" +
+ " \"id\": \"8308763664682731117\",\n" +
+ " \"created\": 1705490859,\n" +
+ " \"model\": \"glm-3-turbo\",\n" +
+ " \"choices\": [\n" +
+ " {\n" +
+ " \"index\": 0,\n" +
+ " \"finish_reason\": \"stop\",\n" +
+ " \"delta\": {\n" +
+ " \"role\": \"assistant\",\n" +
+ " \"content\": \"\"\n" +
+ " }\n" +
+ " }\n" +
+ " ],\n" +
+ " \"usage\": {\n" +
+ " \"prompt_tokens\": 8,\n" +
+ " \"completion_tokens\": 12,\n" +
+ " \"total_tokens\": 20\n" +
+ " }\n" +
+ "}";
+
+ ChatCompletionResponse response = JSON.parseObject(json01, ChatCompletionResponse.class);
+ log.info("测试结果:{}", JSON.toJSONString(response.getChoices()));
+ }
+
+ @Test
+ public void test_image_json(){
+ String json = "{\n" +
+ " \"created\": 1705549253,\n" +
+ " \"data\": [\n" +
+ " {\n" +
+ " \"url\": \"https://sfile.chatglm.cn/testpath/cbffcbf4-ac63-50a3-9d1e-b644c77ffaa2_0.png\"\n" +
+ " }\n" +
+ " ]\n" +
+ "}";
+ ImageCompletionResponse response = JSON.parseObject(json, ImageCompletionResponse.class);
+ log.info("测试结果:{}", response.getData().get(0).getUrl());
+ }
+
+}