fix:java调用open ai

上级 4d5ed5aa
...@@ -237,7 +237,17 @@ ...@@ -237,7 +237,17 @@
<dependency> <dependency>
<groupId>com.github.ben-manes.caffeine</groupId> <groupId>com.github.ben-manes.caffeine</groupId>
<artifactId>caffeine</artifactId> <artifactId>caffeine</artifactId>
<!-- <version>3.1.2</version>--> </dependency>
<dependency>
<groupId>org.asynchttpclient</groupId>
<artifactId>async-http-client</artifactId>
<version>2.12.3</version>
<type>jar</type>
</dependency>
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.10.1</version>
</dependency> </dependency>
</dependencies> </dependencies>
<build> <build>
......
package com.kwan.springbootkwan.entity.openai;
import com.google.gson.annotations.SerializedName;
import lombok.Data;
import java.util.List;
@Data
public class AzureAIChatRequest {
private List<AzureAIMessage> messages;
private Double temperature;
@SerializedName("n")
private Integer choices;
private boolean stream;
private String stop;
@SerializedName("max_tokens")
private Integer maxTokens;
@SerializedName("presence_penalty")
private Integer presencePenalty;
@SerializedName("frequency_penalty")
private Integer frequencyPenalty;
private String user;
}
\ No newline at end of file
package com.kwan.springbootkwan.entity.openai;
import lombok.Data;
import java.util.List;
@Data
public class AzureAIChatResponse {
private String id;
private String object;
private String created;
private String model;
private AzureAIUsage usage;
private List<AzureAIChoice> choices;
}
\ No newline at end of file
package com.kwan.springbootkwan.entity.openai;
import lombok.Data;
@Data
public class AzureAIChoice {
private Object message;
}
package com.kwan.springbootkwan.entity.openai;
import cn.hutool.core.date.BetweenFormatter;
import cn.hutool.core.date.DateUnit;
import cn.hutool.core.date.DateUtil;
import cn.hutool.json.JSONUtil;
import com.google.gson.Gson;
import lombok.extern.slf4j.Slf4j;
import org.asynchttpclient.AsyncHttpClient;
import org.asynchttpclient.DefaultAsyncHttpClient;
import org.asynchttpclient.Request;
import org.asynchttpclient.RequestBuilder;
import org.asynchttpclient.Response;
import java.io.Closeable;
import java.io.IOException;
import java.util.Date;
import java.util.concurrent.Future;
@Slf4j
public class AzureAIClient implements Closeable {
private static final String JSON = "application/json; charset=UTF-8";
private final boolean closeClient;
private final AsyncHttpClient client;
private final String deploymentName;
private final String url;
private final String token;
private final String apiVersion;
private boolean closed = false;
Gson gson = new Gson();
public AzureAIClient(String url, String apiKey, String deploymentName, String apiVersion) throws Exception {
this.client = new DefaultAsyncHttpClient();
this.url = url + "/openai/deployments/" + deploymentName + "/";
this.token = apiKey;
this.deploymentName = deploymentName;
this.apiVersion = apiVersion;
closeClient = true;
}
public boolean isClosed() {
return closed || client.isClosed();
}
@Override
public void close() {
if (closeClient && !client.isClosed()) {
try {
client.close();
} catch (IOException ex) {
}
}
closed = true;
}
public AzureAIChatResponse sendChatRequest(AzureAIChatRequest chatRequest) throws Exception {
Date startDateOne = DateUtil.date();
Future<Response> f = client.executeRequest(buildRequest("POST", "chat/completions?api-version=" + apiVersion, gson.toJson(chatRequest)));
Response r = f.get();
Date endDateOne = DateUtil.date();
// 获取开始时间和结束时间的时间差
long betweenDateOne = DateUtil.between(startDateOne, endDateOne, DateUnit.MS);
// 格式化时间
String formatBetweenOne = DateUtil.formatBetween(betweenDateOne, BetweenFormatter.Level.MILLISECOND);
log.info(String.format("请求数据耗时(毫秒):%s", formatBetweenOne));
if (r.getStatusCode() != 200) {
log.info("Could not create chat request - server resposne was " + r.getStatusCode() + " to url: " + url + "chat/completions?api-version=2023-03-15-preview");
return null;
} else {
Date startDate = DateUtil.date();
AzureAIChatResponse azureAIChatResponse = JSONUtil.toBean(r.getResponseBody(), AzureAIChatResponse.class);
Date endDate = DateUtil.date();
// 获取开始时间和结束时间的时间差
long betweenDate = DateUtil.between(startDate, endDate, DateUnit.MS);
// 格式化时间
String formatBetween = DateUtil.formatBetween(betweenDate, BetweenFormatter.Level.MILLISECOND);
log.info(String.format("格式化数据耗时(毫秒):%s", formatBetween));
return azureAIChatResponse;
}
}
private Request buildRequest(String type, String subUrl, String requestBody) {
RequestBuilder builder = new RequestBuilder(type);
Request request = builder.setUrl(this.url + subUrl)
.addHeader("Accept", JSON)
.addHeader("Content-Type", JSON)
.addHeader("api-key", this.token)
.setBody(requestBody)
.build();
return request;
}
}
\ No newline at end of file
package com.kwan.springbootkwan.entity.openai;
import lombok.Data;
@Data
public class AzureAIMessage {
private String role;
private String content;
}
package com.kwan.springbootkwan.entity.openai;
import com.google.gson.annotations.SerializedName;
public enum AzureAIRole {
@SerializedName("assistant")
ASSISTANT("assistant"),
@SerializedName("system")
SYSTEM("system"),
@SerializedName("user")
USER("user"),
;
private final String text;
private AzureAIRole(final String text) {
this.text = text;
}
@Override
public String toString() {
return text;
}
}
\ No newline at end of file
package com.kwan.springbootkwan.entity.openai;
import com.google.gson.annotations.SerializedName;
import lombok.Data;
@Data
public class AzureAIUsage {
@SerializedName("prompt_tokens")
private int promptTokens;
@SerializedName("completion_tokens")
private int completionTokens;
@SerializedName("total_tokens")
private int totalTokens;
}
package com.kwan.springbootkwan.entity.openai;
import lombok.extern.slf4j.Slf4j;
import java.util.ArrayList;
import java.util.List;
@Slf4j
public class Main {
/**
* # 公司的key
* os.environ["OPENAI_API_KEY"] = ''
* os.environ["OPENAI_API_BASE"] = 'https://opencatgpt.openai.azure.com/'
* os.environ["OPENAI_API_TYPE"] = 'azure'
* os.environ["OPENAI_API_VERSION"] = '2023-05-15'
*/
public static void main(String[] args) throws Exception {
// 装配请求集合
List<AzureAIMessage> azureAiMessageList = new ArrayList<>();
AzureAIChatRequest azureAiChatRequest = new AzureAIChatRequest();
AzureAIMessage azureAIMessage0 = new AzureAIMessage();
azureAIMessage0.setRole(AzureAIRole.SYSTEM.toString());
azureAIMessage0.setContent("你是一个AI机器人,请根据提问进行回答");
azureAiMessageList.add(azureAIMessage0);
execute(azureAiMessageList, azureAiChatRequest, "请解释一下java的多态");
}
private static void execute(List<AzureAIMessage> azureAiMessageList, AzureAIChatRequest azureAiChatRequest
, String question) throws Exception {
AzureAIMessage azureAIMessage1 = new AzureAIMessage();
azureAIMessage1.setRole(AzureAIRole.USER.toString());
azureAIMessage1.setContent(question);
azureAiMessageList.add(azureAIMessage1);
azureAiChatRequest.setMessages(azureAiMessageList);
azureAiChatRequest.setMaxTokens(1024);
azureAiChatRequest.setTemperature(0.0);
// 是否进行留式返回
azureAiChatRequest.setPresencePenalty(0);
azureAiChatRequest.setFrequencyPenalty(0);
azureAiChatRequest.setStop(null);
AzureAIClient azureAIClient = new AzureAIClient("https://opencatgpt.openai.azure.com/", "",
"gpt-35-turbo", "2023-05-15");
AzureAIChatResponse azureAIChatResponse = azureAIClient.sendChatRequest(azureAiChatRequest);
System.out.println(azureAIChatResponse.getChoices().get(0).getMessage());
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册