This commit is contained in:
YunaiV 2025-05-04 09:53:43 +08:00
commit 3bcb9890e8
20 changed files with 404 additions and 99 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 25 KiB

After

Width:  |  Height:  |  Size: 32 KiB

View File

@ -585,6 +585,13 @@
<groupId>com.xkcoding.justauth</groupId>
<artifactId>justauth-spring-boot-starter</artifactId>
<version>${justauth-starter.version}</version>
<exclusions>
<!-- 移除,避免和项目里的 hutool-all 冲突 -->
<exclusion>
<groupId>cn.hutool</groupId>
<artifactId>hutool-core</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>

View File

@ -63,6 +63,15 @@ public class AiKnowledgeController {
knowledgeService.updateKnowledge(updateReqVO);
return success(true);
}
@DeleteMapping("/delete")
@Operation(summary = "删除知识库")
@Parameter(name = "id", description = "编号", required = true, example = "1024")
@PreAuthorize("@ss.hasPermission('ai:knowledge:delete')")
public CommonResult<Boolean> deleteKnowledge(@RequestParam("id") Long id) {
knowledgeService.deleteKnowledge(id);
return success(true);
}
@GetMapping("/simple-list")
@Operation(summary = "获得知识库的精简列表")

View File

@ -0,0 +1,12 @@
### 测试 AI 工作流
POST {{baseUrl}}/ai/workflow/test
Content-Type: application/json
Authorization: {{token}}
tenant-id: {{adminTenantId}}
{
"id": 4,
"params": {
"message": "1 + 1 = ?"
}
}

View File

@ -1,7 +1,8 @@
package cn.iocoder.yudao.module.ai.controller.admin.workflow.vo;
import cn.hutool.core.util.StrUtil;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.AssertTrue;
import lombok.Data;
import java.util.Map;
@ -10,11 +11,18 @@ import java.util.Map;
@Data
public class AiWorkflowTestReqVO {
@Schema(description = "工作流模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "{}")
@NotEmpty(message = "工作流模型不能为空")
@Schema(description = "工作流编号", example = "1024")
private Long id;
@Schema(description = "工作流模型", example = "{}")
private String graph;
@Schema(description = "参数", requiredMode = Schema.RequiredMode.REQUIRED, example = "{}")
private Map<String, Object> params;
@AssertTrue(message = "工作流或模型,必须传递一个")
public boolean isGraphValid() {
return id != null || StrUtil.isNotEmpty(graph);
}
}

View File

@ -36,4 +36,8 @@ public interface AiKnowledgeDocumentMapper extends BaseMapperX<AiKnowledgeDocume
return selectList(AiKnowledgeDocumentDO::getStatus, status);
}
default List<AiKnowledgeDocumentDO> selectListByKnowledgeId(Long knowledgeId) {
return selectList(AiKnowledgeDocumentDO::getKnowledgeId, knowledgeId);
}
}

View File

