未验证 提交 d3f14eed 编写于 作者: 小傅哥 提交者: GitHub

Merge pull request #3 from max-holo/dev_max

feat:add chatglm sync method
......@@ -2,6 +2,7 @@ package cn.bugstack.chatglm;
import cn.bugstack.chatglm.model.ChatCompletionRequest;
import cn.bugstack.chatglm.model.ChatCompletionResponse;
import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
import io.reactivex.Single;
import retrofit2.http.Body;
import retrofit2.http.POST;
......@@ -16,8 +17,12 @@ import retrofit2.http.Path;
public interface IOpenAiApi {
String v3_completions = "api/paas/v3/model-api/{model}/sse-invoke";
String v3_completions_sync = "api/paas/v3/model-api/{model}/invoke";
@POST(v3_completions)
Single<ChatCompletionResponse> completions(@Path("model") String model, @Body ChatCompletionRequest chatCompletionRequest);
@POST(v3_completions_sync)
Single<ChatCompletionSyncResponse> completions(@Body ChatCompletionRequest chatCompletionRequest);
}
......@@ -30,14 +30,13 @@ public class OpenAiHTTPInterceptor implements Interceptor {
public @NotNull Response intercept(Chain chain) throws IOException {
// 1. 获取原始 Request
Request original = chain.request();
// 2. 构建请求
Request request = original.newBuilder()
.url(original.url())
.header("Authorization", "Bearer " + BearerTokenUtils.getToken(configuration.getApiKey(), configuration.getApiSecret()))
.header("Content-Type", Configuration.JSON_CONTENT_TYPE)
.header("User-Agent", Configuration.DEFAULT_USER_AGENT)
.header("Accept", Configuration.SSE_CONTENT_TYPE)
.header("Accept", null != original.header("Accept") ? original.header("Accept") : Configuration.SSE_CONTENT_TYPE)
.method(original.method(), original.body())
.build();
......
package cn.bugstack.chatglm.model;
import lombok.Data;
import java.util.List;
/**
* 同步调用响应
* @author max
* @date 2023/12/14 15:41
*/
@Data
public class ChatCompletionSyncResponse {
private Integer code;
private String msg;
private Boolean success;
private ChatGLMData data;
@Data
public static class ChatGLMData {
private List<Choice> choices;
private String task_status;
private Usage usage;
private String task_id;
private String request_id;
}
@Data
public static class Usage {
private int completion_tokens;
private int prompt_tokens;
private int total_tokens;
}
@Data
public static class Choice {
private String role;
private String content;
}
}
......@@ -2,10 +2,12 @@ package cn.bugstack.chatglm.session;
import cn.bugstack.chatglm.model.ChatCompletionRequest;
import cn.bugstack.chatglm.model.ChatCompletionResponse;
import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
import com.fasterxml.jackson.core.JsonProcessingException;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
/**
......@@ -20,4 +22,6 @@ public interface OpenAiSession {
CompletableFuture<String> completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException;
ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException;
}
......@@ -3,20 +3,19 @@ package cn.bugstack.chatglm.session.defaults;
import cn.bugstack.chatglm.IOpenAiApi;
import cn.bugstack.chatglm.model.ChatCompletionRequest;
import cn.bugstack.chatglm.model.ChatCompletionResponse;
import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
import cn.bugstack.chatglm.model.EventType;
import cn.bugstack.chatglm.session.Configuration;
import cn.bugstack.chatglm.session.OpenAiSession;
import com.alibaba.fastjson.JSON;
import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.*;
import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.Nullable;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
......@@ -94,4 +93,20 @@ public class DefaultOpenAiSession implements OpenAiSession {
return future;
}
@Override
public ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException {
// 构建请求信息
Request request = new Request.Builder()
.url(configuration.getApiHost().concat(IOpenAiApi.v3_completions_sync).replace("{model}", chatCompletionRequest.getModel().getCode()))
.header("Accept",Configuration.APPLICATION_JSON)
.post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
.build();
OkHttpClient okHttpClient = configuration.getOkHttpClient();
Response response = okHttpClient.newCall(request).execute();
if(!response.isSuccessful()){
new RuntimeException("Request failed");
}
return JSON.parseObject(response.body().string(),ChatCompletionSyncResponse.class);
}
}
......@@ -17,6 +17,7 @@ import org.jetbrains.annotations.Nullable;
import org.junit.Before;
import org.junit.Test;
import java.io.IOException;
import java.util.ArrayList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
......@@ -119,7 +120,7 @@ public class ApiTest {
public void test_completions_future() throws ExecutionException, InterruptedException {
// 入参;模型、请求信息
ChatCompletionRequest request = new ChatCompletionRequest();
request.setModel(Model.CHATGLM_LITE); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
request.setPrompt(new ArrayList<ChatCompletionRequest.Prompt>() {
private static final long serialVersionUID = -7988151926241837899L;
......@@ -137,6 +138,30 @@ public class ApiTest {
log.info("测试结果:{}", response);
}
/**
* 同步请求
*/
@Test
public void test_completions_sync() throws IOException {
// 入参;模型、请求信息
ChatCompletionRequest request = new ChatCompletionRequest();
request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
request.setPrompt(new ArrayList<ChatCompletionRequest.Prompt>() {
private static final long serialVersionUID = -7988151926241837899L;
{
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
.content("写个java冒泡排序")
.build());
}
});
ChatCompletionSyncResponse response = openAiSession.completionsSync(request);
log.info("测试结果:{}", JSON.toJSONString(response));
}
@Test
public void test_curl() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册