训练交叉验证的高斯核支持向量机进行选择。

2022-05-28 15:41:05 浏览数 (1)

Q2_final.m

代码语言:javascript复制
%% Take Home Exam 4: Question 2
% Anja Deric | April 13, 2020

% Clear all variables and generate new training and testing data
clc; clear;
[train_data,train_labels] =  generateMultiringDataset(2,1000,1);
title('Training Data');
[test_data,test_labels] =  generateMultiringDataset(2,10000,2);
title('Validation Data');

%% Initial Training
% Train a Gaussian kernel SVM with cross-validation to select 
% hyperparameters that minimize probability of error

% Hyperparameter initialization
CList = 10.^linspace(-3,7,11);
sigmaList = 10.^linspace(-2,3,13);
lossVal = zeros(length(sigmaList),length(CList));

% Loop through all hyperparameters to find best combination
for sigmaCounter = 1:length(sigmaList)
    sigma = sigmaList(sigmaCounter);
    for CCounter = 1:length(CList)
        C = CList(CCounter);
        % Train SVM model for hyperparameter combination
        SVMModel = fitcsvm(train_data',train_labels,'BoxConstraint',C,...
            'KernelFunction','rbf','KernelScale',sigma);
        % Get cross-validated model (10-fold cross validation)
        CVSVMModel = crossval(SVMModel);
        % Store loss for SVM model
        lossVal(sigmaCounter,CCounter) = kfoldLoss(CVSVMModel); 
    end
end

%% Final Training

% Find best sigma and C combination
minLoss = min(lossVal(:));
[sigmaBest_ind, CBest_ind] = find(lossVal == minLoss);

% Train final SVM and predict labels on test data
best_SVMModel = fitcsvm(train_data',train_labels,'BoxConstraint',...
    CList(CBest_ind),'KernelFunction','rbf','KernelScale',...
    sigmaList(sigmaBest_ind));
prediction = best_SVMModel.predict(test_data');

% Calculate accuracy of model
incorrect = sum(prediction' ~= test_labels);
accuracy = ((length(test_data) - incorrect)/length(test_data))*100

%% Plot Classified Labels

% Find and plot all correctly and incorrectly classified test samples
indINCORRECT = find(prediction' ~= test_labels); % incorrectly classified
indCORRECT = find(prediction' == test_labels);   % correctly classified
plot(test_data(1,indCORRECT),test_data(2,indCORRECT),'g.'); hold on;
plot(test_data(1,indINCORRECT),test_data(2,indINCORRECT),'r.'); axis equal;
title('Training Data (RED: Incorrectly Classified)');
xlabel('X_1'); ylabel('X_2');

%% Functions

function [data,labels] = generateMultiringDataset(numberOfClasses,numberOfSamples,fig)
    % Generates N samples from C ring-shaped class-conditional pdfs with equal priors

    % Randomly determine class labels for each sample
    thr = linspace(0,1,numberOfClasses 1); % split [0,1] into C equal length intervals
    u = rand(1,numberOfSamples); % generate N samples uniformly random in [0,1]
    labels = zeros(1,numberOfSamples);
    for l = 1:numberOfClasses
        ind_l = find(thr(l)<u & u<=thr(l 1));
        labels(ind_l) = repmat(l,1,length(ind_l));
    end

    a = [1:numberOfClasses].^3; b = repmat(2,1,numberOfClasses); % parameters of the Gamma pdf needed later
    % Generate data from appropriate rings
    % radius is drawn from Gamma(a,b), angle is uniform in [0,2pi]
    angle = 2*pi*rand(1,numberOfSamples);
    radius = zeros(1,numberOfSamples); % reserve space
    for l = 1:numberOfClasses
        ind_l = find(labels==l);
        radius(ind_l) = gamrnd(a(l),b(l),1,length(ind_l));
    end

    data = [radius.*cos(angle);radius.*sin(angle)];

    if 1
        colors = rand(numberOfClasses,3);
        figure(fig); clf;
        for l = 1:numberOfClasses
            ind_l = find(labels==l);
            plot(data(1,ind_l),data(2,ind_l),'.','MarkerFaceColor',colors(l,:)); 
            axis equal; hold on;
            xlabel('X_1'); ylabel('X_2');
        end
    end
end

0 人点赞