提交 399606ee 编写于 作者: M maxuan

feat:add chatglm sync method

上级 1b69e40d
...@@ -2,6 +2,7 @@ package cn.bugstack.chatglm; ...@@ -2,6 +2,7 @@ package cn.bugstack.chatglm;
import cn.bugstack.chatglm.model.ChatCompletionRequest; import cn.bugstack.chatglm.model.ChatCompletionRequest;
import cn.bugstack.chatglm.model.ChatCompletionResponse; import cn.bugstack.chatglm.model.ChatCompletionResponse;
import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
import io.reactivex.Single; import io.reactivex.Single;
import retrofit2.http.Body; import retrofit2.http.Body;
import retrofit2.http.POST; import retrofit2.http.POST;
...@@ -16,8 +17,12 @@ import retrofit2.http.Path; ...@@ -16,8 +17,12 @@ import retrofit2.http.Path;
public interface IOpenAiApi { public interface IOpenAiApi {
String v3_completions = "api/paas/v3/model-api/{model}/sse-invoke"; String v3_completions = "api/paas/v3/model-api/{model}/sse-invoke";
String v3_completions_sync = "api/paas/v3/model-api/{model}/invoke";
@POST(v3_completions) @POST(v3_completions)
Single<ChatCompletionResponse> completions(@Path("model") String model, @Body ChatCompletionRequest chatCompletionRequest); Single<ChatCompletionResponse> completions(@Path("model") String model, @Body ChatCompletionRequest chatCompletionRequest);
@POST(v3_completions_sync)
Single<ChatCompletionSyncResponse> completions(@Body ChatCompletionRequest chatCompletionRequest);
} }
...@@ -30,14 +30,13 @@ public class OpenAiHTTPInterceptor implements Interceptor { ...@@ -30,14 +30,13 @@ public class OpenAiHTTPInterceptor implements Interceptor {
public @NotNull Response intercept(Chain chain) throws IOException { public @NotNull Response intercept(Chain chain) throws IOException {
// 1. 获取原始 Request // 1. 获取原始 Request
Request original = chain.request(); Request original = chain.request();
// 2. 构建请求 // 2. 构建请求
Request request = original.newBuilder() Request request = original.newBuilder()
.url(original.url()) .url(original.url())
.header("Authorization", "Bearer " + BearerTokenUtils.getToken(configuration.getApiKey(), configuration.getApiSecret())) .header("Authorization", "Bearer " + BearerTokenUtils.getToken(configuration.getApiKey(), configuration.getApiSecret()))
.header("Content-Type", Configuration.JSON_CONTENT_TYPE) .header("Content-Type", Configuration.JSON_CONTENT_TYPE)
.header("User-Agent", Configuration.DEFAULT_USER_AGENT) .header("User-Agent", Configuration.DEFAULT_USER_AGENT)
.header("Accept", Configuration.SSE_CONTENT_TYPE) .header("Accept", null != original.header("Accept") ? original.header("Accept") : Configuration.SSE_CONTENT_TYPE)
.method(original.method(), original.body()) .method(original.method(), original.body())
.build(); .build();
......
package cn.bugstack.chatglm.model;
import lombok.Data;
import java.util.List;
/**
* 同步调用响应
* @author max
* @date 2023/12/14 15:41
*/
@Data
public class ChatCompletionSyncResponse {
private Integer code;
private String msg;
private Boolean success;
private ChatGLMData data;
@Data
public static class ChatGLMData {
private List<Choice> choices;
private String task_status;
private Usage usage;
private String task_id;
private String request_id;
}
@Data
public static class Usage {
private int completion_tokens;
private int prompt_tokens;
private int total_tokens;
}
@Data
public static class Choice {
private String role;
private String content;
}
}
...@@ -2,10 +2,12 @@ package cn.bugstack.chatglm.session; ...@@ -2,10 +2,12 @@ package cn.bugstack.chatglm.session;
import cn.bugstack.chatglm.model.ChatCompletionRequest; import cn.bugstack.chatglm.model.ChatCompletionRequest;
import cn.bugstack.chatglm.model.ChatCompletionResponse; import cn.bugstack.chatglm.model.ChatCompletionResponse;
import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import okhttp3.sse.EventSource; import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener; import okhttp3.sse.EventSourceListener;
import java.io.IOException;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
/** /**
...@@ -20,4 +22,6 @@ public interface OpenAiSession { ...@@ -20,4 +22,6 @@ public interface OpenAiSession {
CompletableFuture<String> completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException; CompletableFuture<String> completions(ChatCompletionRequest chatCompletionRequest) throws InterruptedException;
ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException;
} }
...@@ -3,20 +3,19 @@ package cn.bugstack.chatglm.session.defaults; ...@@ -3,20 +3,19 @@ package cn.bugstack.chatglm.session.defaults;
import cn.bugstack.chatglm.IOpenAiApi; import cn.bugstack.chatglm.IOpenAiApi;
import cn.bugstack.chatglm.model.ChatCompletionRequest; import cn.bugstack.chatglm.model.ChatCompletionRequest;
import cn.bugstack.chatglm.model.ChatCompletionResponse; import cn.bugstack.chatglm.model.ChatCompletionResponse;
import cn.bugstack.chatglm.model.ChatCompletionSyncResponse;
import cn.bugstack.chatglm.model.EventType; import cn.bugstack.chatglm.model.EventType;
import cn.bugstack.chatglm.session.Configuration; import cn.bugstack.chatglm.session.Configuration;
import cn.bugstack.chatglm.session.OpenAiSession; import cn.bugstack.chatglm.session.OpenAiSession;
import com.alibaba.fastjson.JSON; import com.alibaba.fastjson.JSON;
import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.JsonProcessingException;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import okhttp3.MediaType; import okhttp3.*;
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okhttp3.sse.EventSource; import okhttp3.sse.EventSource;
import okhttp3.sse.EventSourceListener; import okhttp3.sse.EventSourceListener;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import java.io.IOException;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
...@@ -94,4 +93,20 @@ public class DefaultOpenAiSession implements OpenAiSession { ...@@ -94,4 +93,20 @@ public class DefaultOpenAiSession implements OpenAiSession {
return future; return future;
} }
@Override
public ChatCompletionSyncResponse completionsSync(ChatCompletionRequest chatCompletionRequest) throws IOException {
// 构建请求信息
Request request = new Request.Builder()
.url(configuration.getApiHost().concat(IOpenAiApi.v3_completions_sync).replace("{model}", chatCompletionRequest.getModel().getCode()))
.header("Accept",Configuration.APPLICATION_JSON)
.post(RequestBody.create(MediaType.parse("application/json"), chatCompletionRequest.toString()))
.build();
OkHttpClient okHttpClient = configuration.getOkHttpClient();
Response response = okHttpClient.newCall(request).execute();
if(!response.isSuccessful()){
new RuntimeException("Request failed");
}
return JSON.parseObject(response.body().string(),ChatCompletionSyncResponse.class);
}
} }
...@@ -17,6 +17,7 @@ import org.jetbrains.annotations.Nullable; ...@@ -17,6 +17,7 @@ import org.jetbrains.annotations.Nullable;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import java.io.IOException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
...@@ -119,7 +120,7 @@ public class ApiTest { ...@@ -119,7 +120,7 @@ public class ApiTest {
public void test_completions_future() throws ExecutionException, InterruptedException { public void test_completions_future() throws ExecutionException, InterruptedException {
// 入参;模型、请求信息 // 入参;模型、请求信息
ChatCompletionRequest request = new ChatCompletionRequest(); ChatCompletionRequest request = new ChatCompletionRequest();
request.setModel(Model.CHATGLM_LITE); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
request.setPrompt(new ArrayList<ChatCompletionRequest.Prompt>() { request.setPrompt(new ArrayList<ChatCompletionRequest.Prompt>() {
private static final long serialVersionUID = -7988151926241837899L; private static final long serialVersionUID = -7988151926241837899L;
...@@ -137,6 +138,30 @@ public class ApiTest { ...@@ -137,6 +138,30 @@ public class ApiTest {
log.info("测试结果:{}", response); log.info("测试结果:{}", response);
} }
/**
* 同步请求
*/
@Test
public void test_completions_sync() throws IOException {
// 入参;模型、请求信息
ChatCompletionRequest request = new ChatCompletionRequest();
request.setModel(Model.CHATGLM_TURBO); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro
request.setPrompt(new ArrayList<ChatCompletionRequest.Prompt>() {
private static final long serialVersionUID = -7988151926241837899L;
{
add(ChatCompletionRequest.Prompt.builder()
.role(Role.user.getCode())
.content("写个java冒泡排序")
.build());
}
});
ChatCompletionSyncResponse response = openAiSession.completionsSync(request);
log.info("测试结果:{}", JSON.toJSONString(response));
}
@Test @Test
public void test_curl() { public void test_curl() {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册