讲
matlab强化学习工具箱提供了建立环境的模板对象,我们只要
新建模板rlCreateEnvTemplate("CartPoleEnv")
就可以自行建立需要的学习环境,成功建立之后
env = CartPoleEnv;
就成功得到环境变量、配合强化学习agent就可以进行学习训练
训练之前可以通过
step(env,10);
查看是否正常
下面就是环境对象
classdef CartPoleEnv < rl.env.MATLABEnvironment
�rtPoleEnv: matlab倒立摆环境.
%% 属性设置
properties
% 常量设置
Gravity = 9.8
% 滑块质量
CartMass = 1.0
% 摆杆质量
PoleMass = 0.1
% 摆杆长度的一半(重心)
HalfPoleLength = 0.5
% 推力最值
MaxForce = 10
% 采样时间
Ts = 0.02
% 摆动范围限制
AngleThreshold = 12 * pi/180
% 移动范围限制
DisplacementThreshold = 2.4
% 平衡时的reward
RewardForNotFalling = 1
% 超出平衡范围的reward
PenaltyForFalling = -10
% 保存显示的figure句柄
h
% 是否绘图
show
end
properties
% 初始状态 [x,dx,theta,dtheta]'
State = zeros(4,1)
end
properties(Access = protected)
% 结束标记
IsDone = false
end
%% 必须的方法
methods
% 构造方法
function this = CartPoleEnv()
% 初始设置观察状态
ObservationInfo = rlNumericSpec([4 1]);
ObservationInfo.Name = 'CartPole States';
ObservationInfo.Description = 'x, dx, theta, dtheta';
% 设置动作
ActionInfo = rlFiniteSetSpec([-1 1]);
ActionInfo.Name = 'CartPole Action';
% 继承系统环境
this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);
% 初始化、设置
updateActionInfo(this);
this.h = figure;
this.show=1;
notifyEnvUpdated(this);
end
% 一次动作的效果
function [Observation,Reward,IsDone,LoggedSignals] = step(this,Action)
LoggedSignals = [];
% 计算推力
Force = getForce(this,Action);
% 取出状态
XDot = this.State(2);
Theta = this.State(3);
ThetaDot = this.State(4);
% 缓存变量
CosTheta = cos(Theta);
SinTheta = sin(Theta);
SystemMass = this.CartMass this.PoleMass;
temp = (Force this.PoleMass*this.HalfPoleLength * ThetaDot^2 * SinTheta) / SystemMass;
% 计算运动结果
ThetaDotDot = (this.Gravity * SinTheta - CosTheta* temp) / (this.HalfPoleLength * (4.0/3.0 - this.PoleMass * CosTheta * CosTheta / SystemMass));
XDotDot = temp - this.PoleMass*this.HalfPoleLength * ThetaDotDot * CosTheta / SystemMass;
% 更新状态
Observation = this.State this.Ts.*[XDot;XDotDot;ThetaDot;ThetaDotDot];
this.State = Observation;
% 检查是否超出范围
X = Observation(1);
Theta = Observation(3);
IsDone = abs(X) > this.DisplacementThreshold || abs(Theta) > this.AngleThreshold;
this.IsDone = IsDone;
% 计算reward
Reward = getReward(this);
% 通知绘图方法进行绘图
notifyEnvUpdated(this);
end
% 环境重置
function InitialObservation = reset(this)
% 初始角度
T0 = 2 * 0.05 * rand - 0.05;
% 初始角速度
Td0 = 0;
% 滑块位置
X0 = 0;
% 滑块速度
Xd0 = 0;
InitialObservation = [T0;Td0;X0;Xd0];
this.State = InitialObservation;
% 通知绘图
notifyEnvUpdated(this);
end
end
%% 可选函数、为了方便自行添加的
methods
% 计算推力
function force = getForce(this,action)
if ~ismember(action,this.ActionInfo.Elements)
error('Action must be %g for going left and %g for going right.',-this.MaxForce,this.MaxForce);
end
force = action;
end
% 设置最大推力
function updateActionInfo(this)
this.ActionInfo.Elements = this.MaxForce*[-1 1];
end
% 计算Reward
function Reward = getReward(this)
if ~this.IsDone
Reward = this.RewardForNotFalling;
else
Reward = this.PenaltyForFalling;
end
end
% 绘制环境
function plot(this)
% 初始化句柄
this.h = figure;
% 通知绘图
envUpdatedCallback(this)
end
% 用于测试环境的几个方法
function set.State(this,state)
validateattributes(state,{'numeric'},{'finite','real','vector','numel',4},'','State');
this.State = double(state(:));
notifyEnvUpdated(this);
end
function set.HalfPoleLength(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','HalfPoleLength');
this.HalfPoleLength = val;
notifyEnvUpdated(this);
end
function set.Gravity(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Gravity');
this.Gravity = val;
end
function set.CartMass(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','CartMass');
this.CartMass = val;
end
function set.PoleMass(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','PoleMass');
this.PoleMass = val;
end
function set.MaxForce(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MaxForce');
this.MaxForce = val;
updateActionInfo(this);
end
function set.Ts(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Ts');
this.Ts = val;
end
function set.AngleThreshold(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','AngleThreshold');
this.AngleThreshold = val;
end
function set.DisplacementThreshold(this,val)
validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','DisplacementThreshold');
this.DisplacementThreshold = val;
end
function set.RewardForNotFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','RewardForNotFalling');
this.RewardForNotFalling = val;
end
function set.PenaltyForFalling(this,val)
validateattributes(val,{'numeric'},{'real','finite','scalar'},'','PenaltyForFalling');
this.PenaltyForFalling = val;
end
end
methods (Access = protected)
% 收到绘图通知开始绘图的方法
function envUpdatedCallback(this)
% 判断是否需要绘图
if ~this.show
return
end
figure(this.h)
clf
% 取出变量
X = this.State(2);
theta = this.State(3);
% 绘制滑块
cartpoly = polyshape([-0.25 -0.25 0.25 0.25],[-0.125 0.125 0.125 -0.125]);
cartpoly = translate(cartpoly,[X 0]);
plot(cartpoly,'FaceColor',[0.8500 0.3250 0.0980])
hold on
% 绘制摆杆
L = this.HalfPoleLength*2;
polepoly = polyshape([-0.1 -0.1 0.1 0.1],[0 L L 0]);
polepoly = translate(polepoly,[X,0]);
polepoly = rotate(polepoly,rad2deg(theta),[X,0]);
plot(polepoly,'FaceColor',[0 0.4470 0.7410])
hold off
xlim([-3 3])
ylim([-1 2])
end
end
end