@ -30,9 +30,9 @@ public interface AiKnowledgeSegmentMapper extends BaseMapperX<AiKnowledgeSegment
.orderByDesc(AiKnowledgeSegmentDO::getId));
}
default List<AiKnowledgeSegmentDO> selectListByVectorIds(List<String> vectorIdList) {
default List<AiKnowledgeSegmentDO> selectListByVectorIds(List<String> vectorIds) {
return selectList(new LambdaQueryWrapperX<AiKnowledgeSegmentDO>()
.in(AiKnowledgeSegmentDO::getVectorId, vectorIdList)
.in(AiKnowledgeSegmentDO::getVectorId, vectorIds)
.orderByDesc(AiKnowledgeSegmentDO::getId));
}

View File

@ -101,8 +101,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
ChatModel chatModel = modalService.getChatModel(model.getId());
// 2. 知识库找回
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(),
conversation);
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments = recallKnowledgeSegment(sendReqVO.getContent(), conversation);
// 3. 插入 user 发送消息
AiChatMessageDO userMessage = createChatMessage(conversation.getId(), null, model,
@ -122,11 +121,11 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
String newContent = chatResponse.getResult().getOutput().getText();
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(newContent));
// 3.4 响应结果
Map<Long, AiKnowledgeDocumentDO> documentMap = knowledgeDocumentService.getKnowledgeDocumentMap(
convertSet(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getDocumentId));
List<AiChatMessageRespVO.KnowledgeSegment> segments = BeanUtils.toBean(knowledgeSegments,
AiChatMessageRespVO.KnowledgeSegment.class,
segment -> {
AiKnowledgeDocumentDO document = knowledgeDocumentService
.getKnowledgeDocument(segment.getDocumentId());
AiChatMessageRespVO.KnowledgeSegment.class, segment -> {
AiKnowledgeDocumentDO document = documentMap.get(segment.getDocumentId());
segment.setDocumentName(document != null ? document.getName() : null);
});
return new AiChatMessageSendRespVO()
@ -173,12 +172,13 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 处理知识库的返回只有首次才有
List<AiChatMessageRespVO.KnowledgeSegment> segments = null;
if (StrUtil.isEmpty(contentBuffer)) {
segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class,
segment -> TenantUtils.executeIgnore(() -> {
AiKnowledgeDocumentDO document = knowledgeDocumentService
.getKnowledgeDocument(segment.getDocumentId());
segment.setDocumentName(document != null ? document.getName() : null);
}));
Map<Long, AiKnowledgeDocumentDO> documentMap = TenantUtils.executeIgnore(() ->
knowledgeDocumentService.getKnowledgeDocumentMap(
convertSet(knowledgeSegments, AiKnowledgeSegmentSearchRespBO::getDocumentId)));
segments = BeanUtils.toBean(knowledgeSegments, AiChatMessageRespVO.KnowledgeSegment.class, segment -> {
AiKnowledgeDocumentDO document = documentMap.get(segment.getDocumentId());
segment.setDocumentName(document != null ? document.getName() : null);
});
}
// 响应结果
String newContent = chunk.getResult() != null ? chunk.getResult().getOutput().getText() : null;
@ -221,8 +221,8 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
private Prompt buildPrompt(AiChatConversationDO conversation, List<AiChatMessageDO> messages,
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
List<AiKnowledgeSegmentSearchRespBO> knowledgeSegments,
AiModelDO model, AiChatMessageSendReqVO sendReqVO) {
List<Message> chatMessages = new ArrayList<>();
// 1.1 System Context 角色设定
if (StrUtil.isNotBlank(conversation.getSystemMessage())) {
@ -247,16 +247,18 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 2.1 查询 tool 工具
Set<String> toolNames = null;
Map<String,Object> toolContext = Map.of();
if (conversation.getRoleId() != null) {
AiChatRoleDO chatRole = chatRoleService.getChatRole(conversation.getRoleId());
if (chatRole != null && CollUtil.isNotEmpty(chatRole.getToolIds())) {
toolNames = convertSet(toolService.getToolList(chatRole.getToolIds()), AiToolDO::getName);
toolContext = AiUtils.buildCommonToolContext();
}
}
// 2.2 构建 ChatOptions 对象
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
ChatOptions chatOptions = AiUtils.buildChatOptions(platform, model.getModel(),
conversation.getTemperature(), conversation.getMaxTokens(), toolNames);
conversation.getTemperature(), conversation.getMaxTokens(), toolNames, toolContext);
return new Prompt(chatMessages, chatOptions);
}

View File

