diff --git a/src/test/java/cn/bugstack/chatglm/test/ApiTest.java b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java new file mode 100644 index 0000000000000000000000000000000000000000..562617c70cb549a2aca5ca572635eb2943867e0a --- /dev/null +++ b/src/test/java/cn/bugstack/chatglm/test/ApiTest.java @@ -0,0 +1,86 @@ +package cn.bugstack.chatglm.test; + +import cn.bugstack.chatglm.model.*; +import cn.bugstack.chatglm.session.Configuration; +import cn.bugstack.chatglm.session.OpenAiSession; +import cn.bugstack.chatglm.session.OpenAiSessionFactory; +import cn.bugstack.chatglm.session.defaults.DefaultOpenAiSessionFactory; +import com.alibaba.fastjson.JSON; +import com.fasterxml.jackson.core.JsonProcessingException; +import lombok.extern.slf4j.Slf4j; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import org.apache.commons.lang3.StringUtils; +import org.jetbrains.annotations.Nullable; +import org.junit.Before; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.concurrent.CountDownLatch; + +/** + * @author 小傅哥,微信:fustack + * @description 在官网申请 ApiSecretKey ApiSecretKey + * @github https://github.com/fuzhengwei + * @Copyright 公众号:bugstack虫洞栈 | 博客:https://bugstack.cn - 沉淀、分享、成长,让自己和他人都能有所收获! + */ +@Slf4j +public class ApiTest { + + private OpenAiSession openAiSession; + + @Before + public void test_OpenAiSessionFactory() { + // 1. 配置文件 + Configuration configuration = new Configuration(); + configuration.setApiHost("https://open.bigmodel.cn/"); + configuration.setApiSecretKey("4e087e4135306ef4a676f0cce3cee560.sgP2DUsWEVPxk0UI"); + // 2. 会话工厂 + OpenAiSessionFactory factory = new DefaultOpenAiSessionFactory(configuration); + // 3. 开启会话 + this.openAiSession = factory.openSession(); + } + + /** + * 流式对话 + */ + @Test + public void test_completions() throws JsonProcessingException, InterruptedException { + // 入参;模型、请求信息 + ChatCompletionRequest request = new ChatCompletionRequest(); + request.setModel(Model.CHATGLM_LITE); // chatGLM_6b_SSE、chatglm_lite、chatglm_lite_32k、chatglm_std、chatglm_pro + request.setPrompt(new ArrayList() { + private static final long serialVersionUID = -7988151926241837899L; + + { + add(ChatCompletionRequest.Prompt.builder() + .role(Role.user.getCode()) + .content("写个java冒泡排序") + .build()); + } + }); + + // 请求 + openAiSession.completions(request, new EventSourceListener() { + @Override + public void onEvent(EventSource eventSource, @Nullable String id, @Nullable String type, String data) { + ChatCompletionResponse response = JSON.parseObject(data, ChatCompletionResponse.class); + log.info("测试结果 onEvent:{}", response.getData()); + // type 消息类型,add 增量,finish 结束,error 错误,interrupted 中断 + if (EventType.finish.getCode().equals(type)) { + ChatCompletionResponse.Meta meta = JSON.parseObject(response.getMeta(), ChatCompletionResponse.Meta.class); + log.info("[输出结束] Tokens {}", JSON.toJSONString(meta)); + } + } + + @Override + public void onClosed(EventSource eventSource) { + log.info("对话完成"); + } + }); + + // 等待 + new CountDownLatch(1).await(); + } + +}