版本
0.28.0
源码
使用langchain4j,可以通过AiServices来封装聊天模型API,实现会话记忆,工具调用,搜索增强,内容审查等功能,并提供简单灵活的用户接口 DefaultAiServices是其默认实现类型,通过动态代理的方式实现用户定义的服务接口
代码语言:javascript复制class DefaultAiServices<T> extends AiServices<T> {
private static final int MAX_SEQUENTIAL_TOOL_EXECUTIONS = 10;
DefaultAiServices(AiServiceContext context) {
super(context);
}
// 校验使用提示词模板发送消息的方法参数
static void validateParameters(Method method) {
// 如果只有一个参数或者没有参数跳过检查(参数直接作为内容发送的方法/其他非发送内容的方法)
Parameter[] parameters = method.getParameters();
if (parameters == null || parameters.length < 2) {
return;
}
for (Parameter parameter : parameters) {
// 获取应用于提示词模板的参数(带有V注解)
V v = parameter.getAnnotation(V.class);
// 获取用户消息模板参数
dev.langchain4j.service.UserMessage userMessage = parameter.getAnnotation(dev.langchain4j.service.UserMessage.class);
// 获取记忆ID参数
MemoryId memoryId = parameter.getAnnotation(MemoryId.class);
// 获取用户名参数
UserName userName = parameter.getAnnotation(UserName.class);
// 如果没有任何模板参数则报错
if (v == null && userMessage == null && memoryId == null && userName == null) {
throw illegalConfiguration(
"Parameter '%s' of method '%s' should be annotated with @V or @UserMessage or @UserName or @MemoryId",
parameter.getName(), method.getName()
);
}
}
}
public T build() {
// 基本校验
// 1. 校验chatModel/streamingChatModel是否有值
// 2. 校验toolSpecifications有值时上下文是否启用记忆(使用工具调用至少需要在记忆中保存3个消息)
performBasicValidation();
// 校验方法使用了Moderate时是否同时指定了审查模型(moderationModel)
for (Method method : context.aiServiceClass.getMethods()) {
if (method.isAnnotationPresent(Moderate.class) && context.moderationModel == null) {
throw illegalConfiguration("The @Moderate annotation is present, but the moderationModel is not set up. "
"Please ensure a valid moderationModel is configured before using the @Moderate annotation.");
}
}
// 构造动态代理
Object proxyInstance = Proxy.newProxyInstance(
context.aiServiceClass.getClassLoader(),
new Class<?>[]{context.aiServiceClass},
new InvocationHandler() {
private final ExecutorService executor = Executors.newCachedThreadPool();
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Exception {
// 直接执行Object类定义的方法
if (method.getDeclaringClass() == Object.class) {
// methods like equals(), hashCode() and toString() should not be handled by this proxy
return method.invoke(this, args);
}
// 校验提示词模板参数
validateParameters(method);
// 获取系统消息
Optional<SystemMessage> systemMessage = prepareSystemMessage(method, args);
// 获取用户消息
UserMessage userMessage = prepareUserMessage(method, args);
// 获取记忆ID参数值,如果没有记忆ID参数则使用默认值“default”
Object memoryId = memoryId(method, args).orElse(DEFAULT);
// 使用检索增强生成(RAG),将检索结果内容与用户原始消息文本整合作为用户消息
if (context.retrievalAugmentor != null) {
List<ChatMessage> chatMemory = context.hasChatMemory()
? context.chatMemory(memoryId).messages()
: null;
Metadata metadata = Metadata.from(userMessage, memoryId, chatMemory);
userMessage = context.retrievalAugmentor.augment(userMessage, metadata);
}
// 用于提供客制化的输出解析,根据函数返回类型生成需要返回的消息格式的相关提示词,追加到用户消息里面
// 如果返回类型为String,AiMessage,TokenStream,Response则不追加格式提示词
// 如果返回类型为void则报错
// 如果返回类型为enum枚举类型,则追加提示词“nYou must answer strictly in the following format: one of value1,value2,value3...,valueN”
// 如果返回类型是 boolean/byte/short/int/long/BigInteger/float/double/BigDecimal/Date/LocalDate/LocalTime/LocalDateTime 或其对应包装类型,则追加对应值类型提示词,例如“nYou must answer strictly in the following format: one of [true, false]” ,“...format: integer number in range [-128, 127]”
// 如果返回类型是List/Set,则追加提示词“You must put every item on a separate line.”
// 否则追加提示词,以json形式返回 “You must answer strictly in the following JSON format: {...}”
String outputFormatInstructions = outputFormatInstructions(method.getReturnType());
userMessage = UserMessage.from(userMessage.text() outputFormatInstructions);
// 如果包含聊天记忆,则在聊天记忆中追加系统消息和用户消息
if (context.hasChatMemory()) {
ChatMemory chatMemory = context.chatMemory(memoryId);
systemMessage.ifPresent(chatMemory::add);
chatMemory.add(userMessage);
}
// 从记忆中获取消息清单或构建新的消息清单
List<ChatMessage> messages;
if (context.hasChatMemory()) {
messages = context.chatMemory(memoryId).messages();
} else {
messages = new ArrayList<>();
systemMessage.ifPresent(messages::add);
messages.add(userMessage);
}
// 执行审查
Future<Moderation> moderationFuture = triggerModerationIfNeeded(method, messages);
// 以流式处理消息
if (method.getReturnType() == TokenStream.class) {
return new AiServiceTokenStream(messages, context, memoryId); // 尚未实现响应内容审查,也不支持工具调用
}
// 调用chatModel生成响应
Response<AiMessage> response = context.toolSpecifications == null
? context.chatModel.generate(messages)
: context.chatModel.generate(messages, context.toolSpecifications);
// 获取token用量
TokenUsage tokenUsageAccumulator = response.tokenUsage();
// 校验审查结果
verifyModerationIfNeeded(moderationFuture);
// 执行工具调用
// 工具调用的最大执行次数(10)
int executionsLeft = MAX_SEQUENTIAL_TOOL_EXECUTIONS;
while (true) {
if (executionsLeft-- == 0) {
throw runtime("Something is wrong, exceeded %s sequential tool executions",
MAX_SEQUENTIAL_TOOL_EXECUTIONS);
}
// 获取AI响应消息,添加到记忆中
AiMessage aiMessage = response.content();
if (context.hasChatMemory()) {
context.chatMemory(memoryId).add(aiMessage);
}
// 如果不存在工具调用请求则中断
if (!aiMessage.hasToolExecutionRequests()) {
break;
}
// 根据工具调用请求,依次调用工具,并将工具执行结果消息添加到记忆中
ChatMemory chatMemory = context.chatMemory(memoryId);
for (ToolExecutionRequest toolExecutionRequest : aiMessage.toolExecutionRequests()) {
ToolExecutor toolExecutor = context.toolExecutors.get(toolExecutionRequest.name());
String toolExecutionResult = toolExecutor.execute(toolExecutionRequest, memoryId);
ToolExecutionResultMessage toolExecutionResultMessage = ToolExecutionResultMessage.from(
toolExecutionRequest,
toolExecutionResult
);
chatMemory.add(toolExecutionResultMessage);
}
// 根据添加了工具执行结果的记忆再次调用模型生成
response = context.chatModel.generate(chatMemory.messages(), context.toolSpecifications);
// 累计token用量
tokenUsageAccumulator = tokenUsageAccumulator.add(response.tokenUsage());
}
// 返回最终的响应
response = Response.from(response.content(), tokenUsageAccumulator, response.finishReason());
// 将响应解析为方法对应的返回类型对象
return parse(response, method.getReturnType());
}
private Future<Moderation> triggerModerationIfNeeded(Method method, List<ChatMessage> messages) {
if (method.isAnnotationPresent(Moderate.class)) {
return executor.submit(() -> {
List<ChatMessage> messagesToModerate = removeToolMessages(messages);
return context.moderationModel.moderate(messagesToModerate).content();
});
}
return null;
}
});
return (T) proxyInstance;
}
// 准备系统消息
private Optional<SystemMessage> prepareSystemMessage(Method method, Object[] args) {
// 获取提示词模板变量
Parameter[] parameters = method.getParameters();
Map<String, Object> variables = getPromptTemplateVariables(args, parameters);
dev.langchain4j.service.SystemMessage annotation = method.getAnnotation(dev.langchain4j.service.SystemMessage.class);
if (annotation != null) {
// 获取 SystemMessage 注解的系统消息提示词模板
String systemMessageTemplate = getPromptText(
method,
"System",
annotation.fromResource(), // 提示词资源文件,如果没有则取value值
annotation.value(), // 提示词文本
annotation.delimiter() // 换行符
);
// 根据模板和变量获取提示词
Prompt prompt = PromptTemplate.from(systemMessageTemplate).apply(variables);
return Optional.of(prompt.toSystemMessage());
}
return Optional.empty();
}
// 准备用户消息
private static UserMessage prepareUserMessage(Method method, Object[] args) {
Parameter[] parameters = method.getParameters();
Map<String, Object> variables = getPromptTemplateVariables(args, parameters);
// 获取用户名参数
String userName = getUserName(parameters, args);
dev.langchain4j.service.UserMessage annotation = method.getAnnotation(dev.langchain4j.service.UserMessage.class);
if (annotation != null) {
String userMessageTemplate = getPromptText(
method,
"User",
annotation.fromResource(),
annotation.value(),
annotation.delimiter()
);
// 如果模板中使用了{{it}}占位符,则只允许使用一个模板参数
if (userMessageTemplate.contains("{{it}}")) {
if (parameters.length != 1) {
throw illegalConfiguration("Error: The {{it}} placeholder is present but the method does not have exactly one parameter. "
"Please ensure that methods using the {{it}} placeholder have exactly one parameter.");
}
variables = singletonMap("it", toString(args[0]));
}
Prompt prompt = PromptTemplate.from(userMessageTemplate).apply(variables);
if (userName != null) {
// 使用用户名构造用户消息
return userMessage(userName, prompt.text());
} else {
return prompt.toUserMessage();
}
}
// 方法如果没有UserMessage注解,查找使用UserMessage注解的参数,作为消息内容
for (int i = 0; i < parameters.length; i ) {
if (parameters[i].isAnnotationPresent(dev.langchain4j.service.UserMessage.class)) {
String text = toString(args[i]);
if (userName != null) {
return userMessage(userName, text);
} else {
return userMessage(text);
}
}
}
// 如果完全没有参数则报错
if (args == null || args.length == 0) {
throw illegalConfiguration("Method should have at least one argument");
}
// 如果只有一个没有注解的参数,则作为消息内容
if (args.length == 1) {
String text = toString(args[0]);
if (userName != null) {
return userMessage(userName, text);
} else {
return userMessage(text);
}
}
throw illegalConfiguration("For methods with multiple parameters, each parameter must be annotated with @V, @UserMessage, @UserName or @MemoryId");
}
// 根据方法提示词注解获取提示词文本
// resource 提示词资源文件,如果没有则取value值
// value 提示词文本
// delimiter 分隔符(换行符)
private static String getPromptText(Method method, String type, String resource, String[] value, String delimiter) {
String messageTemplate;
if (!resource.trim().isEmpty()) {
messageTemplate = getResourceText(method.getDeclaringClass(), resource);
if (messageTemplate == null) {
throw illegalConfiguration("@%sMessage's resource '%s' not found", type, resource);
}
} else {
messageTemplate = String.join(delimiter, value);
}
if (messageTemplate.trim().isEmpty()) {
throw illegalConfiguration("@%sMessage's template cannot be empty", type);
}
return messageTemplate;
}
private static String getResourceText(Class<?> clazz, String name) {
return getText(clazz.getResourceAsStream(name));
}
private static String getText(InputStream inputStream) {
if (inputStream == null) {
return null;
}
try (Scanner scanner = new Scanner(inputStream);
Scanner s = scanner.useDelimiter("\A")) {
return s.hasNext() ? s.next() : "";
}
}
private Optional<Object> memoryId(Method method, Object[] args) {
Parameter[] parameters = method.getParameters();
for (int i = 0; i < parameters.length; i ) {
if (parameters[i].isAnnotationPresent(MemoryId.class)) {
Object memoryId = args[i];
if (memoryId == null) {
throw illegalArgument("The value of parameter %s annotated with @MemoryId in method %s must not be null",
parameters[i].getName(), method.getName());
}
return Optional.of(memoryId);
}
}
return Optional.empty();
}
// 获取用户名参数
private static String getUserName(Parameter[] parameters, Object[] args) {
for (int i = 0; i < parameters.length; i ) {
if (parameters[i].isAnnotationPresent(UserName.class)) {
return args[i].toString();
}
}
return null;
}
// 获取提示词模板变量
// 遍历V注解的变量,返回变量名和变量值映射
private static Map<String, Object> getPromptTemplateVariables(Object[] args, Parameter[] parameters) {
Map<String, Object> variables = new HashMap<>();
for (int i = 0; i < parameters.length; i ) {
V varAnnotation = parameters[i].getAnnotation(V.class);
if (varAnnotation != null) {
String variableName = varAnnotation.value();
Object variableValue = args[i];
variables.put(variableName, variableValue);
}
}
return variables;
}
private static String toString(Object arg) {
if (arg.getClass().isArray()) {
return arrayToString(arg);
} else if (arg.getClass().isAnnotationPresent(StructuredPrompt.class)) {
return StructuredPromptProcessor.toPrompt(arg).text();
} else {
return arg.toString();
}
}
private static String arrayToString(Object arg) {
StringBuilder sb = new StringBuilder("[");
int length = Array.getLength(arg);
for (int i = 0; i < length; i ) {
sb.append(toString(Array.get(arg, i)));
if (i < length - 1) {
sb.append(", ");
}
}
sb.append("]");
return sb.toString();
}
}