当前位置: 首页 > news >正文

【基于ALS模型的教育视频推荐系统(Java实现)】

【基于ALS模型的教育视频推荐系统(Java实现)】

下面是一个完整的基于交替最小二乘法(ALS)的教育视频推荐系统实现,包含数据预处理、模型训练、推荐生成和评估模块。

1. 系统架构

edu-recommender/
├── src/
│   ├── main/
│   │   ├── java/
│   │   │   ├── model/          # 数据模型
│   │   │   ├── algorithm/      # ALS算法实现
│   │   │   ├── service/        # 业务逻辑
│   │   │   ├── util/           # 工具类
│   │   │   └── Main.java       # 入口类
│   │   └── resources/          # 配置文件
├── pom.xml                     # Maven依赖
└── data/                       # 示例数据集

2.1 数据模型类

// Video.java
package model;public class Video {private int id;private String title;private String category;private double duration; // 分钟// 构造函数、getter和setter
}// User.java
package model;public class User {private int id;private String username;private String educationLevel;// 构造函数、getter和setter
}// Rating.java
package model;public class Rating {private int userId;private int videoId;private double score; // 1-5分// 构造函数、getter和setter
}

2.2 ALS算法实现

package algorithm;import java.util.*;public class ALS {private int numFeatures; // 特征维度private double lambda; // 正则化参数private int maxIter; // 最大迭代次数private double[][] userFeatures; // 用户特征矩阵private double[][] itemFeatures; // 物品特征矩阵public ALS(int numFeatures, double lambda, int maxIter) {this.numFeatures = numFeatures;this.lambda = lambda;this.maxIter = maxIter;}// 训练模型public void train(List<Rating> ratings, int numUsers, int numItems) {// 初始化特征矩阵Random rand = new Random();userFeatures = new double[numUsers][numFeatures];itemFeatures = new double[numItems][numFeatures];for (int i = 0; i < numUsers; i++) {for (int j = 0; j < numFeatures; j++) {userFeatures[i][j] = rand.nextDouble();}}for (int i = 0; i < numItems; i++) {for (int j = 0; j < numFeatures; j++) {itemFeatures[i][j] = rand.nextDouble();}}// 交替优化for (int iter = 0; iter < maxIter; iter++) {// 固定物品特征,优化用户特征updateFeatures(ratings, userFeatures, itemFeatures, true);// 固定用户特征,优化物品特征updateFeatures(ratings, itemFeatures, userFeatures, false);double error = calculateRMSE(ratings);System.out.printf("Iteration %d, RMSE: %.4f\n", iter, error);}}// 更新特征矩阵private void updateFeatures(List<Rating> ratings, double[][] mainFeatures, double[][] fixedFeatures, boolean isUser) {Map<Integer, List<Rating>> groupedRatings = groupRatings(ratings, isUser);for (Map.Entry<Integer, List<Rating>> entry : groupedRatings.entrySet()) {int id = entry.getKey();List<Rating> group = entry.getValue();// 构建矩阵A和向量bdouble[][] A = new double[numFeatures][numFeatures];double[] b = new double[numFeatures];for (Rating r : group) {int otherId = isUser ? r.getVideoId() : r.getUserId();double[] otherVec = fixedFeatures[otherId];// A += otherVec * otherVec^Tfor (int i = 0; i < numFeatures; i++) {for (int j = 0; j < numFeatures; j++) {A[i][j] += otherVec[i] * otherVec[j];}b[i] += otherVec[i] * r.getScore();}}// 添加正则化项: A += lambda * Ifor (int i = 0; i < numFeatures; i++) {A[i][i] += lambda;}// 解线性方程组: A * x = bdouble[] newFeatures = solveLinearSystem(A, b);System.arraycopy(newFeatures, 0, mainFeatures[id], 0, numFeatures);}}// 分组评分数据private Map<Integer, List<Rating>> groupRatings(List<Rating> ratings, boolean byUser) {Map<Integer, List<Rating>> map = new HashMap<>();for (Rating r : ratings) {int key = byUser ? r.getUserId() : r.getVideoId();map.computeIfAbsent(key, k -> new ArrayList<>()).add(r);}return map;}// 解线性方程组(使用高斯消元法)private double[] solveLinearSystem(double[][] A, double[] b) {int n = b.length;double[] x = new double[n];// 前向消元for (int i = 0; i < n; i++) {// 找主元int maxRow = i;for (int k = i + 1; k < n; k++) {if (Math.abs(A[k][i]) > Math.abs(A[maxRow][i])) {maxRow = k;}}// 交换行double[] tempRow = A[i];A[i] = A[maxRow];A[maxRow] = tempRow;double tempVal = b[i];b[i] = b[maxRow];b[maxRow] = tempVal;// 消元for (int k = i + 1; k < n; k++) {double factor = A[k][i] / A[i][i];b[k] -= factor * b[i];for (int j = i; j < n; j++) {A[k][j] -= factor * A[i][j];}}}// 回代for (int i = n - 1; i >= 0; i--) {double sum = 0;for (int j = i + 1; j < n; j++) {sum += A[i][j] * x[j];}x[i] = (b[i] - sum) / A[i][i];}return x;}// 计算RMSEpublic double calculateRMSE(List<Rating> ratings) {double sumSquaredError = 0;for (Rating r : ratings) {double predicted = predict(r.getUserId(), r.getVideoId());sumSquaredError += Math.pow(predicted - r.getScore(), 2);}return Math.sqrt(sumSquaredError / ratings.size());}// 预测评分public double predict(int userId, int videoId) {double score = 0;for (int i = 0; i < numFeatures; i++) {score += userFeatures[userId][i] * itemFeatures[videoId][i];}return Math.max(1, Math.min(5, score)); // 限制在1-5分}// 为用户推荐视频public List<Integer> recommendVideos(int userId, int numRecommendations, int numVideos) {PriorityQueue<VideoScore> pq = new PriorityQueue<>();for (int videoId = 0; videoId < numVideos; videoId++) {double score = predict(userId, videoId);pq.offer(new VideoScore(videoId, score));if (pq.size() > numRecommendations) {pq.poll();}}List<Integer> recommendations = new ArrayList<>();while (!pq.isEmpty()) {recommendations.add(0, pq.poll().videoId);}return recommendations;}// 辅助类private static class VideoScore implements Comparable<VideoScore> {int videoId;double score;VideoScore(int videoId, double score) {this.videoId = videoId;this.score = score;}@Overridepublic int compareTo(VideoScore other) {return Double.compare(this.score, other.score);}}
}

