SpringBoot教程(十八) | SpringBoot集成Milvus(全网最全)

2024-05-26 15:58:43 浏览数 (1)

一、 milvus概述

Milvus是一款向量数据库,主要用于在大模型领域做向量查询的相关操作。milvus在之前的版本中其实是存在一些弊端的,尤其是在一些类似于mysql的查询方面,有一些缺点,这里简单唠叨几句。

首先milvus不支持多个向量字段,其次milvus的模糊匹配只支持前缀匹配,再次milvus不支持排序。

不过这些功能在最新版的milvus中都已经解决了。但是我还没来得及体验最新的版本,所以不知道支持情况如何。

说这么多的目的,其实主要就是提醒你们,如果对上述我提出来的几个点特别在意的,一定要选择好版本进行安装,别装完了以后发现不支持。

milvus的官网地址是: https://milvus.io/

milvus支持的语言比较多,这也是我选择的主要原因,python, Java, go,node, c#和 restful。 所以可以根据自己的需要做选择。

同时向量数据库的产品还是比较多的,尤其是python玩家,如果想要找一款玩玩,不建议用这个,因为这个还是比较大的。我用这款主要是之前在python 里用过,现在java也要用,懒得重装了,并且这个玩意正好支持java, 我也正好有java需求,所以继续用这个了。

milvus的安装如果用docker-compose的话还是比较简单的,建议用这种方式,具体安装不说了。

二. SpringBoot集成

最新我需要用springBoot把数据存到milvus向量库里,就搜了一下,发现相关文档太少了,要到了一个csds的,居然还要收费。就这点破玩意,有必要么,于是就自己深入研究了下,结合了网络上的一些文章,搞通了,免费分享给大家,也增加有下我博客的kpi, 后续csdn最多会搞个粉丝可见。

好开始。 首先我得milvus版本是: v2.3.4。 所以我上面说的那几个问题,在这个版本里还有,我的集成也是匹配的这个版本,如果你的版本高于我的,那么下面的代码可能不适用或不兼容。

好,开始,看过我SpringBoot系列文章的,肯定都比较了解集成步骤了,第一步,肯定是引入依赖。

代码语言:javascript复制
<!--milvus 向量数据库 client sdk -->
<dependency>
	<groupId>io.milvus</groupId>
	<artifactId>milvus-sdk-java</artifactId>
	<version>2.3.3</version>
</dependency>

这里用的是2.3.3 之前用2.3.4来着。 但是有一些jar包冲突比较严重,包了一个什么Google的 isEmpty找不到,还是啥来着,记不清了,主要是跟我的mysql-connector冲突了,剔除了一下,如果你们有别的问题,自己想办法解决一下。

代码语言:javascript复制
<!-- MySQL连接 -->
        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>8.0.19</version>
            <scope>runtime</scope>
            <exclusions>
                <exclusion>
                    <artifactId>protobuf-java</artifactId>
                    <groupId>com.google.protobuf</groupId>
                </exclusion>
            </exclusions>
        </dependency>

好了,接下来,引入配置,这个是参考的网上的一篇文档,也就是大家一搜,搜到最多的文档,按照他这个写的,他这边文档写的不错,就是只写了创建Collection和查询的用法,比较少。在config包下,添加两个类:

​MilvusConfig​

代码语言:javascript复制
import io.milvus.client.MilvusServiceClient;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Scope;

/**
 * @author qingf
 * @title MilvusConfig
 * @description
 * @create 2024/4/23
 */
@Configuration
public class MilvusConfig {
    /**
     *  milvus ip addr
     */
    @Value("${milvus.config.ipAddr}")
    private String ipAddr;

    /**
     * milvus   port
     */
    @Value("${milvus.config.port}")
    private Integer  port;

    @Bean
    @Scope("singleton")
    public MilvusServiceClient getMilvusClient() {
        return getMilvusFactory().getMilvusClient();
    }

    @Bean(initMethod = "init", destroyMethod = "close")
    public MilvusRestClientFactory getMilvusFactory() {
        return  MilvusRestClientFactory.build(ipAddr, port);
    }

}

这里注意,配了两个变量, 就是milvus的ip和端口, 把自己的地址,在配置文件里边加一下,别上来就抄,抄完跑不了,还不知道咋回事。

第二个类:

​MilvusRestClientFactory​

代码语言:javascript复制
import io.milvus.client.MilvusServiceClient;
import io.milvus.param.ConnectParam;

/**
 * @author qingf
 * @title MilvusRestClientFactory
 * @description
 * @create 2024/4/23
 */
public class MilvusRestClientFactory {

    private static String  IP_ADDR;

    private static Integer PORT ;

