matlab倒立摆环境建模

2019-11-05 16:05:07 浏览数 (1)

matlab强化学习工具箱提供了建立环境的模板对象,我们只要

新建模板rlCreateEnvTemplate("CartPoleEnv")

就可以自行建立需要的学习环境,成功建立之后

env = CartPoleEnv;

就成功得到环境变量、配合强化学习agent就可以进行学习训练

训练之前可以通过

step(env,10);

查看是否正常

下面就是环境对象

classdef CartPoleEnv < rl.env.MATLABEnvironment

�rtPoleEnv: matlab倒立摆环境.

%% 属性设置

properties

% 常量设置

Gravity = 9.8

% 滑块质量

CartMass = 1.0

% 摆杆质量

PoleMass = 0.1

% 摆杆长度的一半(重心)

HalfPoleLength = 0.5

% 推力最值

MaxForce = 10

% 采样时间

Ts = 0.02

% 摆动范围限制

AngleThreshold = 12 * pi/180

% 移动范围限制

DisplacementThreshold = 2.4

% 平衡时的reward

RewardForNotFalling = 1

% 超出平衡范围的reward

PenaltyForFalling = -10

% 保存显示的figure句柄

h

% 是否绘图

show

end

properties

% 初始状态 [x,dx,theta,dtheta]'

State = zeros(4,1)

end

properties(Access = protected)

% 结束标记

IsDone = false

end

%% 必须的方法

methods

% 构造方法

function this = CartPoleEnv()

% 初始设置观察状态

ObservationInfo = rlNumericSpec([4 1]);

ObservationInfo.Name = 'CartPole States';

ObservationInfo.Description = 'x, dx, theta, dtheta';

% 设置动作

ActionInfo = rlFiniteSetSpec([-1 1]);

ActionInfo.Name = 'CartPole Action';

% 继承系统环境

this = this@rl.env.MATLABEnvironment(ObservationInfo,ActionInfo);

% 初始化、设置

updateActionInfo(this);

this.h = figure;

this.show=1;

notifyEnvUpdated(this);

end

% 一次动作的效果

function [Observation,Reward,IsDone,LoggedSignals] = step(this,Action)

LoggedSignals = [];

% 计算推力

Force = getForce(this,Action);

% 取出状态

XDot = this.State(2);

Theta = this.State(3);

ThetaDot = this.State(4);

% 缓存变量

CosTheta = cos(Theta);

SinTheta = sin(Theta);

SystemMass = this.CartMass this.PoleMass;

temp = (Force this.PoleMass*this.HalfPoleLength * ThetaDot^2 * SinTheta) / SystemMass;

% 计算运动结果

ThetaDotDot = (this.Gravity * SinTheta - CosTheta* temp) / (this.HalfPoleLength * (4.0/3.0 - this.PoleMass * CosTheta * CosTheta / SystemMass));

XDotDot = temp - this.PoleMass*this.HalfPoleLength * ThetaDotDot * CosTheta / SystemMass;

% 更新状态

Observation = this.State this.Ts.*[XDot;XDotDot;ThetaDot;ThetaDotDot];

this.State = Observation;

% 检查是否超出范围

X = Observation(1);

Theta = Observation(3);

IsDone = abs(X) > this.DisplacementThreshold || abs(Theta) > this.AngleThreshold;

this.IsDone = IsDone;

% 计算reward

Reward = getReward(this);

% 通知绘图方法进行绘图

notifyEnvUpdated(this);

end

% 环境重置

function InitialObservation = reset(this)

% 初始角度

T0 = 2 * 0.05 * rand - 0.05;

% 初始角速度

Td0 = 0;

% 滑块位置

X0 = 0;

% 滑块速度

Xd0 = 0;

InitialObservation = [T0;Td0;X0;Xd0];

this.State = InitialObservation;

% 通知绘图

notifyEnvUpdated(this);

end

end

%% 可选函数、为了方便自行添加的

methods

% 计算推力

function force = getForce(this,action)

if ~ismember(action,this.ActionInfo.Elements)

error('Action must be %g for going left and %g for going right.',-this.MaxForce,this.MaxForce);

end

force = action;

end

% 设置最大推力

function updateActionInfo(this)

this.ActionInfo.Elements = this.MaxForce*[-1 1];

end

% 计算Reward

function Reward = getReward(this)

if ~this.IsDone

Reward = this.RewardForNotFalling;

else

Reward = this.PenaltyForFalling;

end

end

% 绘制环境

function plot(this)

% 初始化句柄

this.h = figure;

% 通知绘图

envUpdatedCallback(this)

end

% 用于测试环境的几个方法

function set.State(this,state)

validateattributes(state,{'numeric'},{'finite','real','vector','numel',4},'','State');

this.State = double(state(:));

notifyEnvUpdated(this);

end

function set.HalfPoleLength(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','HalfPoleLength');

this.HalfPoleLength = val;

notifyEnvUpdated(this);

end

function set.Gravity(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Gravity');

this.Gravity = val;

end

function set.CartMass(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','CartMass');

this.CartMass = val;

end

function set.PoleMass(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','PoleMass');

this.PoleMass = val;

end

function set.MaxForce(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','MaxForce');

this.MaxForce = val;

updateActionInfo(this);

end

function set.Ts(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','Ts');

this.Ts = val;

end

function set.AngleThreshold(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','AngleThreshold');

this.AngleThreshold = val;

end

function set.DisplacementThreshold(this,val)

validateattributes(val,{'numeric'},{'finite','real','positive','scalar'},'','DisplacementThreshold');

this.DisplacementThreshold = val;

end

function set.RewardForNotFalling(this,val)

validateattributes(val,{'numeric'},{'real','finite','scalar'},'','RewardForNotFalling');

this.RewardForNotFalling = val;

end

function set.PenaltyForFalling(this,val)

validateattributes(val,{'numeric'},{'real','finite','scalar'},'','PenaltyForFalling');

this.PenaltyForFalling = val;

end

end

methods (Access = protected)

% 收到绘图通知开始绘图的方法

function envUpdatedCallback(this)

% 判断是否需要绘图

if ~this.show

return

end

figure(this.h)

clf

% 取出变量

X = this.State(2);

theta = this.State(3);

% 绘制滑块

cartpoly = polyshape([-0.25 -0.25 0.25 0.25],[-0.125 0.125 0.125 -0.125]);

cartpoly = translate(cartpoly,[X 0]);

plot(cartpoly,'FaceColor',[0.8500 0.3250 0.0980])

hold on

% 绘制摆杆

L = this.HalfPoleLength*2;

polepoly = polyshape([-0.1 -0.1 0.1 0.1],[0 L L 0]);

polepoly = translate(polepoly,[X,0]);

polepoly = rotate(polepoly,rad2deg(theta),[X,0]);

plot(polepoly,'FaceColor',[0 0.4470 0.7410])

hold off

xlim([-3 3])

ylim([-1 2])

end

end

end

0 人点赞