2.3 推荐服务类

package service;import model.*;
import algorithm.ALS;
import java.util.*;public class RecommendationService {private List<User> users;private List<Video> videos;private List<Rating> ratings;private ALS alsModel;public RecommendationService(List<User> users, List<Video> videos, List<Rating> ratings) {this.users = users;this.videos = videos;this.ratings = ratings;}// 训练推荐模型public void trainModel() {int numUsers = users.size();int numVideos = videos.size();alsModel = new ALS(10, 0.01, 20); // 10个特征,lambda=0.01,20次迭代alsModel.train(ratings, numUsers, numVideos);}// 为用户生成推荐public List<Video> getRecommendations(int userId, int numRecs) {List<Integer> videoIds = alsModel.recommendVideos(userId, numRecs, videos.size());List<Video> recommendations = new ArrayList<>();for (int videoId : videoIds) {recommendations.add(videos.get(videoId));}return recommendations;}// 评估推荐系统public void evaluate() {// 划分训练集和测试集Collections.shuffle(ratings);int split = (int) (ratings.size() * 0.8);List<Rating> trainSet = ratings.subList(0, split);List<Rating> testSet = ratings.subList(split, ratings.size());// 在训练集上训练ALS tempModel = new ALS(10, 0.01, 20);tempModel.train(trainSet, users.size(), videos.size());// 在测试集上评估double rmse = tempModel.calculateRMSE(testSet);System.out.printf("Test RMSE: %.4f\n", rmse);}
}

2.4 主程序