    private MilvusServiceClient milvusServiceClient;

    private ConnectParam.Builder  connectParamBuilder;


    private static MilvusRestClientFactory milvusRestClientFactory = new MilvusRestClientFactory();

    private MilvusRestClientFactory(){

    }

    public static MilvusRestClientFactory build(String ipAddr, Integer  port) {
        IP_ADDR = ipAddr;
        PORT = port;
        return milvusRestClientFactory;
    }

    private ConnectParam.Builder connectParamBuilder(String host, int port) {
        return  ConnectParam.newBuilder().withHost(host).withPort(port);
    }



    public void init() {
        connectParamBuilder =  connectParamBuilder(IP_ADDR,PORT);
        ConnectParam connectParam = connectParamBuilder.build();
        milvusServiceClient =new MilvusServiceClient(connectParam);
    }


    public MilvusServiceClient getMilvusClient() {
        return milvusServiceClient;
    }


    public void close() {
        if (milvusServiceClient != null) {
            try {
                milvusServiceClient.close();
            } catch (Exception e) {
                e.printStackTrace();
            }
        }
    }
}

初始化的时候,会构建一个MilvusServiceClient对象,这个就是操作milvus的主要客户端。有了这个类就可以用它连接milvus并进行相关操作了

三、 工具类封装

接下来是重点,就是如何操作,也就是调用相关api,实现增删改查,这一部分的代码就比较少了。在开始之前,还是得介绍一些milvus中的一些常用概念。

首先是DataBase: 就是相当于mysql中的库,先有库才能建表,这个东西在milvus中可以不建, 会默认用一个叫default的库。

然后Collection: 类似与mysql中的表,必须创建,自己起名,创建的时候,还要指定有哪些字段,都是什么类型。然后就是milvus里的主键必须用int64,对应Java中的Long, 所以如果你java里设计的是int类型,插入的时候,记得搞成Long

index: 就是索引,跟mysql里的类似,便于快速查询。 由于我们用的是向量数据库,所以必须要把向量字段设置成索引,不设置的话没法load, 没法load就没法使用

Load: 加载Collection. 只有被加载过的Collection, 才能进行增删改查的操作。

介绍这么多,目的只有一个,就是我在做milvus初始化的时候,会先判断这个Collection存不存在,不存在则进行创建,创建之后要添加索引,然后要进行Load, 初始化的时候为啥干这些事,上面就是答案。

好了封装工具类了,增删改查都有,这里要注意,生成向量我用了一个工具类,这个工具类是我们内部封装的,调用了内部的一个向量模型,你们应该是用不了,自己研究自己的向量模型去。

代码语言:javascript复制
import com.alibaba.fastjson.JSON;
import com.google.common.collect.Lists;
import io.milvus.client.MilvusServiceClient;
import io.milvus.common.clientenum.ConsistencyLevelEnum;
import io.milvus.grpc.DataType;
import io.milvus.grpc.MutationResult;
import io.milvus.grpc.SearchResults;
import io.milvus.param.IndexType;
import io.milvus.param.MetricType;
import io.milvus.param.R;
import io.milvus.param.RpcStatus;
import io.milvus.param.collection.CreateCollectionParam;
import io.milvus.param.collection.FieldType;
import io.milvus.param.collection.HasCollectionParam;
import io.milvus.param.collection.LoadCollectionParam;
import io.milvus.param.dml.InsertParam;
import io.milvus.param.dml.SearchParam;
import io.milvus.param.dml.UpsertParam;
import io.milvus.param.highlevel.dml.DeleteIdsParam;
import io.milvus.param.highlevel.dml.QuerySimpleParam;
import io.milvus.param.highlevel.dml.response.DeleteResponse;
import io.milvus.param.highlevel.dml.response.QueryResponse;
import io.milvus.param.index.CreateIndexParam;
import io.milvus.response.QueryResultsWrapper;
import io.milvus.response.SearchResultsWrapper;
import org.springframework.stereotype.Component;

import java.util.*;
@Component
public class MilvusClient {
    private static final String COLLECTION_NAME = "FAQ";
    private static final int VECTOR_DIM = 1024;
    private static final String ID_FIELD = "id";
    private static final String VECTOR_FIELD = "title_vector";
    private static final String TITLE = "title";
    private static final String CONTENT = "content";

    private final MilvusServiceClient client;

    private final EmbeddingClient embeddingClient;


    public MilvusClient(MilvusServiceClient client, EmbeddingClient embeddingClient) {
        this.client = client;
        this.embeddingClient = embeddingClient;
    }