@ -67,13 +67,6 @@ public interface AiKnowledgeDocumentService {
*/
void updateKnowledgeDocumentStatus(AiKnowledgeDocumentUpdateStatusReqVO reqVO);
/**
* 更新文档检索次数增加 +1
*
* @param ids 文档编号列表
*/
void updateKnowledgeDocumentRetrievalCountIncr(Collection<Long> ids);
/**
* 删除文档
*
@ -81,6 +74,13 @@ public interface AiKnowledgeDocumentService {
*/
void deleteKnowledgeDocument(Long id);
/**
* 根据知识库编号批量删除文档
*
* @param knowledgeId 知识库编号
*/
void deleteKnowledgeDocumentByKnowledgeId(Long knowledgeId);
/**
* 校验文档是否存在
*
@ -105,6 +105,14 @@ public interface AiKnowledgeDocumentService {
*/
List<AiKnowledgeDocumentDO> getKnowledgeDocumentList(Collection<Long> ids);
/**
* 根据知识库编号获取文档列表
*
* @param knowledgeId 知识库编号
* @return 文档列表
*/
List<AiKnowledgeDocumentDO> getKnowledgeDocumentListByKnowledgeId(Long knowledgeId);
/**
* 获取文档 Map
*

View File

@ -161,14 +161,6 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
knowledgeSegmentService.deleteKnowledgeSegmentByDocumentId(id);
}
@Override
public void updateKnowledgeDocumentRetrievalCountIncr(Collection<Long> ids) {
if (CollUtil.isEmpty(ids)) {
return;
}
knowledgeDocumentMapper.updateRetrievalCountIncr(ids);
}
@Override
public AiKnowledgeDocumentDO validateKnowledgeDocumentExists(Long id) {
AiKnowledgeDocumentDO knowledgeDocument = knowledgeDocumentMapper.selectById(id);
@ -211,4 +203,24 @@ public class AiKnowledgeDocumentServiceImpl implements AiKnowledgeDocumentServic
return knowledgeDocumentMapper.selectBatchIds(ids);
}
@Override
public List<AiKnowledgeDocumentDO> getKnowledgeDocumentListByKnowledgeId(Long knowledgeId) {
return knowledgeDocumentMapper.selectListByKnowledgeId(knowledgeId);
}
@Override
@Transactional(rollbackFor = Exception.class)
public void deleteKnowledgeDocumentByKnowledgeId(Long knowledgeId) {
// 1. 获取该知识库下的所有文档
List<AiKnowledgeDocumentDO> documents = knowledgeDocumentMapper.selectListByKnowledgeId(knowledgeId);
if (CollUtil.isEmpty(documents)) {
return;
}
// 2. 逐个删除文档及其对应的段落
for (AiKnowledgeDocumentDO document : documents) {
deleteKnowledgeDocument(document.getId());
}
}
}

View File

@ -29,6 +29,13 @@ public interface AiKnowledgeService {
*/
void updateKnowledge(AiKnowledgeSaveReqVO updateReqVO);
/**
* 删除知识库
*
* @param id 知识库编号
*/
void deleteKnowledge(Long id);
/**
* 获得知识库
*

View File

@ -1,19 +1,18 @@
package cn.iocoder.yudao.module.ai.service.knowledge;
import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.knowledge.vo.knowledge.AiKnowledgeSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.knowledge.AiKnowledgeDocumentDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.knowledge.AiKnowledgeMapper;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.List;
@ -36,6 +35,8 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
private AiModelService modelService;
@Resource
private AiKnowledgeSegmentService knowledgeSegmentService;
@Resource
private AiKnowledgeDocumentService knowledgeDocumentService;
@Override
public Long createKnowledge(AiKnowledgeSaveReqVO createReqVO) {
@ -67,6 +68,20 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
}
}
@Override
@Transactional(rollbackFor = Exception.class)
public void deleteKnowledge(Long id) {
// 1. 校验存在
validateKnowledgeExists(id);
// 2. 删除知识库下的所有文档及段落
knowledgeDocumentService.deleteKnowledgeDocumentByKnowledgeId(id);
// 3. 删除知识库
// 特殊知识库需要最后删除不然相关的配置会找不到
knowledgeMapper.deleteById(id);
}
@Override
public AiKnowledgeDO getKnowledge(Long id) {
return knowledgeMapper.selectById(id);
@ -74,11 +89,11 @@ public class AiKnowledgeServiceImpl implements AiKnowledgeService {
@Override
public AiKnowledgeDO validateKnowledgeExists(Long id) {
AiKnowledgeDO knowledgeBase = knowledgeMapper.selectById(id);
if (knowledgeBase == null) {
AiKnowledgeDO knowledge = knowledgeMapper.selectById(id);
if (knowledge == null) {
throw exception(KNOWLEDGE_NOT_EXISTS);
}
return knowledgeBase;
return knowledge;
}
@Override

View File

@ -6,6 +6,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import dev.tinyflow.core.Tinyflow;
import jakarta.validation.Valid;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.image.ImageModel;
@ -131,4 +132,12 @@ public interface AiModelService {
*/
VectorStore getOrCreateVectorStore(Long id, Map<String, Class<?>> metadataFields);
/**
* 获取 TinyFlow 所需 LLm Provider
*
* @param tinyflow tinyflow
* @param modelId AI 模型 ID
*/
void getLLmProvider4Tinyflow(Tinyflow tinyflow, Long modelId);
}

