MATLAB强化学习toolbox

2019-10-18 14:58:56 浏览数 (1)

新版本MATLAB提供了Reinforcement Learning Toolbox可以方便地建立二维基础网格环境、设置起点、目标、障碍,以及各种agent模型

这是Q-learning的训练简单实现

ccc

%% 布置环境硬件

GW = createGridWorld(6,6);

GW.CurrentState = '[6,1]';

GW.TerminalStates = '[2,5]';

GW.ObstacleStates = ["[2,3]";"[2,4]";"[3,5]";"[4,5]"];

%% 根据障碍设置可否行进

updateStateTranstionForObstacles(GW)

%% 设置reward

nS = numel(GW.States);

nA = numel(GW.Actions);

GW.R = -1*ones(nS,nS,nA);

GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;

%% 生成环境及初始位置

env = rlMDPEnv(GW);

plot(env)

env.ResetFcn = @() 6;

%% Q-learning训练参数初始化

qTable = rlTable(getObservationInfo(env),getActionInfo(env));

tableRep = rlRepresentation(qTable);

tableRep.Options.LearnRate = 1;

agentOpts = rlQAgentOptions;

agentOpts.EpsilonGreedyExploration.Epsilon = .04;

qAgent = rlQAgent(tableRep,agentOpts);

trainOpts = rlTrainingOptions;

trainOpts.MaxStepsPerEpisode = 50;

trainOpts.MaxEpisodes= 200;

trainOpts.StopTrainingCriteria = "AverageReward";

trainOpts.StopTrainingValue = 11;

trainOpts.ScoreAveragingWindowLength = 30;

%% 训练

rng(0)

trainingStats = train(qAgent,env,trainOpts);

%% 结果展示

plot(env)

env.Model.Viewer.ShowTrace = true;

env.Model.Viewer.clearTrace;

sim(qAgent,env)

这是SARSA的训练简单实现

ccc

%% 布置环境硬件

GW = createGridWorld(6,6);

GW.CurrentState = '[6,1]';

GW.TerminalStates = '[2,5]';

GW.ObstacleStates = ["[2,3]";"[2,4]";"[3,5]";"[4,5]"];

%% 设置可否行进

updateStateTranstionForObstacles(GW)

%% 设置reward

nS = numel(GW.States);

nA = numel(GW.Actions);

GW.R = -1*ones(nS,nS,nA);

GW.R(:,state2idx(GW,GW.TerminalStates),:) = 10;

%% 生成环境及初始位置

env = rlMDPEnv(GW);

plot(env)

env.ResetFcn = @() 6;

%% %% SARSA参数初始化

rng(0)

qTable = rlTable(getObservationInfo(env),getActionInfo(env));

tableRep = rlRepresentation(qTable);

tableRep.Options.LearnRate = 1;

agentOpts = rlSARSAAgentOptions;

agentOpts.EpsilonGreedyExploration.Epsilon = 0.04;

sarsaAgent = rlSARSAAgent(tableRep,agentOpts);

trainOpts = rlTrainingOptions;

trainOpts.MaxStepsPerEpisode = 50;

trainOpts.MaxEpisodes= 200;

trainOpts.StopTrainingCriteria = "AverageReward";

trainOpts.StopTrainingValue = 11;

trainOpts.ScoreAveragingWindowLength = 30;

%% 训练

trainingStats = train(sarsaAgent,env,trainOpts);

%% 结果展示

plot(env)

env.Model.Viewer.ShowTrace = true;

env.Model.Viewer.clearTrace;

sim(sarsaAgent,env)

0 人点赞