From b6c05df8fd08454217baa83af791e0f98e143c98 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=B0=8F=E5=82=85=E5=93=A5?= <184172133@qq.com>
Date: Sun, 21 Jan 2024 13:19:21 +0800
Subject: [PATCH] =?UTF-8?q?feat=EF=BC=9A=E6=8E=A5=E5=85=A5=E6=A8=A1?=
=?UTF-8?q?=E5=9E=8B=20glm-3.0=E3=80=81glm-4.0=E3=80=81glm-4v=E3=80=81cogv?=
=?UTF-8?q?iew-3?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
pom.xml | 14 +-
.../java/cn/bugstack/chatglm/IOpenAiApi.java | 13 +-
.../bugstack/chatglm/executor/Executor.java | 50 ++++
.../chatglm/executor/aigc/GLMExecutor.java | 178 ++++++++++++++
.../chatglm/executor/aigc/GLMOldExecutor.java | 108 ++++++++
.../executor/result/ResultHandler.java | 12 +
.../interceptor/OpenAiHTTPInterceptor.java | 4 +-
.../chatglm/model/ChatCompletionRequest.java | 232 +++++++++++++++++-
.../chatglm/model/ChatCompletionResponse.java | 56 ++++-
.../cn/bugstack/chatglm/model/EventType.java | 2 +-
.../chatglm/model/ImageCompletionRequest.java | 55 +++++
.../model/ImageCompletionResponse.java | 25 ++
.../java/cn/bugstack/chatglm/model/Model.java | 10 +-
.../java/cn/bugstack/chatglm/model/Role.java | 2 +-
.../chatglm/session/Configuration.java | 31 ++-
.../chatglm/session/OpenAiSession.java | 16 +-
.../chatglm/session/OpenAiSessionFactory.java | 2 +-
.../defaults/DefaultOpenAiSession.java | 106 +++-----
.../defaults/DefaultOpenAiSessionFactory.java | 10 +-
.../chatglm/utils/BearerTokenUtils.java | 2 +-
.../cn/bugstack/chatglm/test/ApiTest.java | 210 ++++++++++++++--
.../cn/bugstack/chatglm/test/JSONTest.java | 68 +++++
22 files changed, 1066 insertions(+), 140 deletions(-)
create mode 100644 src/main/java/cn/bugstack/chatglm/executor/Executor.java
create mode 100644 src/main/java/cn/bugstack/chatglm/executor/aigc/GLMExecutor.java
create mode 100644 src/main/java/cn/bugstack/chatglm/executor/aigc/GLMOldExecutor.java
create mode 100644 src/main/java/cn/bugstack/chatglm/executor/result/ResultHandler.java
create mode 100644 src/main/java/cn/bugstack/chatglm/model/ImageCompletionRequest.java
create mode 100644 src/main/java/cn/bugstack/chatglm/model/ImageCompletionResponse.java
create mode 100644 src/test/java/cn/bugstack/chatglm/test/JSONTest.java
diff --git a/pom.xml b/pom.xml
index 1b0609a..1976f50 100644
--- a/pom.xml
+++ b/pom.xml
@@ -6,7 +6,7 @@
cn.bugstack
chatglm-sdk-java
- 1.1
+ 2.0
chatglm-sdk-java
OpenAI Java SDK, ZhiPuAi ChatGLM Java SDK . Copyright © 2023 bugstack虫洞栈 All rights reserved. 版权所有(C)小傅哥 https://github.com/fuzhengwei/chatglm-sdk-java
@@ -134,6 +134,18 @@
provided
true
+
+
+ org.apache.httpcomponents
+ httpclient
+ 4.5.14
+
+
+ org.apache.httpcomponents
+ httpmime
+ 4.5.10
+
+
diff --git a/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java b/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java
index 23776fa..2fb309f 100644
--- a/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java
+++ b/src/main/java/cn/bugstack/chatglm/IOpenAiApi.java
@@ -1,8 +1,6 @@
package cn.bugstack.chatglm;
-import cn.bugstack.chatglm.model.ChatCompletionRequest;
-import cn.bugstack.chatglm.model.ChatCompletionResponse;
-import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
+import cn.bugstack.chatglm.model.*;
import io.reactivex.Single;
import retrofit2.http.Body;
import retrofit2.http.POST;
@@ -11,7 +9,7 @@ import retrofit2.http.Path;
/**
* @author 小傅哥,微信:fustack
* @description OpenAi 接口,用于扩展通用类服务
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
public interface IOpenAiApi {
@@ -25,4 +23,11 @@ public interface IOpenAiApi {
@POST(v3_completions_sync)
Single completions(@Body ChatCompletionRequest chatCompletionRequest);
+ String v4 = "api/paas/v4/chat/completions";
+
+ String cogview3 = "api/paas/v4/images/generations";
+
+ @POST(cogview3)
+ Single genImages(@Body ImageCompletionRequest imageCompletionRequest);
+
}
diff --git a/src/main/java/cn/bugstack/chatglm/executor/Executor.java b/src/main/java/cn/bugstack/chatglm/executor/Executor.java
new file mode 100644
index 0000000..6124785
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/executor/Executor.java
@@ -0,0 +1,50 @@
+package cn.bugstack.chatglm.executor;
+
+import cn.bugstack.chatglm.model.ChatCompletionRequest;
+import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
+import cn.bugstack.chatglm.model.ImageCompletionRequest;
+import cn.bugstack.chatglm.model.ImageCompletionResponse;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSourceListener;
+
+import java.io.IOException;
+import java.util.concurrent.CompletableFuture;
+
+public interface Executor {
+
+ /**
+ * 问答模式,流式反馈
+ *
+ * @param chatCompletionRequest 请求信息
+ * @param eventSourceListener 实现监听;通过监听的 onEvent 方法接收数据
+ * @return 应答结果
+ * @throws Exception 异常
+ */
+ EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws Exception;
+
+ /**
+ * 问答模式,同步反馈 —— 用流式转化 Future
+ *
+ * @param chatCompletionRequest 请求信息
+ * @return 应答结果
+ */
+ CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException;
+
+ /**
+ * 同步应答接口
+ *
+ * @param chatCompletionRequest 请求信息
+ * @return ChatCompletionSyncResponse
+ * @throws IOException 异常
+ */
+ ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws Exception;
+
+ /**
+ * 图片生成接口
+ *
+ * @param request 请求信息
+ * @return 应答结果
+ */
+ ImageCompletionResponse genImages(ImageCompletionRequest request) throws Exception;
+
+}
diff --git a/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMExecutor.java b/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMExecutor.java
new file mode 100644
index 0000000..4ec6023
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMExecutor.java
@@ -0,0 +1,178 @@
+package cn.bugstack.chatglm.executor.aigc;
+
+import cn.bugstack.chatglm.IOpenAiApi;
+import cn.bugstack.chatglm.executor.Executor;
+import cn.bugstack.chatglm.executor.result.ResultHandler;
+import cn.bugstack.chatglm.model.*;
+import cn.bugstack.chatglm.session.Configuration;
+import com.alibaba.fastjson.JSON;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import lombok.extern.slf4j.Slf4j;
+import okhttp3.*;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSourceListener;
+import org.jetbrains.annotations.Nullable;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.concurrent.CompletableFuture;
+
+/**
+ * 智谱AI 通用大模型 glm-3-turbo、glm-4 执行器
+ * https://open.bigmodel.cn/dev/api
+ */
+@Slf4j
+public class GLMExecutor implements Executor, ResultHandler {
+
+ /**
+ * OpenAi 接口
+ */
+ private final Configuration configuration;
+ /**
+ * 工厂事件
+ */
+ private final EventSource.Factory factory;
+ /**
+ * 统一接口
+ */
+ private IOpenAiApi openAiApi;
+
+ private OkHttpClient okHttpClient;
+
+ public GLMExecutor(Configuration configuration) {
+ this.configuration = configuration;
+ this.factory = configuration.createRequestFactory();
+ this.openAiApi = configuration.getOpenAiApi();
+ this.okHttpClient = configuration.getOkHttpClient();
+ }
+
+ @Override
+ public EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws Exception {
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v4))
+ .post(RequestBody.create(MediaType.parse(Configuration.JSON_CONTENT_TYPE), chatCompletionRequest.toString()))
+ .build();
+
+ // 返回事件结果
+ return factory.newEventSource(request, chatCompletionRequest.getIsCompatible() ? eventSourceListener(eventSourceListener) : eventSourceListener);
+ }
+
+ @Override
+ public CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException {
+ // 用于执行异步任务并获取结果
+ CompletableFuture future = new CompletableFuture<>();
+ StringBuffer dataBuffer = new StringBuffer();
+
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v4))
+ .post(RequestBody.create(MediaType.parse("application/json;charset=utf-8"), chatCompletionRequest.toString()))
+ .build();
+
+ factory.newEventSource(request, new EventSourceListener() {
+ @Override
+ public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
+ if ("[DONE]".equals(data)) {
+ log.info("[输出结束] Tokens {}", JSON.toJSONString(data));
+ return;
+ }
+
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ log.info("测试结果:{}", JSON.toJSONString(response));
+ List choices = response.getChoices();
+ for (ChatCompletionResponse.Choice choice : choices) {
+ if (!"stop".equals(choice.getFinishReason())) {
+ dataBuffer.append(choice.getDelta().getContent());
+ }
+ }
+
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource) {
+ future.complete(dataBuffer.toString());
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ future.completeExceptionally(new RuntimeException("Request closed before completion"));
+ }
+ });
+
+ return future;
+ }
+
+ @Override
+ public ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws Exception {
+ // sync 同步请求,stream 为 false
+ chatCompletionRequest.setStream(false);
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v4))
+ .post(RequestBody.create(MediaType.parse(Configuration.JSON_CONTENT_TYPE), chatCompletionRequest.toString()))
+ .build();
+ OkHttpClient okHttpClient = configuration.getOkHttpClient();
+ Response response = okHttpClient.newCall(request).execute();
+ if(!response.isSuccessful()){
+ throw new RuntimeException("Request failed");
+ }
+ return JSON.parseObject(response.body().string(),ChatCompletionSyncResponse.class);
+ }
+
+ /**
+ * 图片生成,注释的方式留作扩展使用
+ *
+ * Request request = new Request.Builder()
+ * .url(configuration.getApiHost().concat(IOpenAiApi.cogview3))
+ * .post(RequestBody.create(MediaType.parse(Configuration.JSON_CONTENT_TYPE), imageCompletionRequest.toString()))
+ * .build();
+ * // 请求结果
+ * Call call = okHttpClient.newCall(request);
+ * Response execute = call.execute();
+ * ResponseBody body = execute.body();
+ * if (execute.isSuccessful() && body != null) {
+ * String responseBody = body.string();
+ * ObjectMapper objectMapper = new ObjectMapper();
+ * return objectMapper.readValue(responseBody, ImageCompletionResponse.class);
+ * } else {
+ * throw new IOException("Failed to get image response");
+ * }
+ * @param imageCompletionRequest 请求信息
+ * @return
+ * @throws Exception
+ */
+ @Override
+ public ImageCompletionResponse genImages(ImageCompletionRequest imageCompletionRequest) throws Exception {
+ return openAiApi.genImages(imageCompletionRequest).blockingGet();
+ }
+
+ @Override
+ public EventSourceListener eventSourceListener(EventSourceListener eventSourceListener) {
+ return new EventSourceListener() {
+ @Override
+ public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
+ if ("[DONE]".equals(data)) {
+ return;
+ }
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ if (response.getChoices() != null && 1 == response.getChoices().size() && "stop".equals(response.getChoices().get(0).getFinishReason())) {
+ eventSourceListener.onEvent(eventSource, id, EventType.finish.getCode(), data);
+ return;
+ }
+ eventSourceListener.onEvent(eventSource, id, EventType.add.getCode(), data);
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource) {
+ eventSourceListener.onClosed(eventSource);
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ eventSourceListener.onFailure(eventSource, t, response);
+ }
+ };
+ }
+
+}
diff --git a/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMOldExecutor.java b/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMOldExecutor.java
new file mode 100644
index 0000000..0c2e421
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/executor/aigc/GLMOldExecutor.java
@@ -0,0 +1,108 @@
+package cn.bugstack.chatglm.executor.aigc;
+
+import cn.bugstack.chatglm.IOpenAiApi;
+import cn.bugstack.chatglm.executor.Executor;
+import cn.bugstack.chatglm.model.*;
+import cn.bugstack.chatglm.session.Configuration;
+import com.alibaba.fastjson.JSON;
+import okhttp3.*;
+import okhttp3.sse.EventSource;
+import okhttp3.sse.EventSourceListener;
+import org.jetbrains.annotations.Nullable;
+
+import java.io.IOException;
+import java.util.concurrent.CompletableFuture;
+
+/**
+ * 智谱AI旧版接口模型; chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
+ * https://open.bigmodel.cn/dev/api
+ */
+public class GLMOldExecutor implements Executor {
+
+ /**
+ * OpenAi 接口
+ */
+ private final Configuration configuration;
+ /**
+ * 工厂事件
+ */
+ private final EventSource.Factory factory;
+
+ public GLMOldExecutor(Configuration configuration) {
+ this.configuration = configuration;
+ this.factory = configuration.createRequestFactory();
+ }
+
+ @Override
+ public EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws Exception {
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions).replace("{model}", chatCompletionRequest.getModel().getCode()))
+ .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
+ .build();
+
+ // 返回事件结果
+ return factory.newEventSource(request, eventSourceListener);
+ }
+
+ @Override
+ public CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException {
+ // 用于执行异步任务并获取结果
+ CompletableFuture future = new CompletableFuture<>();
+ StringBuffer dataBuffer = new StringBuffer();
+
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions).replace("{model}", chatCompletionRequest.getModel().getCode()))
+ .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
+ .build();
+
+ // 异步响应请求
+ factory.newEventSource(request, new EventSourceListener() {
+ @Override
+ public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ // type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断
+ if (EventType.add.getCode().equals(type)) {
+ dataBuffer.append(response.getData());
+ } else if (EventType.finish.getCode().equals(type)) {
+ future.complete(dataBuffer.toString());
+ }
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource) {
+ future.completeExceptionally(new RuntimeException("Request closed before completion"));
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ future.completeExceptionally(new RuntimeException("Request closed before completion"));
+ }
+ });
+
+ return future;
+ }
+
+ @Override
+ public ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException {
+ // 构建请求信息
+ Request request = new Request.Builder()
+ .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions_sync).replace("{model}", chatCompletionRequest.getModel().getCode()))
+ .header("Accept",Configuration.APPLICATION_JSON)
+ .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
+ .build();
+ OkHttpClient okHttpClient = configuration.getOkHttpClient();
+ Response response = okHttpClient.newCall(request).execute();
+ if(!response.isSuccessful()){
+ new RuntimeException("Request failed");
+ }
+ return JSON.parseObject(response.body().string(),ChatCompletionSyncResponse.class);
+ }
+
+ @Override
+ public ImageCompletionResponse genImages(ImageCompletionRequest request) {
+ throw new RuntimeException("旧版无图片生成接口");
+ }
+
+}
diff --git a/src/main/java/cn/bugstack/chatglm/executor/result/ResultHandler.java b/src/main/java/cn/bugstack/chatglm/executor/result/ResultHandler.java
new file mode 100644
index 0000000..6abd7c5
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/executor/result/ResultHandler.java
@@ -0,0 +1,12 @@
+package cn.bugstack.chatglm.executor.result;
+
+import okhttp3.sse.EventSourceListener;
+
+/**
+ * 结果封装器
+ */
+public interface ResultHandler {
+
+ EventSourceListener eventSourceListener(EventSourceListener eventSourceListener);
+
+}
diff --git a/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java b/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java
index 2c58ffe..378d2b3 100644
--- a/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java
+++ b/src/main/java/cn/bugstack/chatglm/interceptor/OpenAiHTTPInterceptor.java
@@ -12,7 +12,7 @@ import java.io.IOException;
/**
* @author 小傅哥,微信:fustack
* @description 接口拦截器
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
public class OpenAiHTTPInterceptor implements Interceptor {
@@ -36,7 +36,7 @@ public class OpenAiHTTPInterceptor implements Interceptor {
.header("Authorization", "Bearer " + BearerTokenUtils.getToken(configuration.getApiKey(), configuration.getApiSecret()))
.header("Content-Type", Configuration.JSON_CONTENT_TYPE)
.header("User-Agent", Configuration.DEFAULT_USER_AGENT)
- .header("Accept", null != original.header("Accept") ? original.header("Accept") : Configuration.SSE_CONTENT_TYPE)
+// .header("Accept", null != original.header("Accept") ? original.header("Accept") : Configuration.SSE_CONTENT_TYPE)
.method(original.method(), original.body())
.build();
diff --git a/src/main/java/cn/bugstack/chatglm/model/ChatCompletionRequest.java b/src/main/java/cn/bugstack/chatglm/model/ChatCompletionRequest.java
index af9302f..1dc83d5 100644
--- a/src/main/java/cn/bugstack/chatglm/model/ChatCompletionRequest.java
+++ b/src/main/java/cn/bugstack/chatglm/model/ChatCompletionRequest.java
@@ -1,5 +1,6 @@
package cn.bugstack.chatglm.model;
+import com.alibaba.fastjson.JSON;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
@@ -14,7 +15,7 @@ import java.util.Map;
/**
* @author 小傅哥,微信:fustack
* @description 请求参数
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Data
@@ -25,16 +26,36 @@ import java.util.Map;
@AllArgsConstructor
public class ChatCompletionRequest {
+ /**
+ * 是否对返回结果数据做兼容,24年1月发布的 GLM_3_5_TURBO、GLM_4 模型,与之前的模型在返回结果上有差异。开启 true 可以做兼容。
+ */
+ private Boolean isCompatible = true;
/**
* 模型
*/
- private Model model = Model.CHATGLM_6B_SSE;
-
+ private Model model = Model.GLM_3_5_TURBO;
+ /**
+ * 请求参数 {"role": "user", "content": "你好"}
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ private List messages;
/**
* 请求ID
*/
@JsonProperty("request_id")
private String requestId = String.format("xfg-%d", System.currentTimeMillis());
+ /**
+ * do_sample 为 true 时启用采样策略,do_sample 为 false 时采样策略 temperature、top_p 将不生效
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ @JsonProperty("do_sample")
+ private Boolean doSample = true;
+ /**
+ * 使用同步调用时,此参数应当设置为 Fasle 或者省略。表示模型生成完所有内容后一次性返回所有内容。
+ * 如果设置为 True,模型将通过标准 Event Stream ,逐块返回模型生成内容。Event Stream 结束时会返回一条data: [DONE]消息。
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ private Boolean stream = true;
/**
* 控制温度【随机性】
*/
@@ -44,6 +65,28 @@ public class ChatCompletionRequest {
*/
@JsonProperty("top_p")
private float topP = 0.7f;
+ /**
+ * 模型输出最大tokens
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ @JsonProperty("max_tokens")
+ private Integer maxTokens = 2048;
+ /**
+ * 模型在遇到stop所制定的字符时将停止生成,目前仅支持单个停止词,格式为["stop_word1"]
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ private List stop;
+ /**
+ * 可供模型调用的工具列表,tools字段会计算 tokens ,同样受到tokens长度的限制
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ private List tools;
+ /**
+ * 用于控制模型是如何选择要调用的函数,仅当工具类型为function时补充。默认为auto,当前仅支持auto。
+ * 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ */
+ @JsonProperty("tool_choice")
+ private String toolChoice = "auto";
/**
* 输入给模型的会话信息
* 用户输入的内容;role=user
@@ -60,24 +103,191 @@ public class ChatCompletionRequest {
private String sseFormat = "data";
@Data
- @Builder
@NoArgsConstructor
@AllArgsConstructor
public static class Prompt {
private String role;
private String content;
+
+ public static PromptBuilder builder() {
+ return new PromptBuilder();
+ }
+
+ public static class PromptBuilder {
+ private String role;
+ private String content;
+
+ PromptBuilder() {
+ }
+
+ public PromptBuilder role(String role) {
+ this.role = role;
+ return this;
+ }
+
+ public PromptBuilder content(String content) {
+ this.content = content;
+ return this;
+ }
+
+ public PromptBuilder content(Content content) {
+ this.content = JSON.toJSONString(content);
+ return this;
+ }
+
+ public Prompt build() {
+ return new Prompt(this.role, this.content);
+ }
+
+ public String toString() {
+ return "ChatCompletionRequest.Prompt.PromptBuilder(role=" + this.role + ", content=" + this.content + ")";
+ }
+ }
+
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class Content {
+ private String type = Type.text.code;
+ private String text;
+ @JsonProperty("image_url")
+ private ImageUrl imageUrl;
+
+ @Getter
+ @AllArgsConstructor
+ public static enum Type {
+ text("text", "文本"),
+ image_url("image_url", "图"),
+ ;
+ private final String code;
+ private final String info;
+ }
+
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class ImageUrl {
+ private String url;
+ }
+
+ }
+ }
+
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class Tool {
+ private Type type;
+ private Function function;
+ private Retrieval retrieval;
+ @JsonProperty("web_search")
+ private WebSearch webSearch;
+
+ public String getType() {
+ return type.code;
+ }
+
+ @Getter
+ @AllArgsConstructor
+ public static enum Type {
+ function("function", "函数功能"),
+ retrieval("retrieval", "知识库"),
+ web_search("web_search", "联网"),
+ ;
+ private final String code;
+ private final String info;
+ }
+
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class Function {
+ // 函数名称,只能包含a-z,A-Z,0-9,下划线和中横线。最大长度限制为64
+ private String name;
+ // 用于描述函数功能。模型会根据这段描述决定函数调用方式。
+ private String description;
+ // parameter 字段需要传入一个 Json Schema 对象,以准确地定义函数所接受的参数。https://open.bigmodel.cn/dev/api#glm-3-turbo
+ private Object parameters;
+ }
+
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class Retrieval {
+ // 当涉及到知识库ID时,请前往开放平台的知识库模块进行创建或获取。
+ @JsonProperty("knowledge_id")
+ private String knowledgeId;
+ // 请求模型时的知识库模板,默认模板:
+ @JsonProperty("prompt_template")
+ private String promptTemplate = "\"\"\"\n" +
+ "{{ knowledge}}\n" +
+ "\"\"\"\n" +
+ "中找问题\n" +
+ "\"\"\"\n" +
+ "{{question}}\n" +
+ "\"\"\"";
+ }
+
+ // 仅当工具类型为web_search时补充,如果tools中存在类型retrieval,此时web_search不生效。
+ @Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
+ public static class WebSearch {
+ // 是否启用搜索,默认启用搜索 enable = true/false
+ private Boolean enable = true;
+ // 强制搜索自定义关键内容,此时模型会根据自定义搜索关键内容返回的结果作为背景知识来回答用户发起的对话。
+ @JsonProperty("search_query")
+ private String searchQuery;
+ }
+
+
}
@Override
public String toString() {
- Map paramsMap = new HashMap<>();
- paramsMap.put("request_id", requestId);
- paramsMap.put("prompt", prompt);
- paramsMap.put("incremental", incremental);
- paramsMap.put("temperature", temperature);
- paramsMap.put("top_p", topP);
- paramsMap.put("sseFormat", sseFormat);
try {
+ // 24年1月发布新模型后调整
+ if (Model.GLM_3_5_TURBO.equals(this.model) || Model.GLM_4.equals(this.model) || Model.GLM_4V.equals(this.model)) {
+ Map paramsMap = new HashMap<>();
+ paramsMap.put("model", this.model.getCode());
+ if (null == this.messages && null == this.prompt) {
+ throw new RuntimeException("One of messages or prompt must not be empty!");
+ }
+ paramsMap.put("messages", this.messages != null ? this.messages : this.prompt);
+ if (null != this.requestId) {
+ paramsMap.put("request_id", this.requestId);
+ }
+ if (null != this.doSample) {
+ paramsMap.put("do_sample", this.doSample);
+ }
+ paramsMap.put("stream", this.stream);
+ paramsMap.put("temperature", this.temperature);
+ paramsMap.put("top_p", this.topP);
+ paramsMap.put("max_tokens", this.maxTokens);
+ if (null != this.stop && this.stop.size() > 0) {
+ paramsMap.put("stop", this.stop);
+ }
+ if (null != this.tools && this.tools.size() > 0) {
+ paramsMap.put("tools", this.tools);
+ paramsMap.put("tool_choice", this.toolChoice);
+ }
+ return new ObjectMapper().writeValueAsString(paramsMap);
+ }
+
+ // 默认
+ Map paramsMap = new HashMap<>();
+ paramsMap.put("request_id", requestId);
+ paramsMap.put("prompt", prompt);
+ paramsMap.put("incremental", incremental);
+ paramsMap.put("temperature", temperature);
+ paramsMap.put("top_p", topP);
+ paramsMap.put("sseFormat", sseFormat);
return new ObjectMapper().writeValueAsString(paramsMap);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
diff --git a/src/main/java/cn/bugstack/chatglm/model/ChatCompletionResponse.java b/src/main/java/cn/bugstack/chatglm/model/ChatCompletionResponse.java
index 6245d80..29e7b94 100644
--- a/src/main/java/cn/bugstack/chatglm/model/ChatCompletionResponse.java
+++ b/src/main/java/cn/bugstack/chatglm/model/ChatCompletionResponse.java
@@ -1,20 +1,74 @@
package cn.bugstack.chatglm.model;
+import com.alibaba.fastjson2.JSON;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.List;
/**
* @author 小傅哥,微信:fustack
* @description 返回结果
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Data
public class ChatCompletionResponse {
+ // 旧版获得的数据方式
private String data;
private String meta;
+ // 24年1月发布的 GLM_3_5_TURBO、GLM_4 模型时新增
+ private String id;
+ private Long created;
+ private String model;
+ private List choices;
+ private Usage usage;
+
+ // 封装 setChoices 对 data 属性赋值,兼容旧版使用方式
+ public void setChoices(List choices) {
+ this.choices = choices;
+ for (Choice choice : choices) {
+ if ("stop".equals(choice.finishReason)) {
+ continue;
+ }
+ if (null == this.data) {
+ this.data = "";
+ }
+ this.data = this.data.concat(choice.getDelta().getContent());
+ }
+ }
+
+ // 封装 setChoices 对 meta 属性赋值,兼容旧版使用方式
+ public void setUsage(Usage usage) {
+ this.usage = usage;
+ if (null != usage) {
+ this.meta = JSON.toJSONString(Meta.builder().usage(usage).build());
+ }
+ }
+
+ @Data
+ public static class Choice {
+ private Long index;
+ @JsonProperty("finish_reason")
+ private String finishReason;
+ private Delta delta;
+ }
+
+ @Data
+ public static class Delta {
+ private String role;
+ private String content;
+ }
+
@Data
+ @Builder
+ @NoArgsConstructor
+ @AllArgsConstructor
public static class Meta {
private String task_status;
private Usage usage;
diff --git a/src/main/java/cn/bugstack/chatglm/model/EventType.java b/src/main/java/cn/bugstack/chatglm/model/EventType.java
index 46fa9a0..3e18c1d 100644
--- a/src/main/java/cn/bugstack/chatglm/model/EventType.java
+++ b/src/main/java/cn/bugstack/chatglm/model/EventType.java
@@ -6,7 +6,7 @@ import lombok.Getter;
/**
* @author 小傅哥,微信:fustack
* @description 消息类型 chatglm_lite
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Getter
diff --git a/src/main/java/cn/bugstack/chatglm/model/ImageCompletionRequest.java b/src/main/java/cn/bugstack/chatglm/model/ImageCompletionRequest.java
new file mode 100644
index 0000000..94080bd
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/model/ImageCompletionRequest.java
@@ -0,0 +1,55 @@
+package cn.bugstack.chatglm.model;
+
+import com.fasterxml.jackson.annotation.JsonInclude;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import lombok.AllArgsConstructor;
+import lombok.Builder;
+import lombok.Data;
+import lombok.NoArgsConstructor;
+
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * CogView 根据用户的文字描述生成图像,使用同步调用方式请求接口
+ */
+@Data
+@Builder
+@NoArgsConstructor
+@AllArgsConstructor
+@JsonInclude(JsonInclude.Include.NON_NULL)
+public class ImageCompletionRequest {
+
+ /**
+ * 模型;24年1月发布了 cogview-3 生成图片模型
+ */
+ private Model model = Model.COGVIEW_3;
+
+ /**
+ * 所需图像的文本描述
+ */
+ private String prompt;
+
+ public String getModel() {
+ return model.getCode();
+ }
+
+ public Model getModelEnum() {
+ return model;
+ }
+
+ @Override
+ public String toString() {
+ Map paramsMap = new HashMap<>();
+ paramsMap.put("model", model.getCode());
+ paramsMap.put("prompt", prompt);
+ try {
+ return new ObjectMapper().writeValueAsString(paramsMap);
+ } catch (JsonProcessingException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+}
+
diff --git a/src/main/java/cn/bugstack/chatglm/model/ImageCompletionResponse.java b/src/main/java/cn/bugstack/chatglm/model/ImageCompletionResponse.java
new file mode 100644
index 0000000..4ef1ac1
--- /dev/null
+++ b/src/main/java/cn/bugstack/chatglm/model/ImageCompletionResponse.java
@@ -0,0 +1,25 @@
+package cn.bugstack.chatglm.model;
+
+import lombok.Data;
+
+import java.util.List;
+
+/**
+ * CogView 根据用户的文字描述生成图像,使用同步调用方式请求接口
+ */
+@Data
+public class ImageCompletionResponse {
+
+ /**
+ * 请求创建时间,是以秒为单位的Unix时间戳。
+ */
+ private Long created;
+
+ private List data;
+
+ @Data
+ public static class Image{
+ private String url;
+ }
+
+}
diff --git a/src/main/java/cn/bugstack/chatglm/model/Model.java b/src/main/java/cn/bugstack/chatglm/model/Model.java
index d10bf01..7df64e3 100644
--- a/src/main/java/cn/bugstack/chatglm/model/Model.java
+++ b/src/main/java/cn/bugstack/chatglm/model/Model.java
@@ -6,7 +6,7 @@ import lombok.Getter;
/**
* @author 小傅哥,微信:fustack
* @description 会话模型
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Getter
@@ -23,9 +23,13 @@ public enum Model {
CHATGLM_STD("chatglm_std", "适用于对知识量、推理能力、创造力要求较高的场景"),
@Deprecated
CHATGLM_PRO("chatglm_pro", "适用于对知识量、推理能力、创造力要求较高的场景"),
- /** 智谱AI最新模型 */
+ /** 智谱AI 23年06月发布 */
CHATGLM_TURBO("chatglm_turbo", "适用于对知识量、推理能力、创造力要求较高的场景"),
-
+ /** 智谱AI 24年01月发布 */
+ GLM_3_5_TURBO("glm-3-turbo","适用于对知识量、推理能力、创造力要求较高的场景"),
+ GLM_4("glm-4","适用于复杂的对话交互和深度内容创作设计的场景"),
+ GLM_4V("glm-4v","根据输入的自然语言指令和图像信息完成任务,推荐使用 SSE 或同步调用方式请求接口"),
+ COGVIEW_3("cogview-3","根据用户的文字描述生成图像,使用同步调用方式请求接口"),
;
private final String code;
private final String info;
diff --git a/src/main/java/cn/bugstack/chatglm/model/Role.java b/src/main/java/cn/bugstack/chatglm/model/Role.java
index 2b7aa9a..caa4982 100644
--- a/src/main/java/cn/bugstack/chatglm/model/Role.java
+++ b/src/main/java/cn/bugstack/chatglm/model/Role.java
@@ -6,7 +6,7 @@ import lombok.Getter;
/**
* @author 小傅哥,微信:fustack
* @description 角色
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Getter
diff --git a/src/main/java/cn/bugstack/chatglm/session/Configuration.java b/src/main/java/cn/bugstack/chatglm/session/Configuration.java
index 9d22ac1..07ec3c3 100644
--- a/src/main/java/cn/bugstack/chatglm/session/Configuration.java
+++ b/src/main/java/cn/bugstack/chatglm/session/Configuration.java
@@ -1,6 +1,10 @@
package cn.bugstack.chatglm.session;
import cn.bugstack.chatglm.IOpenAiApi;
+import cn.bugstack.chatglm.executor.Executor;
+import cn.bugstack.chatglm.executor.aigc.GLMOldExecutor;
+import cn.bugstack.chatglm.executor.aigc.GLMExecutor;
+import cn.bugstack.chatglm.model.Model;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
@@ -8,10 +12,12 @@ import okhttp3.logging.HttpLoggingInterceptor;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSources;
+import java.util.HashMap;
+
/**
* @author 小傅哥,微信:fustack
* @description 配置文件
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Slf4j
@@ -22,7 +28,7 @@ public class Configuration {
// 智普Ai ChatGlM 请求地址
@Getter
@Setter
- private String apiHost = "https://open.bigmodel.cn/api/paas/";
+ private String apiHost = "https://open.bigmodel.cn/";
// 智普Ai https://open.bigmodel.cn/usercenter/apikeys - apiSecretKey = {apiKey}.{apiSecret}
private String apiSecretKey;
@@ -69,10 +75,31 @@ public class Configuration {
@Getter
private long readTimeout = 450;
+ private HashMap executorGroup;
+
// http keywords
public static final String SSE_CONTENT_TYPE = "text/event-stream";
public static final String DEFAULT_USER_AGENT = "Mozilla/4.0 (compatible; MSIE 5.0; Windows NT; DigExt)";
public static final String APPLICATION_JSON = "application/json";
public static final String JSON_CONTENT_TYPE = APPLICATION_JSON + "; charset=utf-8";
+ public HashMap newExecutorGroup() {
+ this.executorGroup = new HashMap<>();
+ // 旧版模型,兼容
+ Executor glmOldExecutor = new GLMOldExecutor(this);
+ this.executorGroup.put(Model.CHATGLM_6B_SSE, glmOldExecutor);
+ this.executorGroup.put(Model.CHATGLM_LITE, glmOldExecutor);
+ this.executorGroup.put(Model.CHATGLM_LITE_32K, glmOldExecutor);
+ this.executorGroup.put(Model.CHATGLM_STD, glmOldExecutor);
+ this.executorGroup.put(Model.CHATGLM_PRO, glmOldExecutor);
+ this.executorGroup.put(Model.CHATGLM_TURBO, glmOldExecutor);
+ // 新版模型,配置
+ Executor glmExecutor = new GLMExecutor(this);
+ this.executorGroup.put(Model.GLM_3_5_TURBO, glmExecutor);
+ this.executorGroup.put(Model.GLM_4, glmExecutor);
+ this.executorGroup.put(Model.GLM_4V, glmExecutor);
+ this.executorGroup.put(Model.COGVIEW_3, glmExecutor);
+ return this.executorGroup;
+ }
+
}
diff --git a/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java b/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java
index 88b2595..1f388df 100644
--- a/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java
+++ b/src/main/java/cn/bugstack/chatglm/session/OpenAiSession.java
@@ -1,8 +1,6 @@
package cn.bugstack.chatglm.session;
-import cn.bugstack.chatglm.model.ChatCompletionRequest;
-import cn.bugstack.chatglm.model.ChatCompletionResponse;
-import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
+import cn.bugstack.chatglm.model.*;
import com.fasterxml.jackson.core.JsonProcessingException;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
@@ -13,15 +11,19 @@ import java.util.concurrent.CompletableFuture;
/**
* @author 小傅哥,微信:fustack
* @description 会话服务接口
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
public interface OpenAiSession {
- EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws JsonProcessingException;
+ EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws Exception;
- CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException;
+ CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws Exception;
- ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException;
+ ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws Exception;
+
+ ImageCompletionResponse genImages(ImageCompletionRequest imageCompletionRequest) throws Exception;
+
+ Configuration configuration();
}
diff --git a/src/main/java/cn/bugstack/chatglm/session/OpenAiSessionFactory.java b/src/main/java/cn/bugstack/chatglm/session/OpenAiSessionFactory.java
index f60a4b7..c823404 100644
--- a/src/main/java/cn/bugstack/chatglm/session/OpenAiSessionFactory.java
+++ b/src/main/java/cn/bugstack/chatglm/session/OpenAiSessionFactory.java
@@ -3,7 +3,7 @@ package cn.bugstack.chatglm.session;
/**
* @author 小傅哥,微信:fustack
* @description 工厂接口
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
public interface OpenAiSessionFactory {
diff --git a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java
index 687e408..41da71c 100644
--- a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java
+++ b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSession.java
@@ -1,112 +1,64 @@
package cn.bugstack.chatglm.session.defaults;
-import cn.bugstack.chatglm.IOpenAiApi;
-import cn.bugstack.chatglm.model.ChatCompletionRequest;
-import cn.bugstack.chatglm.model.ChatCompletionResponse;
-import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
-import cn.bugstack.chatglm.model.EventType;
+import cn.bugstack.chatglm.executor.Executor;
+import cn.bugstack.chatglm.model.*;
import cn.bugstack.chatglm.session.Configuration;
import cn.bugstack.chatglm.session.OpenAiSession;
-import com.alibaba.fastjson.JSON;
-import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.extern.slf4j.Slf4j;
-import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
-import org.jetbrains.annotations.Nullable;
import java.io.IOException;
+import java.util.Map;
import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.CountDownLatch;
/**
* @author 小傅哥,微信:fustack
* @description 会话服务
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Slf4j
public class DefaultOpenAiSession implements OpenAiSession {
- /**
- * OpenAi 接口
- */
private final Configuration configuration;
- /**
- * 工厂事件
- */
- private final EventSource.Factory factory;
+ private final Map executorGroup;
- public DefaultOpenAiSession(Configuration configuration) {
+ public DefaultOpenAiSession(Configuration configuration, Map executorGroup) {
this.configuration = configuration;
- this.factory = configuration.createRequestFactory();
+ this.executorGroup = executorGroup;
}
-
@Override
- public EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws JsonProcessingException {
- // 构建请求信息
- Request request = new Request.Builder()
- .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions).replace("{model}", chatCompletionRequest.getModel().getCode()))
- .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
- .build();
-
- // 返回事件结果
- return factory.newEventSource(request, eventSourceListener);
+ public EventSource completions(ChatCompletionRequest chatCompletionRequest, EventSourceListener eventSourceListener) throws Exception {
+ Executor executor = executorGroup.get(chatCompletionRequest.getModel());
+ if (null == executor) throw new RuntimeException(chatCompletionRequest.getModel() + " 模型执行器尚未实现!");
+ return executor.completions(chatCompletionRequest, eventSourceListener);
}
@Override
- public CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException {
- // 用于执行异步任务并获取结果
- CompletableFuture future = new CompletableFuture<>();
- StringBuffer dataBuffer = new StringBuffer();
-
- // 构建请求信息
- Request request = new Request.Builder()
- .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions).replace("{model}", chatCompletionRequest.getModel().getCode()))
- .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
- .build();
-
- // 异步响应请求
- factory.newEventSource(request, new EventSourceListener() {
- @Override
- public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
- ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
- // type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断
- if (EventType.add.getCode().equals(type)) {
- dataBuffer.append(response.getData());
- } else if (EventType.finish.getCode().equals(type)) {
- future.complete(dataBuffer.toString());
- }
- }
-
- @Override
- public void onClosed(EventSource eventSource) {
- future.completeExceptionally(new RuntimeException("Request closed before completion"));
- }
+ public CompletableFuture completions(ChatCompletionRequest chatCompletionRequest) throws Exception {
+ Executor executor = executorGroup.get(chatCompletionRequest.getModel());
+ if (null == executor) throw new RuntimeException(chatCompletionRequest.getModel() + " 模型执行器尚未实现!");
+ return executor.completions(chatCompletionRequest);
+ }
- @Override
- public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
- future.completeExceptionally(new RuntimeException("Request closed before completion"));
- }
- });
+ @Override
+ public ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws Exception {
+ Executor executor = executorGroup.get(chatCompletionRequest.getModel());
+ if (null == executor) throw new RuntimeException(chatCompletionRequest.getModel() + " 模型执行器尚未实现!");
+ return executor.completionsSync(chatCompletionRequest);
+ }
- return future;
+ @Override
+ public ImageCompletionResponse genImages(ImageCompletionRequest imageCompletionRequest) throws Exception {
+ Executor executor = executorGroup.get(imageCompletionRequest.getModelEnum());
+ if (null == executor) throw new RuntimeException(imageCompletionRequest.getModel() + " 模型执行器尚未实现!");
+ return executor.genImages(imageCompletionRequest);
}
@Override
- public ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException {
- // 构建请求信息
- Request request = new Request.Builder()
- .url(configuration.getApiHost().concat(IOpenAiApi.v3_completions_sync).replace("{model}", chatCompletionRequest.getModel().getCode()))
- .header("Accept",Configuration.APPLICATION_JSON)
- .post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
- .build();
- OkHttpClient okHttpClient = configuration.getOkHttpClient();
- Response response = okHttpClient.newCall(request).execute();
- if(!response.isSuccessful()){
- new RuntimeException("Request failed");
- }
- return JSON.parseObject(response.body().string(),ChatCompletionSyncResponse.class);
+ public Configuration configuration() {
+ return configuration;
}
}
diff --git a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSessionFactory.java b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSessionFactory.java
index b72dd97..a2398ea 100644
--- a/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSessionFactory.java
+++ b/src/main/java/cn/bugstack/chatglm/session/defaults/DefaultOpenAiSessionFactory.java
@@ -1,7 +1,9 @@
package cn.bugstack.chatglm.session.defaults;
import cn.bugstack.chatglm.IOpenAiApi;
+import cn.bugstack.chatglm.executor.Executor;
import cn.bugstack.chatglm.interceptor.OpenAiHTTPInterceptor;
+import cn.bugstack.chatglm.model.Model;
import cn.bugstack.chatglm.session.Configuration;
import cn.bugstack.chatglm.session.OpenAiSession;
import cn.bugstack.chatglm.session.OpenAiSessionFactory;
@@ -11,12 +13,13 @@ import retrofit2.Retrofit;
import retrofit2.adapter.rxjava2.RxJava2CallAdapterFactory;
import retrofit2.converter.jackson.JacksonConverterFactory;
+import java.util.HashMap;
import java.util.concurrent.TimeUnit;
/**
* @author 小傅哥,微信:fustack
* @description 会话工厂
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
public class DefaultOpenAiSessionFactory implements OpenAiSessionFactory {
@@ -55,7 +58,10 @@ public class DefaultOpenAiSessionFactory implements OpenAiSessionFactory {
configuration.setOpenAiApi(openAiApi);
- return new DefaultOpenAiSession(configuration);
+ // 4. 实例化执行器
+ HashMap executorGroup = configuration.newExecutorGroup();
+
+ return new DefaultOpenAiSession(configuration, executorGroup);
}
}
diff --git a/src/main/java/cn/bugstack/chatglm/utils/BearerTokenUtils.java b/src/main/java/cn/bugstack/chatglm/utils/BearerTokenUtils.java
index b19927c..199ff69 100644
--- a/src/main/java/cn/bugstack/chatglm/utils/BearerTokenUtils.java
+++ b/src/main/java/cn/bugstack/chatglm/utils/BearerTokenUtils.java
@@ -15,7 +15,7 @@ import java.util.concurrent.TimeUnit;
/**
* @author 小傅哥,微信:fustack
* @description 签名工具包;过期时间30分钟,缓存时间29分钟
- * @github https://github.com/fuzhengwei
+ * @github https://github.com/fuzhengwei/chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Slf4j
diff --git a/src/test/java/cn/bugstack/chatglm/test/ApiTest.java b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java
index 1808b1d..f0d7224 100644
--- a/src/test/java/cn/bugstack/chatglm/test/ApiTest.java
+++ b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java
@@ -9,6 +9,7 @@ import cn.bugstack.chatglm.utils.BearerTokenUtils;
import com.alibaba.fastjson.JSON;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.extern.slf4j.Slf4j;
+import okhttp3.Response;
import okhttp3.logging.HttpLoggingInterceptor;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
@@ -26,7 +27,7 @@ import java.util.concurrent.ExecutionException;
/**
* @author 小傅哥,微信:fustack
* @description 在官网申请 ApiSecretKey ApiSecretKey
- * @github https://github.com/fuzhengwei
+ * @github chatglm-sdk-java
* @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获!
*/
@Slf4j
@@ -39,7 +40,7 @@ public class ApiTest {
// 1. 配置文件
Configuration configuration = new Configuration();
configuration.setApiHost("https://open.bigmodel.cn/");
- configuration.setApiSecretKey("d570f7c5d289cdac2abdfdc562e39f3f.trqz1dH8ZK6ED7Pg");
+ configuration.setApiSecretKey("62ddec38b1d0b9a7b0fddaf271e6ed90.HpD0SUBUlvqd05ey");
configuration.setLevel(HttpLoggingInterceptor.Level.BODY);
// 2. 会话工厂
OpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration);
@@ -48,76 +49,213 @@ public class ApiTest {
}
/**
- * 流式对话
+ * 流式对话;
+ * 1. 默认 isCompatible = true 会兼容新旧版数据格式
+ * 2. GLM_3_5_TURBO、GLM_4 支持联网等插件
*/
@Test
- public void test_completions() throws JsonProcessingException, InterruptedException {
+ public void test_completions() throws Exception {
+ CountDownLatch countDownLatch = new CountDownLatch(1);
+
// 入参;模型、请求信息
ChatCompletionRequest request = new ChatCompletionRequest();
- request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
+ request.setModel(Model.GLM_3_5_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
request.setIncremental(false);
+ request.setIsCompatible(true); // 是否对返回结果数据做兼容,24年1月发布的 GLM_3_5_TURBO、GLM_4 模型,与之前的模型在返回结果上有差异。开启 true 可以做兼容。
+ // 24年1月发布的 glm-3-turbo、glm-4 支持函数、知识库、联网功能
+ request.setTools(new ArrayList() {
+ private static final long serialVersionUID = -7988151926241837899L;
+
+ {
+ add(ChatCompletionRequest.Tool.builder()
+ .type(ChatCompletionRequest.Tool.Type.web_search)
+ .webSearch(ChatCompletionRequest.Tool.WebSearch.builder().enable(true).searchQuery("小傅哥").build())
+ .build());
+ }
+ });
request.setPrompt(new ArrayList() {
private static final long serialVersionUID = -7988151926241837899L;
{
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
- .content("1+2")
+ .content("小傅哥的是谁")
+ .build());
+ }
+ });
+
+ // 请求
+ openAiSession.completions(request, new EventSourceListener() {
+ @Override
+ public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ log.info("测试结果 onEvent:{}", response.getData());
+ // type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断
+ if (EventType.finish.getCode().equals(type)) {
+ ChatCompletionResponse.Meta meta = JSON.parseObject(response.getMeta(), ChatCompletionResponse.Meta.class);
+ log.info("[输出结束] Tokens {}", JSON.toJSONString(meta));
+ }
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource) {
+ log.info("对话完成");
+ countDownLatch.countDown();
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ log.info("对话异常");
+ countDownLatch.countDown();
+ }
+ });
+
+ // 等待
+ countDownLatch.await();
+ }
+
+ /**
+ * 流式对话;
+ * 1. 与 test_completions 测试类相比,只是设置 isCompatible = false 这样就是使用了新的数据结构。onEvent 处理接收数据有差异
+ * 2. 不兼容旧版格式的话,仅支持 GLM_3_5_TURBO、GLM_4 其他模型会有解析错误
+ */
+ @Test
+ public void test_completions_new() throws Exception {
+ CountDownLatch countDownLatch = new CountDownLatch(1);
+
+ // 入参;模型、请求信息
+ ChatCompletionRequest request = new ChatCompletionRequest();
+ request.setModel(Model.GLM_3_5_TURBO); // GLM_3_5_TURBO、GLM_4
+ request.setIsCompatible(false);
+ // 24年1月发布的 glm-3-turbo、glm-4 支持函数、知识库、联网功能
+ request.setTools(new ArrayList() {
+ private static final long serialVersionUID = -7988151926241837899L;
+
+ {
+ add(ChatCompletionRequest.Tool.builder()
+ .type(ChatCompletionRequest.Tool.Type.web_search)
+ .webSearch(ChatCompletionRequest.Tool.WebSearch.builder().enable(true).searchQuery("小傅哥").build())
.build());
+ }
+ });
+ request.setMessages(new ArrayList() {
+ private static final long serialVersionUID = -7988151926241837899L;
+ {
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
- .content("Okay")
+ .content("小傅哥的是谁")
.build());
+ }
+ });
- /* system 和 user 为一组出现。如果有参数类型为 system 则 system + user 一组一起传递。*/
+ // 请求
+ openAiSession.completions(request, new EventSourceListener() {
+ @Override
+ public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
+ if ("[DONE]".equals(data)) {
+ log.info("[输出结束] Tokens {}", JSON.toJSONString(data));
+ return;
+ }
+
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ log.info("测试结果:{}", JSON.toJSONString(response));
+ }
+
+ @Override
+ public void onClosed(EventSource eventSource) {
+ log.info("对话完成");
+ countDownLatch.countDown();
+ }
+
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ log.error("对话失败", t);
+ countDownLatch.countDown();
+ }
+ });
+
+ // 等待
+ countDownLatch.await();
+ }
+
+ /**
+ * 模型编码:glm-4v
+ * 根据输入的自然语言指令和图像信息完成任务,推荐使用 SSE 或同步调用方式请求接口
+ * https://open.bigmodel.cn/dev/api#glm-4v
+ */
+ @Test
+ public void test_completions_4v() throws Exception {
+ CountDownLatch countDownLatch = new CountDownLatch(1);
+ // 入参;模型、请求信息
+ ChatCompletionRequest request = new ChatCompletionRequest();
+ request.setModel(Model.GLM_4V); // GLM_3_5_TURBO、GLM_4
+ request.setStream(true);
+ request.setMessages(new ArrayList() {
+ private static final long serialVersionUID = -7988151926241837899L;
+
+ {
+ // content 字符串格式
add(ChatCompletionRequest.Prompt.builder()
- .role(Role.system.getCode())
- .content("1+1=2")
+ .role(Role.user.getCode())
+ .content("这个图片写了什么")
.build());
+ // content 对象格式
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
- .content("Okay")
+ .content(ChatCompletionRequest.Prompt.Content.builder()
+ .type(ChatCompletionRequest.Prompt.Content.Type.text.getCode())
+ .text("这是什么图片")
+ .build())
.build());
+ // content 对象格式,上传图片;图片支持url、basde64
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
- .content("1+2")
+ .content(ChatCompletionRequest.Prompt.Content.builder()
+ .type(ChatCompletionRequest.Prompt.Content.Type.image_url.getCode())
+ .imageUrl(ChatCompletionRequest.Prompt.Content.ImageUrl.builder().url("https://bugstack.cn/images/article/project/chatgpt/chatgpt-extra-231011-01.png").build())
+ .build())
.build());
-
}
});
- // 请求
openAiSession.completions(request, new EventSourceListener() {
@Override
public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) {
- ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
- log.info("测试结果 onEvent:{}", response.getData());
- // type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断
- if (EventType.finish.getCode().equals(type)) {
- ChatCompletionResponse.Meta meta = JSON.parseObject(response.getMeta(), ChatCompletionResponse.Meta.class);
- log.info("[输出结束] Tokens {}", JSON.toJSONString(meta));
+ if ("[DONE]".equals(data)) {
+ log.info("[输出结束] Tokens {}", JSON.toJSONString(data));
+ return;
}
+
+ ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class);
+ log.info("测试结果:{}", JSON.toJSONString(response));
}
@Override
public void onClosed(EventSource eventSource) {
log.info("对话完成");
+ countDownLatch.countDown();
}
+ @Override
+ public void onFailure(EventSource eventSource, @Nullable Throwable t, @Nullable Response response) {
+ log.error("对话失败", t);
+ countDownLatch.countDown();
+ }
});
// 等待
- new CountDownLatch(1).await();
+ countDownLatch.await();
+
}
/**
* 同步请求
*/
@Test
- public void test_completions_future() throws ExecutionException, InterruptedException {
+ public void test_completions_future() throws Exception {
// 入参;模型、请求信息
ChatCompletionRequest request = new ChatCompletionRequest();
request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
@@ -142,17 +280,29 @@ public class ApiTest {
* 同步请求
*/
@Test
- public void test_completions_sync() throws IOException {
+ public void test_completions_sync() throws Exception {
// 入参;模型、请求信息
ChatCompletionRequest request = new ChatCompletionRequest();
- request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
+ request.setModel(Model.GLM_4V); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
request.setPrompt(new ArrayList() {
private static final long serialVersionUID = -7988151926241837899L;
{
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
- .content("1+1")
+ .content("小傅哥是谁")
+ .build());
+ }
+ });
+
+ // 24年1月发布的 glm-3-turbo、glm-4 支持函数、知识库、联网功能
+ request.setTools(new ArrayList() {
+ private static final long serialVersionUID = -7988151926241837899L;
+
+ {
+ add(ChatCompletionRequest.Tool.builder()
+ .type(ChatCompletionRequest.Tool.Type.web_search)
+ .webSearch(ChatCompletionRequest.Tool.WebSearch.builder().enable(true).searchQuery("小傅哥").build())
.build());
}
});
@@ -162,13 +312,21 @@ public class ApiTest {
log.info("测试结果:{}", JSON.toJSONString(response));
}
+ @Test
+ public void test_genImages() throws Exception {
+ ImageCompletionRequest request = new ImageCompletionRequest();
+ request.setModel(Model.COGVIEW_3);
+ request.setPrompt("画个小狗");
+ ImageCompletionResponse response = openAiSession.genImages(request);
+ log.info("测试结果:{}", JSON.toJSONString(response));
+ }
@Test
public void test_curl() {
// 1. 配置文件
Configuration configuration = new Configuration();
configuration.setApiHost("https://open.bigmodel.cn/");
- configuration.setApiSecretKey("4d00226f242793b9c267a64ab2eaf5cb.aIwQNiG59MhSWJbn");
+ configuration.setApiSecretKey("62ddec38b1d0b9a7b0fddaf271e6ed90.HpD0SUBUlvqd05ey");
// 2. 获取Token
String token = BearerTokenUtils.getToken(configuration.getApiKey(), configuration.getApiSecret());
diff --git a/src/test/java/cn/bugstack/chatglm/test/JSONTest.java b/src/test/java/cn/bugstack/chatglm/test/JSONTest.java
new file mode 100644
index 0000000..79039bd
--- /dev/null
+++ b/src/test/java/cn/bugstack/chatglm/test/JSONTest.java
@@ -0,0 +1,68 @@
+package cn.bugstack.chatglm.test;
+
+import cn.bugstack.chatglm.model.ChatCompletionResponse;
+import cn.bugstack.chatglm.model.ImageCompletionResponse;
+import com.alibaba.fastjson.JSON;
+import lombok.extern.slf4j.Slf4j;
+import org.junit.Test;
+
+@Slf4j
+public class JSONTest {
+
+ @Test
+ public void test_glm_json() {
+ String json01 = "{\n" +
+ " \"id\": \"8305987191663349153\",\n" +
+ " \"created\": 1705487423,\n" +
+ " \"model\": \"glm-3-turbo\",\n" +
+ " \"choices\": [\n" +
+ " {\n" +
+ " \"index\": 0,\n" +
+ " \"delta\": {\n" +
+ " \"role\": \"assistant\",\n" +
+ " \"content\": \"1\"\n" +
+ " }\n" +
+ " }\n" +
+ " ]\n" +
+ "}";
+
+ String json02 = "{\n" +
+ " \"id\": \"8308763664682731117\",\n" +
+ " \"created\": 1705490859,\n" +
+ " \"model\": \"glm-3-turbo\",\n" +
+ " \"choices\": [\n" +
+ " {\n" +
+ " \"index\": 0,\n" +
+ " \"finish_reason\": \"stop\",\n" +
+ " \"delta\": {\n" +
+ " \"role\": \"assistant\",\n" +
+ " \"content\": \"\"\n" +
+ " }\n" +
+ " }\n" +
+ " ],\n" +
+ " \"usage\": {\n" +
+ " \"prompt_tokens\": 8,\n" +
+ " \"completion_tokens\": 12,\n" +
+ " \"total_tokens\": 20\n" +
+ " }\n" +
+ "}";
+
+ ChatCompletionResponse response = JSON.parseObject(json01, ChatCompletionResponse.class);
+ log.info("测试结果:{}", JSON.toJSONString(response.getChoices()));
+ }
+
+ @Test
+ public void test_image_json(){
+ String json = "{\n" +
+ " \"created\": 1705549253,\n" +
+ " \"data\": [\n" +
+ " {\n" +
+ " \"url\": \"https://sfile.chatglm.cn/testpath/cbffcbf4-ac63-50a3-9d1e-b644c77ffaa2_0.png\"\n" +
+ " }\n" +
+ " ]\n" +
+ "}";
+ ImageCompletionResponse response = JSON.parseObject(json, ImageCompletionResponse.class);
+ log.info("测试结果:{}", response.getData().get(0).getUrl());
+ }
+
+}
--
GitLab