View File

@ -12,6 +12,11 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.model.AiModelSaveReq
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiModelDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiChatMapper;
import com.agentsflex.llm.ollama.OllamaLlm;
import com.agentsflex.llm.ollama.OllamaLlmConfig;
import com.agentsflex.llm.qwen.QwenLlm;
import com.agentsflex.llm.qwen.QwenLlmConfig;
import dev.tinyflow.core.Tinyflow;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.embedding.EmbeddingModel;
@ -168,4 +173,29 @@ public class AiModelServiceImpl implements AiModelService {
// return modelFactory.getOrCreateVectorStore(MilvusVectorStore.class, embeddingModel, metadataFields);
}
// TODO @lesan是不是返回 Llm 对象会好点哈
@Override
public void getLLmProvider4Tinyflow(Tinyflow tinyflow, Long modelId) {
AiModelDO model = validateModel(modelId);
AiApiKeyDO apiKey = apiKeyService.validateApiKey(model.getKeyId());
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
switch (platform) {
// TODO @lesan 考虑到未来不需要使用agents-flex 现在仅测试通义千问
// TODO @lesan重要是不是可以实现一个 SpringAiLlm这样的话内部全部用它就好了只实现 chat 部分这样就把 flex 作为一个 agent 框架内部调用还是 spring ai 相关的成本可能低一点
case TONG_YI:
QwenLlmConfig qwenLlmConfig = new QwenLlmConfig();
qwenLlmConfig.setApiKey(apiKey.getApiKey());
qwenLlmConfig.setModel(model.getModel());
// TODO @lesan这个有点奇怪如果一个链式里有多个模型咋整呀
tinyflow.setLlmProvider(id -> new QwenLlm(qwenLlmConfig));
break;
case OLLAMA:
OllamaLlmConfig ollamaLlmConfig = new OllamaLlmConfig();
ollamaLlmConfig.setEndpoint(apiKey.getUrl());
ollamaLlmConfig.setModel(model.getModel());
tinyflow.setLlmProvider(id -> new OllamaLlm(ollamaLlmConfig));
break;
}
}
}

View File

@ -0,0 +1,75 @@
package cn.iocoder.yudao.module.ai.service.model.tool;
import cn.iocoder.yudao.framework.ai.core.util.AiUtils;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.framework.security.core.LoginUser;
import cn.iocoder.yudao.framework.tenant.core.util.TenantUtils;
import cn.iocoder.yudao.module.system.api.user.AdminUserApi;
import cn.iocoder.yudao.module.system.api.user.dto.AdminUserRespDTO;
import com.fasterxml.jackson.annotation.JsonClassDescription;
import jakarta.annotation.Resource;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.springframework.ai.chat.model.ToolContext;
import org.springframework.stereotype.Component;
import java.util.function.BiFunction;
/**
* 工具当前用户信息查询
*
* 同时也是展示 ToolContext 上下文的使用
*
* @author Ren
*/
@Component("user_profile_query")
public class UserProfileQueryToolFunction
implements BiFunction<UserProfileQueryToolFunction.Request, ToolContext, UserProfileQueryToolFunction.Response> {
@Resource
private AdminUserApi adminUserApi;
@Data
@JsonClassDescription("当前用户信息查询")
public static class Request { }
@Data
@AllArgsConstructor
@NoArgsConstructor
public static class Response {
/**
* 用户ID
*/
private Long id;
/**
* 用户昵称
*/
private String nickname;
/**
* 手机号码
*/
private String mobile;
/**
* 用户头像
*/
private String avatar;
}
@Override
public UserProfileQueryToolFunction.Response apply(UserProfileQueryToolFunction.Request request, ToolContext toolContext) {
LoginUser loginUser = (LoginUser) toolContext.getContext().get(AiUtils.TOOL_CONTEXT_LOGIN_USER);
Long tenantId = (Long) toolContext.getContext().get(AiUtils.TOOL_CONTEXT_TENANT_ID);
if (loginUser == null | tenantId == null) {
return null;
}
return TenantUtils.execute(tenantId, () -> {
AdminUserRespDTO user = adminUserApi.getUser(loginUser.getId());
return BeanUtils.toBean(user, Response.class);
});
}
}