    public R<RpcStatus> createCollection() {
        FieldType id = FieldType.newBuilder()
                .withName(ID_FIELD)
                .withDataType(DataType.Int64)
                .withPrimaryKey(true)
                .withAutoID(false)
                .withDescription(ID_FIELD)
                .build();
        FieldType type_id  = FieldType.newBuilder()
                .withName("type_id")
                .withDataType(DataType.Int64)
                .withDescription("type_id")
                .build();
        FieldType title  = FieldType.newBuilder()
                .withName("title")
                .withDataType(DataType.VarChar)
                .withMaxLength(10000)
                .withDescription("title")
                .build();
        FieldType content  = FieldType.newBuilder()
                .withName("content")
                .withDataType(DataType.VarChar)
                .withMaxLength(10000)
                .withDescription("content")
                .build();
        FieldType title_vector = FieldType.newBuilder()
                .withName(VECTOR_FIELD)
                .withDescription(VECTOR_FIELD)
                .withDataType(DataType.FloatVector)
                .withDimension(VECTOR_DIM)
                .build();

        CreateCollectionParam param = CreateCollectionParam.newBuilder()
                .withCollectionName(COLLECTION_NAME)
                .addFieldType(id)
                .addFieldType(type_id)
                .addFieldType(title)
                .addFieldType(content)
                .addFieldType(title_vector)
                .build();

        R<RpcStatus> response = client.createCollection(param);
        if (response.getStatus() != R.Status.Success.getCode()) {
            System.out.println(response.getMessage());
        }
        return response;
    }

    public Boolean isExitCollection(){
        R<Boolean> response = client.hasCollection(
                HasCollectionParam.newBuilder()
                        .withCollectionName(COLLECTION_NAME)
                        .build());
        return response.getData();
    }


    public R<RpcStatus> loadCollection() {
        LoadCollectionParam param = LoadCollectionParam.newBuilder()
                .withCollectionName(COLLECTION_NAME)
                .withReplicaNumber(1)
                .withSyncLoad(Boolean.TRUE)
                .withSyncLoadWaitingInterval(500L)
                .withSyncLoadWaitingTimeout(30L)
                .build();
        R<RpcStatus> response = client.loadCollection(param);
        if (response.getStatus() != R.Status.Success.getCode()) {
            System.out.println(response.getMessage());
        }
        return response;
    }

    public R<RpcStatus> createIndex() {
        CreateIndexParam param = CreateIndexParam.newBuilder()
                .withCollectionName(COLLECTION_NAME)
                .withFieldName(VECTOR_FIELD)
                .withIndexType(IndexType.GPU_IVF_FLAT)
                .withMetricType(MetricType.L2)
                .withExtraParam("{"nlist":2048}")
                .build();
        R<RpcStatus> response = client.createIndex(param);
        if (response.getStatus() != R.Status.Success.getCode()) {
            System.out.println(response.getMessage());
        }
        return response;
    }

    public void insertMilvus(FaqRecord record) {
        List<InsertParam.Field> fields = new ArrayList<>();
        fields.add(new InsertParam.Field("id", Collections.singletonList(Long.valueOf(record.getId()))));
        fields.add(new InsertParam.Field("type_id", Collections.singletonList(Long.valueOf(record.getTypeId()))));
        fields.add(new InsertParam.Field("title", Collections.singletonList(record.getTitle())));
        fields.add(new InsertParam.Field("content", Collections.singletonList(record.getContent())));
        fields.add(new InsertParam.Field("title_vector", embeddingClient.getEmbedding(record.getTitle())));
        InsertParam insertParam = InsertParam.newBuilder()
                .withCollectionName(COLLECTION_NAME)
                .withFields(fields)
                .build();
        R<MutationResult> response = client.insert(insertParam);
        if (response.getStatus() != R.Status.Success.getCode()) {
            System.out.println(response.getMessage());
        }
    }

    public void updateMilvus(FaqRecord record) {
        List<InsertParam.Field> fields = new ArrayList<>();
        fields.add(new InsertParam.Field("id", Collections.singletonList(Long.valueOf(record.getId()))));
        fields.add(new InsertParam.Field("type_id", Collections.singletonList(Long.valueOf(record.getTypeId()))));
        fields.add(new InsertParam.Field("title", Collections.singletonList(record.getTitle())));
        fields.add(new InsertParam.Field("content", Collections.singletonList(record.getContent())));
        fields.add(new InsertParam.Field("title_vector", embeddingClient.getEmbedding(record.getTitle())));
        UpsertParam insertParam = UpsertParam.newBuilder()
                .withCollectionName(COLLECTION_NAME)
                .withFields(fields)
                .build();
        R<MutationResult> response = client.upsert(insertParam);
        if (response.getStatus() != R.Status.Success.getCode()) {
            System.out.println(response.getMessage());
        }
    }