import model.*;
import service.RecommendationService;
import java.util.*;public class Main {public static void main(String[] args) {// 1. 准备示例数据List<User> users = createUsers();List<Video> videos = createVideos();List<Rating> ratings = createRatings(users, videos);// 2. 创建推荐服务RecommendationService service = new RecommendationService(users, videos, ratings);// 3. 训练模型service.trainModel();// 4. 生成推荐int targetUserId = 0; // 为第一个用户推荐List<Video> recommendations = service.getRecommendations(targetUserId, 5);System.out.println("为用户 " + users.get(targetUserId).getUsername() + " 推荐的视频:");for (Video video : recommendations) {System.out.println("- " + video.getTitle() + " (" + video.getCategory() + ")");}// 5. 评估模型service.evaluate();}private static List<User> createUsers() {List<User> users = new ArrayList<>();users.add(new User(0, "张三", "本科"));users.add(new User(1, "李四", "硕士"));users.add(new User(2, "王五", "博士"));return users;}private static List<Video> createVideos() {List<Video> videos = new ArrayList<>();videos.add(new Video(0, "Java入门", "编程", 120));videos.add(new Video(1, "机器学习基础", "AI", 180));videos.add(new Video(2, "高等数学", "数学", 240));videos.add(new Video(3, "英语写作", "语言", 90));videos.add(new Video(4, "数据结构", "编程", 150));videos.add(new Video(5, "深度学习", "AI", 210));videos.add(new Video(6, "线性代数", "数学", 160));videos.add(new Video(7, "商务英语", "语言", 95));return videos;}private static List<Rating> createRatings(List<User> users, List<Video> videos) {List<Rating> ratings = new ArrayList<>();Random rand = new Random();// 为每个用户随机评分一些视频for (User user : users) {int numRatings = 3 + rand.nextInt(3); // 每个用户3-5个评分Set<Integer> ratedVideos = new HashSet<>();for (int i = 0; i < numRatings; i++) {int videoId;do {videoId = rand.nextInt(videos.size());} while (ratedVideos.contains(videoId));ratedVideos.add(videoId);double score = 1 + rand.nextInt(5); // 1-5分ratings.add(new Rating(user.getId(), videoId, score));}}return ratings;}
}

3. 系统优化建议

冷启动问题解决方案:

// 在RecommendationService中添加混合推荐方法
public List<Video> getHybridRecommendations(int userId, int numRecs) {// 如果新用户,使用基于内容的推荐if (isNewUser(userId)) {return getContentBasedRecommendations(userId, numRecs);}// 否则使用ALS推荐return getRecommendations(userId, numRecs);
}private boolean isNewUser(int userId) {return ratings.stream().noneMatch(r -> r.getUserId() == userId);
}private List<Video> getContentBasedRecommendations(int userId, int numRecs) {User user = users.get(userId);// 根据用户教育水平推荐同类视频return videos.stream().filter(v -> v.getCategory().equals(getPreferredCategory(user))).sorted(Comparator.comparingDouble(Video::getDuration).reversed()).limit(numRecs).collect(Collectors.toList());
}private String getPreferredCategory(User user) {// 简单逻辑:根据教育水平推荐类别switch(user.getEducationLevel()) {case "本科": return "编程";case "硕士": return "AI";case "博士": return "数学";default: return "语言";}
}

实时更新模型:

// 在ALS类中添加增量更新方法
public void updateModel(Rating newRating) {int userId = newRating.getUserId();int videoId = newRating.getVideoId();// 简单实现:重新计算相关用户和视频的特征updateUserFeatures(userId);updateVideoFeatures(videoId);
}private void updateUserFeatures(int userId) {// 获取该用户的所有评分List<Rating> userRatings = ratings.stream().filter(r -> r.getUserId() == userId).collect(Collectors.toList());// 重新计算用户特征(简化版)double[] newFeatures = new double[numFeatures];for (Rating r : userRatings) {for (int i = 0; i < numFeatures; i++) {newFeatures[i] += itemFeatures[r.getVideoId()][i] * r.getScore();}}userFeatures[userId] = newFeatures;
}

性能优化:

// 使用矩阵运算库替代手动实现
import org.apache.commons.math3.linear.*;// 修改ALS中的solveLinearSystem方法
private double[] solveLinearSystem(double[][] A, double[] b) {RealMatrix matrix = MatrixUtils.createRealMatrix(A);DecompositionSolver solver = new LUDecomposition(matrix).getSolver();return solver.solve(MatrixUtils.createRealVector(b)).toArray();
}
  1. 评估指标扩展
// 在RecommendationService中添加更多评估指标
public void fullEvaluation() {// 1. 划分训练测试集Collections.shuffle(ratings);int split = (int) (ratings.size() * 0.8);List<Rating> trainSet = ratings.subList(0, split);List<Rating> testSet = ratings.subList(split, ratings.size());// 2. 训练模型ALS tempModel = new ALS(10, 0.01, 20);tempModel.train(trainSet, users.size(), videos.size());// 3. 计算各项指标double rmse = calculateRMSE(tempModel, testSet);double precision = calculatePrecision(tempModel, testSet);double recall = calculateRecall(tempModel, testSet);System.out.println("=== 评估结果 ===");System.out.printf("RMSE: %.4f\n", rmse);System.out.printf("Precision@5: %.4f\n", precision);System.out.printf("Recall@5: %.4f\n", recall);
}private double calculateRMSE(ALS model, List<Rating> testRatings) {return model.calculateRMSE(testRatings);
}private double calculatePrecision(ALS model, List<Rating> testRatings) {int hits = 0;int total = 0;for (User user : users) {// 获取用户实际高评分视频(4分以上)Set<Integer> actualHighRated = testRatings.stream().filter(r -> r.getUserId() == user.getId() && r.getScore() >= 4).map(Rating::getVideoId).collect(Collectors.toSet());if (!actualHighRated.isEmpty()) {// 获取推荐视频List<Integer> recommended = model.recommendVideos(user.getId(), 5, videos.size());// 计算命中数for (int videoId : recommended) {if (actualHighRated.contains(videoId)) {hits++;}}total += recommended.size();}}return total > 0 ? (double) hits / total : 0;
}private double calculateRecall(ALS model, List<Rating> testRatings) {int hits = 0;int totalHighRated = 0;for (User user : users) {// 获取用户实际高评分视频(4分以上)Set<Integer> actualHighRated = testRatings.stream().filter(r -> r.getUserId() == user.getId() && r.getScore() >= 4).map(Rating::getVideoId).collect(Collectors.toSet());totalHighRated += actualHighRated.size();if (!actualHighRated.isEmpty()) {// 获取推荐视频List<Integer> recommended = model.recommendVideos(user.getId(), 5, videos.size());// 计算命中数for (int videoId : recommended) {if (actualHighRated.contains(videoId)) {hits++;}}}}return totalHighRated > 0 ? (double) hits / totalHighRated : 0;
}

这个实现提供了基于ALS的教育视频推荐系统完整框架,可以根据实际需求进一步扩展和优化。系统包含核心推荐算法、业务逻辑和评估模块,适合作为学术研究或中小型教育平台的推荐系统基础。

相关文章:

  • hashCode()和equals(),为什么使用Map要重写这两个,为什么重写了hashCode,equals也需要重写
  • csdn博客打赏功能
  • 小刚说C语言刷题—1149 - 回文数个数
  • 什么是IP专线?企业数字化转型的关键网络基础设施
  • 大小端的判断方法
  • cursor对话关键词技巧
  • spring boot3.0自定义校验注解:文章状态校验示例
  • PH热榜 | 2025-05-12
  • 前端vue+elementplus实现上传通用组件
  • SHAP分析!Transformer-GRU组合模型SHAP分析,模型可解释不在发愁!
  • HDFS客户端操作
  • 排查服务器内存空间预警思路
  • AI日报 - 2024年05月13日
  • 航电系统之电传飞行控制系统篇
  • Excel VBA 与 AcroForm 文档级脚本对比
  • MCU开启浮点计算FPU
  • [springboot]SSM日期数据转换易见问题
  • Linux电源管理(五),发热管理(thermal),温度控制
  • C 语 言 - - - 简 易 通 讯 录
  • Python 字符串
  • 习近平出席中国-拉美和加勒比国家共同体论坛第四届部长级会议开幕式
  • “影像上海”中的自媒体影像特展:无论何时,影像都需要空间
  • 习近平将出席中国—拉美和加勒比国家共同体论坛第四届部长级会议开幕式并发表重要讲话
  • 国内大模型人才大战打响!大厂各出奇招
  • “拼好假”的年轻人,今年有哪些旅游新玩法?
  • 央行设立服务消费与养老再贷款,额度5000亿元