关于数据分析之线性回归

2022-03-11 17:01:57 浏览数 (1)

前文是讲从csv读取到pandas,本文是讲csv读取到numpy数组中的三种方法,当然genfromtext代码量最少,也最友好。

代码语言:javascript复制
import csv
import os
import numpy as np
import matplotlib.pyplot as plt
# genfromtxt方式
pathfilename=os.path.abspath('.') '\02-traindata.csv'
my_data=np.genfromtxt(pathfilename,delimiter=',',skip_header=0)
print(my_data)

# loadtxt方式
f=open(pathfilename,'r')
my_data=np.loadtxt(f,delimiter=',')
print(my_data)

# csv.reader方式
with open(pathfilename,'r') as f:
    data_iter=csv.reader(f,delimiter=',')
    data=[data for data in data_iter]
my_data=np.asarray(data,dtype=float)
print(my_data)

关于线性回归的预测,可以用numpy自带的polyfit,也可以用scikit-learn,一般建议用后者,次数越高训练集准确率越高,但过度拟合会导致测试集的预测准确率降低。

代码语言:javascript复制
# 构造训练数据
x=np.arange(1,10.1)
y=0.9*x np.sin(x)
# 构造测试数据
testx=np.arange(-2,12,0.5)
plt.plot(x,y,'o')
plt.show()
代码语言:javascript复制
# 一次线性回归求解
model=np.polyfit(x,y,deg=1)
# [0.85886294 0.36737264]
testy=np.polyval(model,testx)
plt.plot(x,y,'o',testx,testy,'x')
plt.show()
代码语言:javascript复制
# 二次线性回归求解
model=np.polyfit(x,y,deg=2)
# [0.03238583 0.50261879 1.07986094]
testx=np.arange(-2,12,0.5)
testy=np.polyval(model,testx)
plt.plot(x,y,'o',testx,testy,'x')
plt.show()
代码语言:javascript复制
# 三次线性回归求解
model=np.polyfit(x,y,deg=3)
# [-0.02581586  0.45834759 -1.46196846  3.29486208]
testx=np.arange(-2,12,0.5)
testy=np.polyval(model,testx)
plt.plot(x,y,'o',testx,testy,'x')
plt.show()

0 人点赞