环境准备
在Manjaro Linux上先用如下命令启动docker容器服务,启动后可用status查看状态:
代码语言:javascript复制1
2
3
4
5
6
7
8
9
10
11
12
13
[dechin-manjaro mindspore]# systemctl start docker
[dechin-manjaro mindspore]# systemctl status docker
● docker.service - Docker Application Container Engine
Loaded: loaded (/usr/lib/systemd/system/docker.service; disabled; vendor preset: disabled)
Active: active (running) since Wed 2021-04-14 16:32:38 CST; 9s ago
TriggeredBy: ● docker.socket
Docs: https://docs.docker.com
Main PID: 298485 (dockerd)
Tasks: 99 (limit: 47875)
Memory: 186.0M
CGroup: /system.slice/docker.service
├─298485 /usr/bin/dockerd -H fd://
└─298496 containerd --config /var/run/docker/containerd/containerd.toml --log-level info
在按照这篇博客的方法下载下来mindspore的容器镜像之后,可以在本地的镜像仓库中查询到该镜像:
代码语言:javascript复制1
2
3
[dechin-root mindspore]# docker images
REPOSITORY TAG IMAGE ID
swr.cn-south-1.myhuaweicloud.com/mindspore/mindspore-cpu 1.1.1 98a3f041e3d4
容器的启动方式可以参考如下指令:
代码语言:javascript复制1
2
3
4
5
6
[dechin-root mindspore]# docker run -it 98a3
root@2a6c33894e53:~# python
Python 3.7.5 (default, Feb 8 2021, 02:21:05)
[GCC 7.5.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>>
这里可以看到在这个容器镜像中是预装了python3.7.5版本的mindspore的,可以在python的命令行中用如下的方法进行验证:
代码语言:javascript复制1
2
3
4
>>> from mindspore import context
WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.
[WARNING] ME(20:139876984823936,MainProcess):2021-04-14-08:37:40.331.840 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.
>>> context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
除了mindspore自身之外,我们还经常可能用到一些第三方的库,如matplotlib等,jay我们可以自行安装:
代码语言:javascript复制1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
root@2a6c33894e53:~# python -m pip install matplotlib
Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
Collecting matplotlib
Downloading http://mirrors.aliyun.com/pypi/packages/ce/63/74c0b6184b6b169b121bb72458818ee60a7d7c436d7b1907bd5874188c55/matplotlib-3.4.1-cp37-cp37m-manylinux1_x86_64.whl (10.3MB)
|████████████████████████████████| 10.3MB 4.4MB/s
Collecting cycler>=0.10 (from matplotlib)
Downloading http://mirrors.aliyun.com/pypi/packages/f7/d2/e07d3ebb2bd7af696440ce7e754c59dd546ffe1bbe732c8ab68b9c834e61/cycler-0.10.0-py2.py3-none-any.whl
Collecting kiwisolver>=1.0.1 (from matplotlib)
Downloading http://mirrors.aliyun.com/pypi/packages/d2/46/231de802ade4225b76b96cffe419cf3ce52bbe92e3b092cf12db7d11c207/kiwisolver-1.3.1-cp37-cp37m-manylinux1_x86_64.whl (1.1MB)
|████████████████████████████████| 1.1MB 13.9MB/s
Collecting python-dateutil>=2.7 (from matplotlib)
Downloading http://mirrors.aliyun.com/pypi/packages/d4/70/d60450c3dd48ef87586924207ae8907090de0b306af2bce5d134d78615cb/python_dateutil-2.8.1-py2.py3-none-any.whl (227kB)
|████████████████████████████████| 235kB 4.6MB/s
Requirement already satisfied: pyparsing>=2.2.1 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (2.4.7)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (8.1.0)
Requirement already satisfied: numpy>=1.16 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from matplotlib) (1.17.5)
Requirement already satisfied: six in /usr/local/python-3.7.5/lib/python3.7/site-packages (from cycler>=0.10->matplotlib) (1.15.0)
Installing collected packages: cycler, kiwisolver, python-dateutil, matplotlib
Successfully installed cycler-0.10.0 kiwisolver-1.3.1 matplotlib-3.4.1 python-dateutil-2.8.1
WARNING: You are using pip version 19.2.3, however version 21.0.1 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.
root@2a6c33894e53:~# python -m pip install --upgrade pip
Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
Collecting pip
Downloading http://mirrors.aliyun.com/pypi/packages/fe/ef/60d7ba03b5c442309ef42e7d69959f73aacccd0d86008362a681c4698e83/pip-21.0.1-py3-none-any.whl (1.5MB)
|████████████████████████████████| 1.5MB 1.3MB/s
Installing collected packages: pip
Found existing installation: pip 19.2.3
Uninstalling pip-19.2.3:
Successfully uninstalled pip-19.2.3
Successfully installed pip-21.0.1
同样的方法我们再安装一下ipython:
代码语言:javascript复制1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
root@b8955ba28950:/home# python -m pip install IPython
Looking in indexes: http://mirrors.aliyun.com/pypi/simple/
Collecting IPython
Downloading http://mirrors.aliyun.com/pypi/packages/c9/b1/82cbe2b856386f44f37fdae54d9b425813bd86fe33385c9d658d64826098/ipython-7.22.0-py3-none-any.whl (785 kB)
|████████████████████████████████| 785 kB 1.8 MB/s
Collecting prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0
Downloading http://mirrors.aliyun.com/pypi/packages/eb/e6/4b4ca4fa94462d4560ba2f4e62e62108ab07be2e16a92e594e43b12d3300/prompt_toolkit-3.0.18-py3-none-any.whl (367 kB)
|████████████████████████████████| 367 kB 818 kB/s
Collecting pickleshare
Downloading http://mirrors.aliyun.com/pypi/packages/9a/41/220f49aaea88bc6fa6cba8d05ecf24676326156c23b991e80b3f2fc24c77/pickleshare-0.7.5-py2.py3-none-any.whl (6.9 kB)
Collecting pygments
Downloading http://mirrors.aliyun.com/pypi/packages/3a/80/a52c0a7c5939737c6dca75a831e89658ecb6f590fb7752ac777d221937b9/Pygments-2.8.1-py3-none-any.whl (983 kB)
|████████████████████████████████| 983 kB 2.7 MB/s
Requirement already satisfied: decorator in /usr/local/python-3.7.5/lib/python3.7/site-packages (from IPython) (4.4.2)
Collecting traitlets>=4.2
Downloading http://mirrors.aliyun.com/pypi/packages/f6/7d/3ecb0ebd0ce8dcdfa7bd47ab85c1d4a521e6770ef283d0824f5804994dfe/traitlets-5.0.5-py3-none-any.whl (100 kB)
|████████████████████████████████| 100 kB 4.0 MB/s
Collecting pexpect>4.3
Downloading http://mirrors.aliyun.com/pypi/packages/39/7b/88dbb785881c28a102619d46423cb853b46dbccc70d3ac362d99773a78ce/pexpect-4.8.0-py2.py3-none-any.whl (59 kB)
|████████████████████████████████| 59 kB 5.9 MB/s
Collecting jedi>=0.16
Downloading http://mirrors.aliyun.com/pypi/packages/f9/36/7aa67ae2663025b49e8426ead0bad983fee1b73f472536e9790655da0277/jedi-0.18.0-py2.py3-none-any.whl (1.4 MB)
|████████████████████████████████| 1.4 MB 3.7 MB/s
Collecting backcall
Downloading http://mirrors.aliyun.com/pypi/packages/4c/1c/ff6546b6c12603d8dd1070aa3c3d273ad4c07f5771689a7b69a550e8c951/backcall-0.2.0-py2.py3-none-any.whl (11 kB)
Requirement already satisfied: setuptools>=18.5 in /usr/local/python-3.7.5/lib/python3.7/site-packages (from IPython) (41.2.0)
Collecting parso<0.9.0,>=0.8.0
Downloading http://mirrors.aliyun.com/pypi/packages/a9/c4/d5476373088c120ffed82f34c74b266ccae31a68d665b837354d4d8dc8be/parso-0.8.2-py2.py3-none-any.whl (94 kB)
|████████████████████████████████| 94 kB 6.0 MB/s
Collecting ptyprocess>=0.5
Downloading http://mirrors.aliyun.com/pypi/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl (13 kB)
Collecting wcwidth
Downloading http://mirrors.aliyun.com/pypi/packages/59/7c/e39aca596badaf1b78e8f547c807b04dae603a433d3e7a7e04d67f2ef3e5/wcwidth-0.2.5-py2.py3-none-any.whl (30 kB)
Collecting ipython-genutils
Downloading http://mirrors.aliyun.com/pypi/packages/fa/bc/9bd3b5c2b4774d5f33b2d544f1460be9df7df2fe42f352135381c347c69a/ipython_genutils-0.2.0-py2.py3-none-any.whl (26 kB)
Installing collected packages: wcwidth, ptyprocess, parso, ipython-genutils, traitlets, pygments, prompt-toolkit, pickleshare, pexpect, jedi, backcall, IPython
WARNING: The script pygmentize is installed in '/usr/local/python-3.7.5/bin' which is not on PATH.
Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
WARNING: The scripts iptest, iptest3, ipython and ipython3 are installed in '/usr/local/python-3.7.5/bin' which is not on PATH.
Consider adding this directory to PATH or, if you prefer to suppress this warning, use --no-warn-script-location.
Successfully installed IPython-7.22.0 backcall-0.2.0 ipython-genutils-0.2.0 jedi-0.18.0 parso-0.8.2 pexpect-4.8.0 pickleshare-0.7.5 prompt-toolkit-3.0.18 ptyprocess-0.7.0 pygments-2.8.1 traitlets-5.0.5 wcwidth-0.2.5
安装过程中都没有出现其他的依赖问题,接下来我们可以在docker容器中保存这些已经安装的库,避免下一次使用的时候还需要再安装一次。在用exit
退出当前容器镜像之后,可以用docker ps
指令查看近期的操作记录:
1
2
3
4
5
[dechin-root mindspore]# docker ps -n 3
CONTAINER ID IMAGE COMMAND CREATED STATUS PORTS NAMES
2a6c33894e53 98a3 "/bin/bash" 13 minutes ago Exited (0) 7 seconds ago upbeat_tharp
625ee5f4ee95 ea1c "bash" 9 days ago Exited (0) 9 days ago zealous_mccarthy
ded2cb29290a kivy/buildozer "buildozer bash -c '…" 9 days ago Exited (1) 9 days ago exciting_lumiere
这里第一个操作记录就是我们需要保存的mindspore的镜像,那么我们可以用docker commit
的指令将操作保存到一个新的镜像里面:
1
2
[dechin-root mindspore]# docker commit 2a6c mindspore
sha256:3a6951d9b9009f93027748ecec78078efff1fb36599a5786bcbc667e72119392
上面的执行反馈表示运行成功了,再次查看本地镜像内容:
代码语言:javascript复制1
2
3
4
[dechin-root mindspore]# docker images
REPOSITORY TAG IMAGE ID CREATED SIZE
mindspore latest 3a6951d9b900 31 seconds ago 1.22GB
swr.cn-south-1.myhuaweicloud.com/mindspore/mindspore-cpu 1.1.1 98a3f041e3d4 2 months ago 1.18GB
可以看到我们的基础镜像环境已经制作完成了,在原镜像的基础上多了40M左右的空间。本章节的最后我们也说明一下,mindspore提供的这个镜像的基础系统环境为Ubuntu18.04
:
1
2
root@b8955ba28950:/home# cat /etc/issue
Ubuntu 18.04.5 LTS n l
MindSpore线性函数拟合
假设有如下图中红点所示的一系列散点,或者可以认为是需要我们来执行训练的数据。而图中的绿线表示真实的函数,也就是说我们是基于这样一个真实的线性函数,来生成了一系列加随机噪声的散点。最终我们的目的当然是希望能够通过这些散点将线性的函数再拟合出来,这样就可以用来预测下一个位置的函数值,相关技术用在量化金融领域,就可以预测下一步股市的价格,当然那样的函数就会更加的复杂。
对应于图中的函数,我们给定的是:
f(x)=2x 3f(x)=2x 3
生成散点数据集
加噪声的方法在get_data
函数中体现,其中生成数据集的方法为:先在[−10,10][−10,10]的范围内生成一系列的随机xx自变量值,然后生成一系列的正态分布随机数作为噪声,把这些噪声加到自变量值所对应的f(x)f(x)函数值上,周杰伦就得到了原始数据。当然,这里没有用return
进行返回,而是用yield
的形式逐一返回。
第二步我们需要将这些数据集转化为mindspore所能够识别的数据格式:mindspore.dataset.GeneratorDataset
,除了可以给xx和yy分别配置一个变量名之外,还可以指定这些数据集的分组(batch)和重复次数,其中分组数量的配置是有可能影响到最终的训练速率的。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
from mindspore import dataset as ds
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w b noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50)) # 生成50个带噪声的随机点
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 3 # 期望的函数值
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
print("The dataset size of ds_train:", ds_train.get_dataset_size())
dict_datasets = next(ds_train.create_dict_iterator())
print(dict_datasets.keys())
print("The x label value shape:", dict_datasets["data"].shape)
print("The y label value shape:", dict_datasets["label"].shape)
上述代码的执行效果如下:
代码语言:javascript复制1
2
3
4
5
root@b8955ba28950:/home# python test_linear.py
The dataset size of ds_train: 100
dict_keys(['data', 'label'])
The x label value shape: (16, 1)
The y label value shape: (16, 1)
到这里为止,我们就已经构造了一个1600个训练的数据,并且分为了100个batch进行训练,每个batch的大小为16。
构建拟合模型与初始参数
用mindspore.nn.Dense
的方法我们可以构造一个线性拟合的模型:
f(x)=wx bf(x)=wx b
关于该激活函数的官方文档说明如下:
而这里面的weight
和bias
的初始化参数是由一个张量形式的数据结构来定义的,我们给了一个入参nn.Dense(1, 1, Normal(0.02), Normal(0.02))
表示两组参数,都是一维的张量(或称为1阶的张量),而这两个初始化张量的元素是由两个N(0,σ)N(0,σ)正态分布所生成的随机化初始数据,比如在该案例中我们可以试着将这些初始化的参数打印出来:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
from mindspore import dataset as ds
from mindspore.common.initializer import Normal
from mindspore import nn
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w b noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 3
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
dict_datasets = next(ds_train.create_dict_iterator())
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
net = LinearNet()
model_params = net.trainable_params()
for param in model_params:
print(param, param.asnumpy())
执行结果如下,是两个一维的数组:数组
代码语言:javascript复制1
2
3
root@b8955ba28950:/home# python test_linear.py
Parameter (name=fc.weight) [[-0.00252427]]
Parameter (name=fc.bias) [0.00694926]
在上述代码中虽然打印了两个参数值,但是并不是很直观,我们可以将这组参数值所对应的函数图画在刚才的散点图中看看效果:
代码语言:javascript复制1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
from mindspore import dataset as ds
from mindspore.common.initializer import Normal
from mindspore import nn, Tensor
https://www.321flac.com/
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w b noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 3
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
dict_datasets = next(ds_train.create_dict_iterator())
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
net = LinearNet()
model_params = net.trainable_params()
x_model_label = np.array([-10, 10, 0.1])
y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0]
Tensor(model_params[1]).asnumpy()[0])
plt.axis([-10, 10, -20, 25])
plt.scatter(x_eval_label, y_eval_label, color="red", s=5)
plt.plot(x_model_label, y_model_label, color="blue")
plt.plot(x_target_label, y_target_label, color="green")
plt.savefig('initial.png')
执行后会在当前目录生成一个名为initial.png
的图片:
可以看到此时的参数所对应的函数距离我们所预期的还是比较远的。
训练与可视化
在前面的技术铺垫之后,这一步终于可以开始训练了。在机器学习中,我们需要先定义好一个用于衡量结果好坏的函数,一般可以称之为损失函数(Loss Function)。损失函数值越小,代表结果就越好,在我们面对的这个函数拟合问题中所代表的就是,拟合的效果越好。这里我们采取的是均方误差函数(Mean Square Error,简称MSE):
均方误差是最常使用的损失函数,因为不管是往哪个方向的偏移,都会导致损失函数值的急剧增大。在定义好损失函数之后,我们需要定义一个前向传播网络,用于执行损失函数的计算,这里我们直接使用了mindspore定义好的接口:mindspore.nn.loss.MSELoss
:
在计算好对应参数的损失函数值之后,我们需要更新迭代参数,计算下一组参数的损失函数值,以确定向哪个方向“前进”才能找到最终的最低损失函数值。这个参数迭代的功能由反向传播网络实现,常用的参数更新算法有梯度下降等,关于梯度下降算法,在前面写过的这篇博客中有比较详细的介绍。其基本计算公式如下:
在mindspore中优化函数的接口为mindspore.nn.Momentum
:
这些模型都定义好之后,可以用mindspore.Model
进行封装和训练。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# test_linear.py
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mindspore import dataset as ds
from mindspore.common.initializer import Normal
from mindspore import nn, Tensor, Model
import time
from IPython import display
from mindspore.train.callback import Callback, LossMonitor
def get_data(num, w=2.0, b=3.0):
for _ in range(num):
x = np.random.uniform(-10.0, 10.0)
noise = np.random.normal(0, 1)
y = x * w b noise
yield np.array([x]).astype(np.float32), np.array([y]).astype(np.float32)
https://www.321flac.com/
eval_data = list(get_data(50))
x_target_label = np.array([-10, 10, 0.1])
y_target_label = x_target_label * 2 3
x_eval_label,y_eval_label = zip(*eval_data)
def create_dataset(num_data, batch_size=16, repeat_size=1):
input_data = ds.GeneratorDataset(list(get_data(num_data)), column_names=['data','label'])
input_data = input_data.batch(batch_size)
input_data = input_data.repeat(repeat_size)
return input_data
data_number = 1600
batch_number = 16
repeat_number = 1
ds_train = create_dataset(data_number, batch_size=batch_number, repeat_size=repeat_number)
dict_datasets = next(ds_train.create_dict_iterator())
class LinearNet(nn.Cell):
def __init__(self):
super(LinearNet, self).__init__()
self.fc = nn.Dense(1, 1, Normal(0.02), Normal(0.02))
def construct(self, x):
x = self.fc(x)
return x
net = LinearNet()
model_params = net.trainable_params()
x_model_label = np.array([-10, 10, 0.1])
y_model_label = (x_model_label * Tensor(model_params[0]).asnumpy()[0][0]
Tensor(model_params[1]).asnumpy()[0])
net = LinearNet()
net_loss = nn.loss.MSELoss()
opt = nn.Momentum(net.trainable_params(), learning_rate=0.005, momentum=0.9)
model = Model(net, net_loss, opt)
fig = plt.figure()
ims = []
def plot_model_and_datasets(net, eval_data):
weight = net.trainable_params()[0]
bias = net.trainable_params()[1]
x = np.arange(-10, 10, 0.1)
y = x * Tensor(weight).asnumpy()[0][0] Tensor(bias).asnumpy()[0]
x1, y1 = zip(*eval_data)
x_target = x
y_target = x_target * 2 3
plt.axis([-11, 11, -20, 25])
plt.scatter(x1, y1, color="red", s=5)
im = plt.plot(x, y, color="blue")
ims.append(im)
im1 = plt.plot(x_target, y_target, color="green")
ims.append(im1)
time.sleep(0.2)
class ImageShowCallback(Callback):
def __init__(self, net, eval_data):
self.net = net
self.eval_data = eval_data
def step_end(self, run_context):
plot_model_and_datasets(self.net, self.eval_data)
display.clear_output(wait=True)
epoch = 1
imageshow_cb = ImageShowCallback(net, eval_data)
model.train(epoch, ds_train, callbacks=[imageshow_cb], dataset_sink_mode=False)
plot_model_and_datasets(net, eval_data)
for net_param in net.trainable_params():
print(net_param, net_param.asnumpy())
ani = animation.ArtistAnimation(fig, ims, interval=500, repeat_delay=1000)
ani.save('train.gif', writer='pillow')
执行结果如下:
代码语言:javascript复制1
2
3
4
5
root@b8955ba28950:/home# python test_linear.py
WARNING: 'ControlDepend' is deprecated from version 1.1 and will be removed in a future version, use 'Depend' instead.
[WARNING] ME(444:140374496206976,MainProcess):2021-04-14-09:28:58.738.627 [mindspore/ops/operations/array_ops.py:2302] WARN_DEPRECATED: The usage of Pack is deprecated. Please use Stack.
Parameter (name=fc.weight) [[1.8964282]]
Parameter (name=fc.bias) [3.0266616]
执行完成后会在当前目录下生成一个名为train.gif
的动态图,演示整个训练优化的过程:
其中红色散点是训练数据,绿色直线是原始函数,蓝色直线是训练后的函数,可以看到两个函数是越来越接近的。最后拟合出来的函数为:
y=1.8964282x 3.0266616y=1.8964282x 3.0266616
与我们所预期的:
y=2x 3y=2x 3
还是略有差距,但是这其中的可能原因有很多,有可能是生成的随机散点的问题,也有可能是在这个范围内的线段拟合就是有这么大的误差,这里我们不做展开。到这里为止,我们就成功的使用mindspore完成了一个函数拟合的任务。
python绘制动态函数图
在上一个章节中我们演示了使用mindspore完成了一个线性函数的拟合,最后的代码中其实已经使用到了动态图的绘制方法,这里单独抽取出来作为一个章节来介绍。我们所使用到的工具是matplotlib.animation
,使用的第一步是在训练的外部先生成一个动态图像的对象:
1
2
fig = plt.figure()
ims = []
其中ims
是用于存储每一帧的数据绘制内容。第二步是将训练过程中需要变化的绘图对象添加到ims
中:
1
2
3
4
im = plt.plot(x, y, color="blue")
ims.append(im)
im1 = plt.plot(x_target, y_target, color="green")
ims.append(im1)
最后根据绘制的图的对象fig
和变化的图像集合ims
来生成一个动态图并且保存到本地文件中:
1
2
ani = animation.ArtistAnimation(fig, ims, interval=500, repeat_delay=1000)
ani.save('train.gif', writer='pillow')
关于animation.ArtistAnimation
的接口参数如下所示:
这里每一帧之间的间隔时间我们定义为500ms
,重复播放1000次,基本可以认为是一直在重复播放的。最终的效果图在上一个章节中已经做了展示,这里就不再重复说明。需要注意的是,生成动态图的过程会比较漫长,而且只有通过animation才能够生成和保存gif
动态图,直接通过plt.savefig
是无法保存为动态图的。
总结概要
很多机器学习的算法的基础就是函数的拟合,这里我们考虑的是其中一种最简单也最常见的场景:线性函数的拟合,并且我们要通过mindspore来实现这个数据的训练。通过构造均方误差函数,配合前向传播网络与反向传播网络的使用,最终大体成功的拟合了给定的一个线性函数。文末我们还顺带介绍了使用matplotlib的animation来生成动态图的功能,可视化的展现了整个训练的过程。