java实现Word2Vec计算语义相似度,AI入门,附源码,分步骤详细注释版

2024-08-12 20:53:15 浏览数 (3)

1.准备工作,源码以及预训练文件

源码以及预训练文件比较大 下载地址https://pan.quark.cn/s/aeb85eaf95e2

2.核心代码Main函数

代码语言:java复制
public class Main {

    public static void main(String[] args) throws IOException {

        // 输入的两个句子

        String input1 = "一寸光阴一寸金,寸金难买寸光阴。";

        String input2 = "光阴似箭";

        // 词向量模型文件路径

        String modelFile = "baike_26g_news_13g_novel_229g.bin";

        // 读取词向量模型文件

        InputStream is = ClassLoader.getSystemResourceAsStream(modelFile);

        BufferedInputStream bufferedInputStream = new BufferedInputStream(Objects.requireNonNull(is), 1024 * 1024);

        Word2Vec word2VecModel = WordVectorSerializer.readBinaryModel(bufferedInputStream, false, true);

        // 计算并输出两个句子的相似度

        System.out.println(sentenceSimilarity(input1, input2, word2VecModel));

    }

3.根据文本内容获取对应的词向量列表

代码语言:java复制
    /**

     * 根据文本内容获取对应的词向量列表

     * @param text 文本内容

     * @param model 词向量模型

     * @return 词向量列表

     */

    private static List<INDArray> getWordVectors(String text, Word2Vec model) {

        // 将文本分词

        List<String> words = segmentWords(text.toLowerCase(Locale.getDefault()));

        // 创建一个列表来存储词向量

        List<INDArray> wordVectors = new ArrayList<>(words.size());

        for (String word : words) {

            if (model.hasWord(word)) {

                wordVectors.add(model.getWordVectorMatrix(word));

            } else {

                // 如果单词不在词汇表中,使用默认向量(这里使用零向量)

                int vectorSize = model.getLayerSize(); // 获取词向量的大小

                INDArray defaultVector = Nd4j.zeros(1, vectorSize); // 创建零向量

                wordVectors.add(defaultVector);

            }

        }

        return wordVectors;

    }

4.对句子进行分词处理

代码语言:java复制
    /**

     * 对句子进行分词处理

     * @param sentence 待分词的句子

     * @return 分词后的词语列表

     */

    private static List<String> segmentWords(String sentence) {

        JiebaSegmenter segmenter = new JiebaSegmenter();

        return segmenter.sentenceProcess(sentence).stream()

                .filter(e -> !" ".equals(e) && !e.isEmpty())

                .collect(Collectors.toList());

    }

5.计算两个向量的余弦相似度

代码语言:java复制
    /**

     * 计算两个向量的余弦相似度

     * @param vec1 第一个向量

     * @param vec2 第二个向量

     * @return 余弦相似度值

     */

    private static double cosineSimilarity(INDArray vec1, INDArray vec2) {

        // 计算两个向量的点积

        double dotProduct = vec1.mulRowVector(vec2).sumNumber().doubleValue();

        // 计算两个向量的模长

        double norm1 = vec1.norm2Number().doubleValue();

        double norm2 = vec2.norm2Number().doubleValue();

        // 计算余弦相似度

        return dotProduct / (norm1 * norm2);

    }

6.计算两个句子的相似度

代码语言:java复制
    /**

     * 计算两个句子的相似度

     * @param sentence1 第一个句子

     * @param sentence2 第二个句子

     * @param model 词向量模型

     * @return 句子相似度值

     */

    private static double sentenceSimilarity(String sentence1, String sentence2, Word2Vec model) {

        List<INDArray> vectors1 = getWordVectors(sentence1, model);

        List<INDArray> vectors2 = getWordVectors(sentence2, model);

        INDArray avgVector1 = getAverageVector(vectors1, model.getLayerSize());

        INDArray avgVector2 = getAverageVector(vectors2, model.getLayerSize());

        return cosineSimilarity(avgVector1, avgVector2);

    }

7.计算一组向量的平均值向量

代码语言:java复制
    /**

     * 计算一组向量的平均值向量

     * @param vectors 向量列表

     * @param modelSize 向量维度大小

     * @return 平均向量

     */

    private static INDArray getAverageVector(List<INDArray> vectors, int modelSize) {

        INDArray sumVector = Nd4j.zeros(1, modelSize); // 创建一个与第一个向量形状相同的零向量

        for (INDArray vector : vectors) {

            sumVector.addiRowVector(vector); // 使用addi进行原地操作

        }

        INDArray indArray = sumVector.div(vectors.size());

        sumVector.close();

        return indArray; // 将总和除以向量数量以获得平均值

    }

}

8.依赖

代码语言:xml复制
<?xml version="1.0" encoding="UTF-8"?>

<project xmlns="http://maven.apache.org/POM/4.0.0"

         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"

         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">

    <modelVersion>4.0.0</modelVersion>



    <groupId>org.example</groupId>

    <artifactId>java-nlp</artifactId>

    <version>1.0-SNAPSHOT</version>



    <properties>

        <maven.compiler.source>17</maven.compiler.source>

        <maven.compiler.target>17</maven.compiler.target>

        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>

    </properties>

    <dependencies>

        <dependency>

            <groupId>org.deeplearning4j</groupId>

            <artifactId>deeplearning4j-nlp</artifactId>

            <version>1.0.0-M2.1</version>

        </dependency>

        <dependency>

            <groupId>nz.ac.waikato.cms.weka</groupId>

            <artifactId>weka-stable</artifactId>

            <version>3.8.6</version>

        </dependency>

        <dependency>

            <groupId>com.huaban</groupId>

            <artifactId>jieba-analysis</artifactId>

            <version>1.0.2</version>

        </dependency>

        <dependency>

            <groupId>ch.qos.logback</groupId>

            <artifactId>logback-classic</artifactId>

            <version>1.5.6</version>

        </dependency>

        <dependency>

            <groupId>org.nd4j</groupId>

            <artifactId>nd4j-native-platform</artifactId>

            <version>1.0.0-M2.1</version>

        </dependency>

    </dependencies>



</project>

9. 预训练文件

预训练文件预训练文件

运行结果

运行结果运行结果

觉得有用请点赞,有问题请在评论区留言

0 人点赞