1.涉及公式
1.1 高斯分布公式
概率密度函数
1.2 二项分布公式
换句话说,一枚公平的硬币有正面结果的概率(正面)p = 0.5。如果你掷硬币 20 次,平均值为 20 * 0.5 = 10;你会期望得到10个正面
1.3 方差
继续以硬币为例,n 是投掷硬币的次数,p 是正面朝上的概率
1.4 标准差
换句话说,标准差是方差的平方根。
1.5 概率密度函数
2.编写高斯类
代码语言:javascript复制import math
import matplotlib.pyplot as plt
class Gaussian():
""" 高斯分布类,用于计算和可视化高斯分布.
Attributes:
均值(float),表示分布的均值。stdev(float)表示分布数据的标准偏差。_ list(float列表):从数据文件中提取的浮点列表
"""
def __init__(self, mu = 0, sigma = 1):
self.mean = mu
self.stdev = sigma
self.data = []
def calculate_mean(self):
"""函数计算数据集的平均值.
Args:
None
Returns:
float: mean of the data set
"""
avg = 1.0 * sum(self.data) / len(self.data)
self.mean = avg
return self.mean
def calculate_stdev(self, sample=True):
"""函数计算数据集的标准偏差.
Args:
sample (bool): 数据是代表样本还是总体
Returns:
float: 数据集的标准偏差
"""
if sample:
n = len(self.data) - 1
else:
n = len(self.data)
mean = self.mean
sigma = 0
for d in self.data:
sigma = (d - mean) ** 2
sigma = math.sqrt(sigma / n)
self.stdev = sigma
return self.stdev
def read_data_file(self, file_name, sample=True):
"""函数从txt文件读入数据。txt文件应该具有每行一个数字(浮动)。这些数字存储在数据属性中。读取文件后,计算平均值和标准偏差
Args:
file_name (string): name of a file to read from
Returns:
None
"""
with open(file_name) as file:
data_list = []
line = file.readline()
while line:
data_list.append(int(line))
line = file.readline()
file.close()
self.data = data_list
self.mean = self.calculate_mean()
self.stdev = self.calculate_stdev(sample)
def plot_histogram(self):
"""函数使用matplotlib pyplot库输出实例变量数据的直方图.
Args:
None
Returns:
None
"""
plt.hist(self.data)
plt.title('Histogram of Data')
plt.xlabel('data')
plt.ylabel('count')
def pdf(self, x):
"""高斯分布的概率密度函数计算器.
Args:
x (float): 计算概率密度函数的点
Returns:
float: 输出的概率密度函数
"""
return (1.0 / (self.stdev * math.sqrt(2*math.pi))) * math.exp(-0.5*((x - self.mean) / self.stdev) ** 2)
def plot_histogram_pdf(self, n_spaces = 50):
"""函数绘制数据的归一化直方图,并沿相同范围绘制概率密度函数
Args:
n_spaces (int): number of data points
Returns:
list: x values for the pdf plot
list: y values for the pdf plot
"""
mu = self.mean
sigma = self.stdev
min_range = min(self.data)
max_range = max(self.data)
# 计算x值之间的间隔
interval = 1.0 * (max_range - min_range) / n_spaces
x = []
y = []
# calculate the x values to visualize
for i in range(n_spaces):
tmp = min_range interval*i
x.append(tmp)
y.append(self.pdf(tmp))
# make the plots
fig, axes = plt.subplots(2,sharex=True)
fig.subplots_adjust(hspace=.5)
axes[0].hist(self.data, density=True)
axes[0].set_title('Normed Histogram of Data')
axes[0].set_ylabel('Density')
axes[1].plot(x, y)
axes[1].set_title('Normal Distribution for n Sample Mean and Sample Standard Deviation')
axes[0].set_ylabel('Density')
plt.show()
return x, y
3.测试高斯类
代码语言:javascript复制import unittest
class TestGaussianClass(unittest.TestCase):
def setUp(self):
self.gaussian = Gaussian(25, 2)
def test_initialization(self):
self.assertEqual(self.gaussian.mean, 25, 'incorrect mean')
self.assertEqual(self.gaussian.stdev, 2, 'incorrect standard deviation')
def test_pdf(self):
self.assertEqual(round(self.gaussian.pdf(25), 5), 0.19947,
'pdf function does not give expected result')
def test_meancalculation(self):
self.gaussian.read_data_file('numbers.txt', True)
self.assertEqual(self.gaussian.calculate_mean(),
sum(self.gaussian.data) / float(len(self.gaussian.data)), 'calculated mean not as expected')
def test_stdevcalculation(self):
self.gaussian.read_data_file('numbers.txt', True)
self.assertEqual(round(self.gaussian.stdev, 2), 92.87, 'sample standard deviation incorrect')
self.gaussian.read_data_file('numbers.txt', False)
self.assertEqual(round(self.gaussian.stdev, 2), 88.55, 'population standard deviation incorrect')
tests = TestGaussianClass()
tests_loaded = unittest.TestLoader().loadTestsFromModule(tests)
unittest.TextTestRunner().run(tests_loaded)