1.准备工作,源码以及预训练文件
源码以及预训练文件比较大 下载地址https://pan.quark.cn/s/aeb85eaf95e2
2.核心代码Main函数
代码语言:java复制public class Main {
public static void main(String[] args) throws IOException {
// 输入的两个句子
String input1 = "一寸光阴一寸金,寸金难买寸光阴。";
String input2 = "光阴似箭";
// 词向量模型文件路径
String modelFile = "baike_26g_news_13g_novel_229g.bin";
// 读取词向量模型文件
InputStream is = ClassLoader.getSystemResourceAsStream(modelFile);
BufferedInputStream bufferedInputStream = new BufferedInputStream(Objects.requireNonNull(is), 1024 * 1024);
Word2Vec word2VecModel = WordVectorSerializer.readBinaryModel(bufferedInputStream, false, true);
// 计算并输出两个句子的相似度
System.out.println(sentenceSimilarity(input1, input2, word2VecModel));
}
3.根据文本内容获取对应的词向量列表
代码语言:java复制 /**
* 根据文本内容获取对应的词向量列表
* @param text 文本内容
* @param model 词向量模型
* @return 词向量列表
*/
private static List<INDArray> getWordVectors(String text, Word2Vec model) {
// 将文本分词
List<String> words = segmentWords(text.toLowerCase(Locale.getDefault()));
// 创建一个列表来存储词向量
List<INDArray> wordVectors = new ArrayList<>(words.size());
for (String word : words) {
if (model.hasWord(word)) {
wordVectors.add(model.getWordVectorMatrix(word));
} else {
// 如果单词不在词汇表中,使用默认向量(这里使用零向量)
int vectorSize = model.getLayerSize(); // 获取词向量的大小
INDArray defaultVector = Nd4j.zeros(1, vectorSize); // 创建零向量
wordVectors.add(defaultVector);
}
}
return wordVectors;
}
4.对句子进行分词处理
代码语言:java复制 /**
* 对句子进行分词处理
* @param sentence 待分词的句子
* @return 分词后的词语列表
*/
private static List<String> segmentWords(String sentence) {
JiebaSegmenter segmenter = new JiebaSegmenter();
return segmenter.sentenceProcess(sentence).stream()
.filter(e -> !" ".equals(e) && !e.isEmpty())
.collect(Collectors.toList());
}
5.计算两个向量的余弦相似度
代码语言:java复制 /**
* 计算两个向量的余弦相似度
* @param vec1 第一个向量
* @param vec2 第二个向量
* @return 余弦相似度值
*/
private static double cosineSimilarity(INDArray vec1, INDArray vec2) {
// 计算两个向量的点积
double dotProduct = vec1.mulRowVector(vec2).sumNumber().doubleValue();
// 计算两个向量的模长
double norm1 = vec1.norm2Number().doubleValue();
double norm2 = vec2.norm2Number().doubleValue();
// 计算余弦相似度
return dotProduct / (norm1 * norm2);
}
6.计算两个句子的相似度
代码语言:java复制 /**
* 计算两个句子的相似度
* @param sentence1 第一个句子
* @param sentence2 第二个句子
* @param model 词向量模型
* @return 句子相似度值
*/
private static double sentenceSimilarity(String sentence1, String sentence2, Word2Vec model) {
List<INDArray> vectors1 = getWordVectors(sentence1, model);
List<INDArray> vectors2 = getWordVectors(sentence2, model);
INDArray avgVector1 = getAverageVector(vectors1, model.getLayerSize());
INDArray avgVector2 = getAverageVector(vectors2, model.getLayerSize());
return cosineSimilarity(avgVector1, avgVector2);
}
7.计算一组向量的平均值向量
代码语言:java复制 /**
* 计算一组向量的平均值向量
* @param vectors 向量列表
* @param modelSize 向量维度大小
* @return 平均向量
*/
private static INDArray getAverageVector(List<INDArray> vectors, int modelSize) {
INDArray sumVector = Nd4j.zeros(1, modelSize); // 创建一个与第一个向量形状相同的零向量
for (INDArray vector : vectors) {
sumVector.addiRowVector(vector); // 使用addi进行原地操作
}
INDArray indArray = sumVector.div(vectors.size());
sumVector.close();
return indArray; // 将总和除以向量数量以获得平均值
}
}
8.依赖
代码语言:xml复制<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<groupId>org.example</groupId>
<artifactId>java-nlp</artifactId>
<version>1.0-SNAPSHOT</version>
<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nlp</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
<dependency>
<groupId>nz.ac.waikato.cms.weka</groupId>
<artifactId>weka-stable</artifactId>
<version>3.8.6</version>
</dependency>
<dependency>
<groupId>com.huaban</groupId>
<artifactId>jieba-analysis</artifactId>
<version>1.0.2</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<version>1.5.6</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native-platform</artifactId>
<version>1.0.0-M2.1</version>
</dependency>
</dependencies>
</project>
9. 预训练文件
运行结果
觉得有用请点赞,有问题请在评论区留言