MATLAB训练CartPole强化学习模型

2022-01-19 16:23:33 浏览数 (1)

MATLAB在gym环境中进行强化学习训练

首先回忆一下我们的小目标

这次用的环境是移动倒立摆CartPole环境,建立环境模型,主要是对reword进行定义

代码语言:javascript复制
classdef CartPoleEnv < rl.env.MATLABEnvironment
    %http://gym.openai.com/envs/CartPole-v1
    %% 属性设置
    properties
        show=true;
        % pygame环境对象
        p
        % 初始状态
        State
    end
    properties(Access = protected)
        % 结束标记
        IsDone = false
    end
    %% 必须的方法
    methods
        % 构造方法
        function this = CartPoleEnv()
            % 初始设置观察状态
            ObservationInfo = rlNumericSpec([1 4]);
            % 设置动作
            ActionInfo = rlFiniteSetSpec(1:2);
            % 继承系统环境
            this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);
            % 初始化、设置
            this.State=[0 0 0 0];
            this.p=py.gym.make('CartPole-v0');
            this.p.reset();
            notifyEnvUpdated(this);
        end
        % 一次动作的效果
        function [Observation,Reward,IsDone,LoggedSignals] = step(this,action)
            LoggedSignals = [];
            act = py.int(action-1);
            %             act
            % 计算reward
            temp = cell(this.p.step(act));
            Observation = double(temp{1,1});
            IsDone = temp{1,3};
            x=Observation(1);
            theta=Observation(3);
            x_threshold=2.4;
            theta_threshold=12*2*pi/360;
            rp = (x_threshold - abs(x))/x_threshold - 0.8 ;
            ra = (theta_threshold - abs(theta))/theta_threshold - 0.5 ;
            Reward = rp   ra;
            if IsDone
                Reward=-10;
            end
            this.State = Observation;
            this.IsDone = IsDone;
            notifyEnvUpdated(this);
        end
        % 环境重置
        function InitialObservation = reset(this)
            this.p.reset();
            InitialObservation =[0 0 0 0];
            this.State = InitialObservation;
            notifyEnvUpdated(this);
        end
    end
    %% 可选函数、为了方便自行添加的
    methods
        % 收到绘图通知开始绘图的方法
        function isDone=is_done(this)
            % 设置是否需要绘图
            isDone = this.IsDone;
        end
    end
    methods (Access = protected)
        % 收到绘图通知开始绘图的方法
        function envUpdatedCallback(this)
            % 设置是否需要绘图
            if this.show
                this.p.render();
            end
        end
    end
end

接下来建立强化学习网络模型、和MATLAB借助openai gym环境训练强化学习模型不同,CartPole环境的输入只分为2项——左边施力与右边施力,输出为滑块位置、滑块速度、摆杆角度、摆杆转速,根据输入输出设置网络模型

代码语言:javascript复制
%% 读取环境
ccc
env = CartPoleEnv;
% 获取可观察的状态
obsInfo = getObservationInfo(env);
% 获取可观察的状态维度
numObservations = obsInfo.Dimension(2);
% 获取可执行的动作
actInfo = getActionInfo(env);
% 获取可执行的动作维度
numActions = actInfo.Dimension(1);
rng(0)
%% 初始化agent
dnn = [
    featureInputLayer(obsInfo.Dimension(2),'Normalization','none','Name','state')
    fullyConnectedLayer(24,'Name','CriticStateFC1')
    reluLayer('Name','CriticRelu1')
    fullyConnectedLayer(24, 'Name','CriticStateFC2')
    reluLayer('Name','CriticCommonRelu')
    fullyConnectedLayer(length(actInfo.Elements),'Name','output')];
% figure
% plot(layerGraph(dnn))
criticOpts = rlRepresentationOptions('LearnRate',0.001,'GradientThreshold',1);
critic = rlQValueRepresentation(dnn,obsInfo,actInfo,'Observation',{'state'},criticOpts);
agentOpts = rlDQNAgentOptions(...
    'UseDoubleDQN',false, ...    
    'TargetSmoothFactor',1, ...
    'TargetUpdateFrequency',4, ...   
    'ExperienceBufferLength',100000, ...
    'DiscountFactor',0.99, ...
    'MiniBatchSize',256);
agent = rlDQNAgent(critic,agentOpts);
%% 设置训练参数
trainOpts = rlTrainingOptions(...
    'MaxEpisodes',1000, ...
    'MaxStepsPerEpisode',500, ...
    'Verbose',false, ...
    'Plots','training-progress',...
    'StopTrainingCriteria','AverageReward',...
    'StopTrainingValue',480); 
%% 训练

env.show=false;
trainingStats = train(agent,env,trainOpts);
代码语言:javascript复制
%% 结果展示
env.show=true;
simOptions = rlSimulationOptions('MaxSteps',5000);
experience = sim(env,agent,simOptions);
totalReward = sum(experience.Reward);

0 人点赞