
1. 环境准备和数据加载
addpath(genpath('bnt'));
n_samples = 1000;
n_features = 3;
n_classes = 3;
X = randn(n_samples, n_features);
true_weights = [2, -1, 0.5];
y_scores = X * true_weights' + randn(n_samples, 1)*0.5;
y = discretize(y_scores, [-inf, -0.5, 0.5, inf]);
train_ratio = 0.7;
n_train = floor(n_samples * train_ratio);
X_train = X(1:n_train, :);
y_train = y(1:n_train);
X_test = X(n_train+1:end, :);
y_test = y(n_train+1:end);
2. 构建贝叶斯网络结构
function dag = create_classification_bnet(n_features, n_classes)dag = zeros(n_features + 1, n_features + 1);class_node = 1;for i = 1:n_featuresfeature_node = i + 1;dag(class_node, feature_node) = 1;endfprintf('构建了包含%d个特征的分类网络\n', n_features);
end
3. 参数学习和模型训练
function bnet = train_bayesian_classifier(X_train, y_train, n_classes)[n_samples, n_features] = size(X_train);dag = create_classification_bnet(n_features, n_classes);ns = ones(1, n_features + 1);ns(1) = n_classes; for i = 1:n_featuresns(i+1) = 3; endbnet = mk_bnet(dag, ns);bnet.CPD{1} = tabular_CPD(bnet, 1); for i = 1:n_featuresbnet.CPD{i+1} = tabular_CPD(bnet, i+1); enddata = cell(n_features + 1, n_samples);data(1, :) = num2cell(y_train');for i = 1:n_featuresfeature_data = X_train(:, i);edges = linspace(min(feature_data), max(feature_data), 4);disc_data = discretize(feature_data, edges);data(i+1, :) = num2cell(disc_data');endbnet = learn_params(bnet, data);fprintf('贝叶斯分类器训练完成\n');
end
4. 预测函数
function [predictions, probabilities] = bayesian_predict(bnet, X_test, n_classes)[n_test, n_features] = size(X_test);test_data = cell(n_features + 1, n_test);test_data(1, :) = {[]}; for i = 1:n_featuresfeature_data = X_test(:, i);edges = linspace(min(feature_data), max(feature_data), 4);disc_data = discretize(feature_data, edges);test_data(i+1, :) = num2cell(disc_data');endengine = jtree_inf_engine(bnet);predictions = zeros(n_test, 1);probabilities = zeros(n_test, n_classes);for i = 1:n_testevidence = test_data(:, i);[engine, loglik] = enter_evidence(engine, evidence);marg = marginal_nodes(engine, 1);probabilities(i, :) = marg.T;[~, predictions(i)] = max(marg.T);end
end
5. 完整的工作流程
function main_multiple_input_classification()fprintf('准备数据...\n');[X, y, X_train, y_train, X_test, y_test] = prepare_data();fprintf('训练贝叶斯分类器...\n');n_classes = length(unique(y));bnet = train_bayesian_classifier(X_train, y_train, n_classes);fprintf('进行预测...\n');[predictions, probabilities] = bayesian_predict(bnet, X_test, n_classes);fprintf('评估模型性能...\n');evaluate_model(y_test, predictions, probabilities);visualize_results(X_test, y_test, predictions, probabilities);
endfunction [X, y, X_train, y_train, X_test, y_test] = prepare_data()n_samples = 2000;n_features = 4;X = zeros(n_samples, n_features);X(:,1) = randn(n_samples, 1);X(:,2) = rand(n_samples, 1) * 10;X(:,3) = 0.5 * X(:,1) + randn(n_samples, 1) * 0.5;X(:,4) = randi(3, n_samples, 1);scores = 2*X(:,1) - 1.5*X(:,2) + 0.8*X(:,3) + 0.5*(X(:,4)-2);noise = randn(n_samples, 1) * 0.3;y_scores = scores + noise;y = discretize(y_scores, [-inf, -2, 2, inf]);train_ratio = 0.7;n_train = floor(n_samples * train_ratio);X_train = X(1:n_train, :);y_train = y(1:n_train);X_test = X(n_train+1:end, :);y_test = y(n_train+1:end);
endfunction evaluate_model(y_true, y_pred, probabilities)accuracy = sum(y_true == y_pred) / length(y_true);fprintf('准确率: %.4f\n', accuracy);cm = confusionmat(y_true, y_pred);fprintf('混淆矩阵:\n');disp(cm);for i = 1:max(y_true)class_idx = (y_true == i);class_acc = sum(y_true(class_idx) == y_pred(class_idx)) / sum(class_idx);fprintf('类别 %d 准确率: %.4f\n', i, class_acc);endmean_prob = mean(max(probabilities, [], 2));fprintf('平均预测置信度: %.4f\n', mean_prob);
endfunction visualize_results(X_test, y_test, predictions, probabilities)figure;if size(X_test, 2) >= 2subplot(2, 2, 1);gscatter(X_test(:,1), X_test(:,2), y_test, 'rgb', 'o');title('真实类别');xlabel('特征1'); ylabel('特征2');subplot(2, 2, 2);gscatter(X_test(:,1), X_test(:,2), predictions, 'rgb', 'x');title('预测类别');xlabel('特征1'); ylabel('特征2');subplot(2, 2, 3);correct = (y_test == predictions);gscatter(X_test(:,1), X_test(:,2), correct, 'br', 'o*');title('分类结果 (蓝色:正确, 红色:错误)');xlabel('特征1'); ylabel('特征2');subplot(2, 2, 4);confidence = max(probabilities, [], 2);scatter(X_test(:,1), X_test(:,2), 50, confidence, 'filled');colorbar;title('预测置信度');xlabel('特征1'); ylabel('特征2');end
end
6. 高级功能:连续特征处理
function bnet = train_with_continuous_features(X_train, y_train, n_classes)[n_samples, n_features] = size(X_train);dag = create_classification_bnet(n_features, n_classes);discrete_nodes = 1;continuous_nodes = 2:(n_features+1);ns = ones(1, n_features + 1);ns(discrete_nodes) = n_classes;bnet = mk_bnet(dag, ns, 'discrete', discrete_nodes);bnet.CPD{1} = tabular_CPD(bnet, 1); for i = continuous_nodesbnet.CPD{i} = gaussian_CPD(bnet, i, 'cov_type', 'diag');enddata = cell(n_features + 1, n_samples);data(1, :) = num2cell(y_train');for i = 1:n_featuresdata{i+1, :} = num2cell(X_train(:, i)');endbnet = learn_params(bnet, data);
end