    public List list(){
        List<QueryResultsWrapper.RowRecord> rowRecords = new ArrayList<>();
        QuerySimpleParam querySimpleParam = QuerySimpleParam.newBuilder()
                .withCollectionName(COLLECTION_NAME)
                .withOutputFields(Lists.newArrayList("*"))
                .withFilter("id > 0")
                .withLimit(100L)
                .withOffset(0L)
                .build();
        R<QueryResponse> response = client.query(querySimpleParam);
        if (response.getStatus() != R.Status.Success.getCode()) {
            System.out.println(response.getMessage());
        }

        for (QueryResultsWrapper.RowRecord rowRecord : response.getData().getRowRecords()) {
            System.out.println(rowRecord);
            rowRecords.add(rowRecord);
        }
        return rowRecords;
    }

    public void delete(Integer id) {
        List<Integer> ids = Lists.newArrayList(id);
        DeleteIdsParam param = DeleteIdsParam.newBuilder()
                .withCollectionName(COLLECTION_NAME)
                .withPrimaryIds(ids)
                .build();

        R<DeleteResponse> response = client.delete(param);
        if (response.getStatus() != R.Status.Success.getCode()) {
            System.out.println(response.getMessage());
        }

        for (Object deleteId : response.getData().getDeleteIds()) {
            System.out.println(deleteId);
        }
    }


    public List search(String keyword, Integer topK) {
        List<FaqRecordMilvusVO> result = new ArrayList<>();
        List<List<Float>> targetVectors = embeddingClient.getEmbedding(keyword);
        SearchParam param = SearchParam.newBuilder()
                .withCollectionName(COLLECTION_NAME)
                .withMetricType(MetricType.L2)
                .withTopK(topK)
                .withVectors(targetVectors)
                .withVectorFieldName(VECTOR_FIELD)
                .withConsistencyLevel(ConsistencyLevelEnum.EVENTUALLY)
                .withOutFields(Arrays.asList(ID_FIELD, TITLE, CONTENT ))
                .withParams("{"nprobe":10}")
                .build();
        R<SearchResults> response = client.search(param);
        if (response.getStatus() != R.Status.Success.getCode()) {
            System.out.println(response.getMessage());
        }

        SearchResultsWrapper wrapper = new SearchResultsWrapper(response.getData().getResults());
        System.out.println("Search results:");
        for (int i = 0; i < targetVectors.size();   i) {
            List<SearchResultsWrapper.IDScore> scores = wrapper.getIDScore(i);
            for (SearchResultsWrapper.IDScore score:scores) {
                System.out.println(score);
                Float scoreValue = score.getScore();
                Map<String, Object> fieldValues = score.getFieldValues();
                FaqRecordMilvusVO vo = JSON.parseObject(JSON.toJSONString(fieldValues), FaqRecordMilvusVO.class);
                vo.setScore(scoreValue);
                result.add(vo);
            }
        }
        return result;
    }
}

这里有两个类,FaqRecordMilvusVO , FaqRecord 是我自己定义的两个pojo,方便存数据的,你可能没用,那就用你自己的。上述所有代码都经过测试过。 这里唯一提醒的就是我的向量模型返回的是List<List<Float>> 这里传向量的时候,注意点,得嵌套一层。

最后再把我做初始化的代码也发一下吧,初始话的目的就是,我项目启动的时候,判断有没有Collection,没有的话自己创建一个,并完成加载。 还有就是还可以创建Database, Partition,但是都不是必须得,不干这两件事没事,会用默认的,所以我这里没有。

项目初始化代码如下:

代码语言:javascript复制
@Component
@Slf4j
public class MilvusInitRunner implements ApplicationRunner {

    @Resource
    private MilvusClient milvusClient;

    @Override
    public void run(ApplicationArguments args) throws Exception {
        log.info("Milvus init start...");
        Boolean exitCollection = milvusClient.isExitCollection();
        log.info("Milvus check collection [FAQ] exits Result: "   exitCollection);
        if (!exitCollection) {
            log.info("Milvus create collection [FAQ] start...");
            R<RpcStatus> createResult = milvusClient.createCollection();
            log.info("Milvus create collection [FAQ] result: {}", createResult);

            log.info("Milvus create index start...");
            R<RpcStatus> index = milvusClient.createIndex();
            log.info("Milvus create index result: {}", index);

            log.info("Milvus load collection start...");
            R<RpcStatus> loadResult = milvusClient.loadCollection();
            log.info("Milvus load collection result: {}", loadResult);

        }
        log.info("Milvus init end...");
    }
}

好了,写完了,再会!

0 人点赞