View File

@ -7,10 +7,9 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.workflow.vo.AiWorkflowPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.workflow.vo.AiWorkflowSaveReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.workflow.vo.AiWorkflowTestReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.dataobject.workflow.AiWorkflowDO;
import cn.iocoder.yudao.module.ai.dal.mysql.workflow.AiWorkflowMapper;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiModelService;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import dev.tinyflow.core.Tinyflow;
@ -37,11 +36,14 @@ public class AiWorkflowServiceImpl implements AiWorkflowService {
private AiWorkflowMapper workflowMapper;
@Resource
private AiApiKeyService apiKeyService;
private AiModelService apiModelService;
@Override
public Long createWorkflow(AiWorkflowSaveReqVO createReqVO) {
validateWorkflowForCreateOrUpdate(null, createReqVO.getCode());
// 1. 参数校验
validateCodeUnique(null, createReqVO.getCode());
// 2. 插入工作流配置
AiWorkflowDO workflow = BeanUtils.toBean(createReqVO, AiWorkflowDO.class);
workflowMapper.insert(workflow);
return workflow.getId();
@ -49,47 +51,33 @@ public class AiWorkflowServiceImpl implements AiWorkflowService {
@Override
public void updateWorkflow(AiWorkflowSaveReqVO updateReqVO) {
validateWorkflowForCreateOrUpdate(updateReqVO.getId(), updateReqVO.getCode());
// 1. 参数校验
validateWorkflowExists(updateReqVO.getId());
validateCodeUnique(updateReqVO.getId(), updateReqVO.getCode());
// 2. 更新工作流配置
AiWorkflowDO workflow = BeanUtils.toBean(updateReqVO, AiWorkflowDO.class);
workflowMapper.updateById(workflow);
}
@Override
public void deleteWorkflow(Long id) {
// 1. 校验存在
validateWorkflowExists(id);
// 2. 删除工作流配置
workflowMapper.deleteById(id);
}
@Override
public AiWorkflowDO getWorkflow(Long id) {
return workflowMapper.selectById(id);
}
@Override
public PageResult<AiWorkflowDO> getWorkflowPage(AiWorkflowPageReqVO pageReqVO) {
return workflowMapper.selectPage(pageReqVO);
}
@Override
public Object testWorkflow(AiWorkflowTestReqVO testReqVO) {
Map<String, Object> variables = testReqVO.getParams();
Tinyflow tinyflow = parseFlowParam(testReqVO.getGraph());
return tinyflow.toChain().executeForResult(variables);
}
private void validateWorkflowForCreateOrUpdate(Long id, String code) {
validateWorkflowExists(id);
validateCodeUnique(id, code);
}
private void validateWorkflowExists(Long id) {
private AiWorkflowDO validateWorkflowExists(Long id) {
if (ObjUtil.isNull(id)) {
return;
throw exception(WORKFLOW_NOT_EXISTS);
}
AiWorkflowDO workflow = workflowMapper.selectById(id);
if (ObjUtil.isNull(workflow)) {
throw exception(WORKFLOW_NOT_EXISTS);
}
return workflow;
}
private void validateCodeUnique(Long id, String code) {
@ -108,6 +96,30 @@ public class AiWorkflowServiceImpl implements AiWorkflowService {
}
}
@Override
public AiWorkflowDO getWorkflow(Long id) {
return workflowMapper.selectById(id);
}
@Override
public PageResult<AiWorkflowDO> getWorkflowPage(AiWorkflowPageReqVO pageReqVO) {
return workflowMapper.selectPage(pageReqVO);
}
@Override
public Object testWorkflow(AiWorkflowTestReqVO testReqVO) {
// 加载 graph
String graph = testReqVO.getGraph() != null ? testReqVO.getGraph()
: validateWorkflowExists(testReqVO.getId()).getGraph();
// 构建 TinyFlow 执行链
Tinyflow tinyflow = parseFlowParam(graph);
// 执行
Map<String, Object> variables = testReqVO.getParams();
return tinyflow.toChain().executeForResult(variables);
}
private Tinyflow parseFlowParam(String graph) {
// TODO @lesan可以使用 jackson
JSONObject json = JSONObject.parseObject(graph);
@ -118,25 +130,7 @@ public class AiWorkflowServiceImpl implements AiWorkflowService {
switch (node.getString("type")) {
case "llmNode":
JSONObject data = node.getJSONObject("data");
AiApiKeyDO apiKey = apiKeyService.getApiKey(data.getLong("llmId"));
switch (apiKey.getPlatform()) {
// TODO @lesan 需要讨论一下这里怎么弄
// TODO @lesan llmId 对应 model 的编号如何这样的话就是 apiModelService 提供一个获取 LLM 的方法然后创建的方法也在 AiModelFactory 提供可以先接个 deepseek deepseek yyds
case "OpenAI":
break;
case "Ollama":
break;
case "YiYan":
break;
case "XingHuo":
break;
case "TongYi":
break;
case "DeepSeek":
break;
case "ZhiPu":
break;
}
apiModelService.getLLmProvider4Tinyflow(tinyflow, data.getLong("llmId"));
break;
case "internalNode":
break;

View File

@ -76,7 +76,7 @@ public class AiWriteServiceImpl implements AiWriteService {
? writeRole.getSystemMessage() : AiChatRoleEnum.AI_WRITE_ROLE.getSystemMessage();
// 1.3 校验平台
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(model.getPlatform());
StreamingChatModel chatModel = modalService.getChatModel(model.getKeyId());
StreamingChatModel chatModel = modalService.getChatModel(model.getId());
// 2. 插入写作信息
AiWriteDO writeDO = BeanUtils.toBean(generateReqVO, AiWriteDO.class, write -> write.setUserId(userId)

View File

@ -15,7 +15,7 @@
<description>AI 大模型拓展,接入国内外大模型</description>
<properties>
<spring-ai.version>1.0.0-M6</spring-ai.version>
<tinyflow.version>1.0.0-rc.3</tinyflow.version>
<tinyflow.version>1.0.2</tinyflow.version>
</properties>
<dependencies>
@ -24,6 +24,18 @@
<artifactId>yudao-common</artifactId>
</dependency>
<!-- 业务组件 -->
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-spring-boot-starter-biz-tenant</artifactId>
</dependency>
<!-- Web 相关 -->
<dependency>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-spring-boot-starter-security</artifactId>
</dependency>
<!-- Spring AI Model 模型接入 -->
<dependency>
<groupId>org.springframework.ai</groupId>
@ -98,6 +110,13 @@
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-milvus-store</artifactId>
<version>${spring-ai.version}</version>
<exclusions>
<!-- 解决和 logback 的日志冲突 -->
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-reload4j</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
@ -124,6 +143,10 @@
<artifactId>tinyflow-java-core</artifactId>
<version>${tinyflow.version}</version>
<exclusions>
<exclusion>
<groupId>com.jfinal</groupId>
<artifactId>enjoy</artifactId>
</exclusion>
<exclusion>
<!-- 解决 https://gitee.com/zhijiantianya/ruoyi-vue-pro/pulls/1318/ 问题 -->
<groupId>com.agentsflex</groupId>
@ -134,6 +157,19 @@
<groupId>org.codehaus.groovy</groupId>
<artifactId>groovy-all</artifactId>
</exclusion>
<!-- 解决和 logback 的日志冲突 -->
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-simple</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.logging.log4j</groupId>
<artifactId>log4j-slf4j-impl</artifactId>
</exclusion>
<exclusion>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-reload4j</artifactId>
</exclusion>
</exclusions>
</dependency>

View File

@ -3,6 +3,8 @@ package cn.iocoder.yudao.framework.ai.core.util;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils;
import cn.iocoder.yudao.framework.tenant.core.context.TenantContextHolder;
import com.alibaba.cloud.ai.dashscope.chat.DashScopeChatOptions;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.messages.*;
@ -15,6 +17,8 @@ import org.springframework.ai.qianfan.QianFanChatOptions;
import org.springframework.ai.zhipuai.ZhiPuAiChatOptions;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
/**
@ -24,29 +28,32 @@ import java.util.Set;
*/
public class AiUtils {
public static final String TOOL_CONTEXT_LOGIN_USER = "LOGIN_USER";
public static final String TOOL_CONTEXT_TENANT_ID = "TENANT_ID";
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens) {
return buildChatOptions(platform, model, temperature, maxTokens, null);
return buildChatOptions(platform, model, temperature, maxTokens, null, null);
}
public static ChatOptions buildChatOptions(AiPlatformEnum platform, String model, Double temperature, Integer maxTokens,
Set<String> toolNames) {
Set<String> toolNames, Map<String, Object> toolContext) {
toolNames = ObjUtil.defaultIfNull(toolNames, Collections.emptySet());
// noinspection EnhancedSwitchMigration
switch (platform) {
case TONG_YI:
return DashScopeChatOptions.builder().withModel(model).withTemperature(temperature).withMaxToken(maxTokens)
.withFunctions(toolNames).build();
.withFunctions(toolNames).withToolContext(toolContext).build();
case YI_YAN:
return QianFanChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens).build();
case ZHI_PU:
return ZhiPuAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.functions(toolNames).build();
.functions(toolNames).toolContext(toolContext).build();
case MINI_MAX:
return MiniMaxChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.functions(toolNames).build();
.functions(toolNames).toolContext(toolContext).build();
case MOONSHOT:
return MoonshotChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.functions(toolNames).build();
.functions(toolNames).toolContext(toolContext).build();
case OPENAI:
case DEEP_SEEK: // 复用 OpenAI 客户端
case DOU_BAO: // 复用 OpenAI 客户端
@ -55,14 +62,14 @@ public class AiUtils {
case SILICON_FLOW: // 复用 OpenAI 客户端
case BAI_CHUAN: // 复用 OpenAI 客户端
return OpenAiChatOptions.builder().model(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).build();
.toolNames(toolNames).toolContext(toolContext).build();
case AZURE_OPENAI:
// TODO 芋艿貌似没 model 字段
return AzureOpenAiChatOptions.builder().deploymentName(model).temperature(temperature).maxTokens(maxTokens)
.toolNames(toolNames).build();
.toolNames(toolNames).toolContext(toolContext).build();
case OLLAMA:
return OllamaOptions.builder().model(model).temperature(temperature).numPredict(maxTokens)
.toolNames(toolNames).build();
.toolNames(toolNames).toolContext(toolContext).build();
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
@ -84,4 +91,11 @@ public class AiUtils {
throw new IllegalArgumentException(StrUtil.format("未知消息类型({})", type));
}
public static Map<String, Object> buildCommonToolContext() {
Map<String, Object> context = new HashMap<>();
context.put(TOOL_CONTEXT_LOGIN_USER, SecurityFrameworkUtils.getLoginUser());
context.put(TOOL_CONTEXT_TENANT_ID, TenantContextHolder.getTenantId());
return context;
}
}

View File

@ -0,0 +1,63 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.api.OpenAiApi;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
/**
* 基于 {@link OpenAiChatModel} 集成 Coze 测试
*
* @author 芋道源码
*/
public class CozeChatModelTests {
private final OpenAiChatModel chatModel = OpenAiChatModel.builder()
.openAiApi(OpenAiApi.builder()
.baseUrl("http://127.0.0.1:3000")
.apiKey("app-4hy2d7fJauSbrKbzTKX1afuP") // apiKey
.build())
.build();
@Test
@Disabled
public void testCall() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
ChatResponse response = chatModel.call(new Prompt(messages));
// 打印结果
System.out.println(response);
System.out.println(response.getResult().getOutput());
}
@Test
@Disabled
public void testStream() {
// 准备参数
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("1 + 1 = "));
// 调用
Flux<ChatResponse> flux = chatModel.stream(new Prompt(messages));
// 打印结果
flux.doOnNext(response -> {
// System.out.println(response);
System.out.println(response.getResult().getOutput());
}).then().block();
}
}