MATLAB强化学习训练simulink模型

2020-02-11 16:29:11 浏览数 (1)

simulink可以方便地建立物理域模型,这是一个简单的倒立摆,同样可以使用MATLAB的强化学习工具箱进行训练

%% 读取环境

ccc

mdl = 'rlCartPoleSimscapeModel';

open_system(mdl)

env = rlPredefinedEnv('CartPoleSimscapeModel-Continuous');

obsInfo = getObservationInfo(env);

numObservations = obsInfo.Dimension(1);

actInfo = getActionInfo(env);

%%

Ts = 0.02;

Tf = 25;

rng(0)

%% 初始化agent

statePath = [

imageInputLayer([numObservations 1 1],'Normalization','none','Name','observation')

fullyConnectedLayer(128,'Name','CriticStateFC1')

reluLayer('Name','CriticRelu1')

fullyConnectedLayer(200,'Name','CriticStateFC2')];

actionPath = [

imageInputLayer([1 1 1],'Normalization','none','Name','action')

fullyConnectedLayer(200,'Name','CriticActionFC1','BiasLearnRateFactor',0)];

commonPath = [

additionLayer(2,'Name','add')

reluLayer('Name','CriticCommonRelu')

fullyConnectedLayer(1,'Name','CriticOutput')];

criticNetwork = layerGraph(statePath);

criticNetwork = addLayers(criticNetwork,actionPath);

criticNetwork = addLayers(criticNetwork,commonPath);

criticNetwork = connectLayers(criticNetwork,'CriticStateFC2','add/in1');

criticNetwork = connectLayers(criticNetwork,'CriticActionFC1','add/in2');

figure

plot(criticNetwork)

criticOptions = rlRepresentationOptions('LearnRate',1e-03,'GradientThreshold',1);

critic = rlRepresentation(criticNetwork,obsInfo,actInfo,...

'Observation',{'observation'},'Action',{'action'},criticOptions);

actorNetwork = [

imageInputLayer([numObservations 1 1],'Normalization','none','Name','observation')

fullyConnectedLayer(128,'Name','ActorFC1')

reluLayer('Name','ActorRelu1')

fullyConnectedLayer(200,'Name','ActorFC2')

reluLayer('Name','ActorRelu2')

fullyConnectedLayer(1,'Name','ActorFC3')

tanhLayer('Name','ActorTanh1')

scalingLayer('Name','ActorScaling','Scale',max(actInfo.UpperLimit))];

actorOptions = rlRepresentationOptions('LearnRate',5e-04,'GradientThreshold',1);

actor = rlRepresentation(actorNetwork,obsInfo,actInfo,...

'Observation',{'observation'},'Action',{'ActorScaling'},actorOptions);

agentOptions = rlDDPGAgentOptions(...

'SampleTime',Ts,...

'TargetSmoothFactor',1e-3,...

'ExperienceBufferLength',1e6,...

'MiniBatchSize',128);

agentOptions.NoiseOptions.Variance = 0.4;

agentOptions.NoiseOptions.VarianceDecayRate = 1e-5;

agent = rlDDPGAgent(actor,critic,agentOptions);

%% 设置训练参数

maxepisodes = 2000;

maxsteps = ceil(Tf/Ts);

trainingOptions = rlTrainingOptions(...

'MaxEpisodes',maxepisodes,...

'MaxStepsPerEpisode',maxsteps,...

'ScoreAveragingWindowLength',5,...

'Verbose',false,...

'Plots','training-progress',...

'StopTrainingCriteria','AverageReward',...

'StopTrainingValue',-400,...

'SaveAgentCriteria','EpisodeReward',...

'SaveAgentValue',-400);

%% 并行学习设置

trainOpts.UseParallel = true;

trainOpts.ParallelizationOptions.Mode = "async";

trainOpts.ParallelizationOptions.DataToSendFromWorkers = "Gradients";

trainOpts.ParallelizationOptions.StepsUntilDataIsSent = -1;

%% 训练

trainingStats = train(agent,env,trainingOptions);

%% 结果展示

simOptions = rlSimulationOptions('MaxSteps',500);

experience = sim(env,agent,simOptions);

totalReward = sum(experience.Reward);

% bdclose(mdl)

%关闭simulink模型

相关文件下载链接:

https://pan.baidu.com/s/1O1O1PaloLpaOFde1PNI_1w

提取码:ngou

畅通无阻、立即下载https://github.com/getlantern/forum

帮你学MatLab

0 人点赞