【Matlab代码】共轭梯度下降和阻尼牛顿下降两种算法解得极值

2022-05-28 15:05:43 浏览数 (1)

代码语言:javascript复制
clc;
clear;
syms x y r;
f = (x y)^2   (x 1)^2   (y 3)^2;
syms x_tmp y_tmp;
f_tmp = (x_tmp y_tmp)^2   (x_tmp 1)^2   (y_tmp 3)^2;
fx = diff(f,x);
fy = diff(f,y);
grad_f1 = [fx,fy]';  
x = -20:0.1:20;
y = -15:0.1:15;
[X,Y] = meshgrid(x,y); 
Z = (X Y).^2   (X 1).^2   (Y 3).^2;
figure(1);
mesh(X,Y,Z);
xlabel('横坐标x'); ylabel('纵坐标y'); zlabel('空间坐标z');
hold on;
x0 = 10; y0 = -1.5;
z0 = (x0 y0)^2   (x0 1)^2   (y0 3)^2;
plot3(x0,y0,z0,'r*')
hold on
acc = 0.0001;     
x = 10; 
y = -1.5;  
k = 1;     
fprintf('共轭梯度下降开始:n');
d = -eval(grad_f1);   
while 1
    grad_f1_down = norm(eval(grad_f1))^2;   
    x_tmp = x   r*d(1);
    y_tmp = y   r*d(2);
    r_result = solve(diff(eval(f_tmp)));
    x = x   r_result*d(1);
    y = y   r_result*d(2);
    grad_f1_up = norm(eval(grad_f1))^2;     
    plot3(x,y,eval(f),'r*');  
    hold on 
    if norm(eval(grad_f1)) <= acc
        fprintf('极值坐标为:(%.5f,%.5f,%.5f)n',x,y,eval(f))
        fprintf('迭代次数:%dn',k)
        break;
    end
    miu = grad_f1_up/grad_f1_down;
    d = -eval(grad_f1)   miu*d;
    k = k 1;
end
hold off;

代码语言:javascript复制
clc;
clear;
syms x y;
f =(x y)^2   (x 1)^2   (y 3)^2;
syms x1 y1 a; 
f1 = (x1 y1)^2   (x1 1)^2   (y1 3)^2;
fx = diff(f,x);
fy = diff(f,y);
fxx = diff(fx,x);
fyy = diff(fy,y);
fxy = diff(fx,y);
fyx = diff(fy,x);
grad_f1 = [fx;fy];    
grad_H2 = [fxx fxy;fyx fyy]; 
x = -20:0.1:20;
y = -15:0.1:15;
[X,Y] = meshgrid(x,y); 
Z = (X Y).^2   (X 1).^2   (Y 3).^2;
figure(1);
mesh(X,Y,Z);
xlabel('横坐标x'); ylabel('纵坐标y'); zlabel('空间坐标z');
hold on;
x0 = 10; y0 = -1.5;
z0 = (x0 y0)^2   (x0 1)^2   (y0 3)^2;
plot3(x0,y0,z0,'r*');
hold on       
acc = 0.00001;  
x = 10; 
y = -1.5;      
k = 0;        
fprintf('阻尼牛顿下降开始:n')
while 1
 ans_tmp = [x;y];
 S = -eval(inv(grad_H2))*eval(grad_f1);
 x1 = x   a*S(1)
 y1 = y   a*S(2)
    if diff(eval(f1),a) == 0
        a_result = 0;
    else
     a_result = solve(diff(eval(f1)));
        fprintf('第%d次迭代,当前阻尼步长为:%.5fn', k, a_result);
    end
    x = x   a_result*S(1);
 y = y   a_result*S(2);
    result_tmp = [x;y];
    acc_tmp = sqrt( (result_tmp(1)-ans_tmp(1))^2   (result_tmp(2) - ans_tmp(2))^2 );
 if acc_tmp < acc
 fprintf('极值坐标为:(%.5f,%.5f,%.5f)n', x, y, eval(f));
        fprintf('迭代次数:%dn',k);
        plot3(x,y,eval(f),'r*');
        hold off;
        break;
    end
    plot3(x,y,eval(f),'r*');
    hold on;
    k = k   1;
    if k >= 100
        fprintf('自动结束!n');
        break;
    end
end

0 人点赞