将Keras深度学习模型部署为Web应用程序

2018-12-18 15:10:44 浏览数 (1)

编译:yxy

出品:ATYUN订阅号

建立一个很酷的机器学习项目确实很不错,但如果你希望其他人能够看到你的作品怎么办呢?当然,你可以将整个项目放在GitHub上,但这只能给程序员看,如果你想给自己家里的老人看呢?GitHub肯定不行,所以我们想要的是将我们的深度学习模型部署成世界上任何人都轻易访问的Web应用程序。

在本文中,我们将看到如何编写一个Web应用程序获取经过训练的RNN,并使用户生成新的专利摘要。这个项目建立在RNN示例项目:详解使用RNN撰写专利摘要文章的基础上,但你不需要知道如何创建RNN。我们现在只将其视为一个黑盒子:我们输入一个起始序列,它输出一个全新的专利摘要,然后将其在浏览器上显示!

http://www.atyun.com/32461.html

一般来说,数据科学家开发模型,前端工程师负责展示。但在这个项目中,我们将不得不同时扮演这两个角色,并深入研究Web开发(尽管几乎的都用Python写)。

这个项目需要结合:

  • Flask:用Python创建一个基本的Web应用程序
  • Keras:部署训练好的RNN
  • 使用Jinja模板库进行模板化
  • 用于编写网页的HTML和CSS

最终我们得到一个Web应用程序,允许用户使用训练好的RNN生成全新的专利摘要:

方法

我们的目标是尽快启动和运行Web应用程序。因此,我选择了Flask,它使我们可以用Python编写应用程序。我不喜欢乱糟糟的样式所以几乎所有的CSS都是复制和粘贴的。

Flask的基本Web应用程序

在Python中构建Web应用程序的最快方法是使用Flask。要制作我们自己的应用,我们可以:

代码语言:javascript复制
from flaskimport Flask
代码语言:javascript复制
app= Flask(__name__)
代码语言:javascript复制
代码语言:javascript复制
@app.route("/")
代码语言:javascript复制
def hello():
代码语言:javascript复制
    return "<h1>Not Much Going On Here</h1>"
代码语言:javascript复制
app.run(host='0.0.0.0', port=50000)

如果你复制粘贴此代码并运行它,将能够在localhost:50000上看到自己的Web应用程序。当然,我们要做的肯定不仅仅是这些,所以我们使用一个稍微复杂的函数,它基本上做同样的事情:处理来自浏览器的请求并将一些内容作为HTML提供。对于我们的主页面,我们希望向用户显示一个表单(Form),使用户可以输入一些详细信息。

用户输入表格

当我们的用户到达应用程序的主页面时,我们将向他们展示一个包含三个参数的表单:

  1. 输入RNN的起始序列或随机选择
  2. 选择RNN预测的多样性
  3. 选择RNN输出的字数

要在Python中构建表单,我们将使用wtforms 。制作表单的代码是:

代码语言:javascript复制
from wtformsimport (Form, TextField, validators, SubmitField,
代码语言:javascript复制
DecimalField, IntegerField)
代码语言:javascript复制
代码语言:javascript复制
class ReusableForm(Form):
代码语言:javascript复制
    """User entry form for entering specifics for generation"""
代码语言:javascript复制
    # Starting seed
代码语言:javascript复制
    seed= TextField("Enter a seed string or 'random':", validators=[
代码语言:javascript复制
                     validators.InputRequired()])
代码语言:javascript复制
    # Diversity of predictions
代码语言:javascript复制
    diversity= DecimalField('Enter diversity:', default=0.8,
代码语言:javascript复制
                             validators=[validators.InputRequired(),
代码语言:javascript复制
                                         validators.NumberRange(min=0.5,max=5.0,
代码语言:javascript复制
                                         message='Diversity must be between 0.5 and 5.')])
代码语言:javascript复制
    # Number of words
代码语言:javascript复制
    words= IntegerField('Enter number of words to generate:',
代码语言:javascript复制
                         default=50, validators=[validators.InputRequired(),
代码语言:javascript复制
                                                 validators.NumberRange(min=10,max=100,
代码语言:javascript复制
                                                 message='Number of words must be between 10 and 100')])
代码语言:javascript复制
    # Submit button
代码语言:javascript复制
    submit= SubmitField("Enter")

这将创建一个如下所示的表单(带有main.css的样式):

代码中的validator确保用户输入正确的信息。例如,我们检查所有输入框已填充且diversity介于0.5和5之间。必须满足这些条件才能接受表单。

验证错误

我们实际使用Flask提供表单服务的方式是使用模板。

模板

模板是一个带有基本框架的文档,我们需要添加详细信息。对于Flask Web应用程序,我们可以使用Jinja模板库将Python代码传递给HTML文档。例如,在我们的main函数中,我们将表单的内容发送到一个名为index.html的模板。

代码语言:javascript复制
from flaskimport render_template
代码语言:javascript复制
代码语言:javascript复制
# Home page
代码语言:javascript复制
@app.route("/", methods=['GET','POST'])
代码语言:javascript复制
def home():
代码语言:javascript复制
    """Home page of app with form"""
代码语言:javascript复制
    # Create form
代码语言:javascript复制
    form= ReusableForm(request.form)
代码语言:javascript复制
代码语言:javascript复制
    # Send template information to index.html
代码语言:javascript复制
    return render_template('index.html', form=form)

当用户到达主页时,我们的应用程序将提供带有form详细信息的index.html息。模板是一个简单的html脚手架,我们用{{variable}} 语法引用python变量。

代码语言:javascript复制
<!DOCTYPE html>
代码语言:javascript复制
<html>
代码语言:javascript复制
代码语言:javascript复制
<head>
代码语言:javascript复制
  <title>RNN Patent Writing</title>
代码语言:javascript复制
  <link rel="stylesheet" href="/static/css/main.css">
代码语言:javascript复制
  <link rel="shortcut icon" href="/static/images/lstm.ico">
代码语言:javascript复制
代码语言:javascript复制
</head>
代码语言:javascript复制
代码语言:javascript复制
<body>
代码语言:javascript复制
  <divclass="container">
代码语言:javascript复制
    <h1>
代码语言:javascript复制
      <center>Writing Novel Patent Abstracts with Recurrent Neural Networks</center>
代码语言:javascript复制
    </h1>
代码语言:javascript复制
代码语言:javascript复制
    {% block content %}
代码语言:javascript复制
    {%for message in form.seed.errors %}
代码语言:javascript复制
    <divclass="flash">{{ message }}</div>
代码语言:javascript复制
    {%endfor %}
代码语言:javascript复制
代码语言:javascript复制
    {%for message in form.diversity.errors %}
代码语言:javascript复制
    <divclass="flash">{{ message }}</div>
代码语言:javascript复制
    {%endfor %}
代码语言:javascript复制
代码语言:javascript复制
    {%for message in form.words.errors %}
代码语言:javascript复制
    <divclass="flash">{{ message }}</div>
代码语言:javascript复制
    {%endfor %}
代码语言:javascript复制
代码语言:javascript复制
    <form method=post>
代码语言:javascript复制
代码语言:javascript复制
      {{ form.seed.label }}
代码语言:javascript复制
      {{ form.seed }}
代码语言:javascript复制
代码语言:javascript复制
      {{ form.diversity.label }}
代码语言:javascript复制
      {{ form.diversity }}
代码语言:javascript复制
代码语言:javascript复制
      {{ form.words.label }}
代码语言:javascript复制
      {{ form.words }}
代码语言:javascript复制
代码语言:javascript复制
      {{ form.submit }}
代码语言:javascript复制
    </form>
代码语言:javascript复制
    {% endblock %}
代码语言:javascript复制
代码语言:javascript复制
  </div>
代码语言:javascript复制
</body>
代码语言:javascript复制
代码语言:javascript复制
</html>

对于表单中的每个错误(那些无法验证的条目),错误将flash。除此之外,此文件将显示上面的表单。

当用户输入信息并点击submit(POST请求)时,如果信息是正确的,我们希望将输入转移到正确的函数以使用经过训练的RNN进行预测。这意味着要修改home() 。

代码语言:javascript复制
from flaskimport request
代码语言:javascript复制
# User defined utility functions
代码语言:javascript复制
from utilsimport generate_random_start, generate_from_seed
代码语言:javascript复制
代码语言:javascript复制
# Home page
代码语言:javascript复制
@app.route("/", methods=['GET','POST'])
代码语言:javascript复制
def home():
代码语言:javascript复制
    """Home page of app with form"""
代码语言:javascript复制
代码语言:javascript复制
    # Create form
代码语言:javascript复制
    form= ReusableForm(request.form)
代码语言:javascript复制
代码语言:javascript复制
    # On form entry and all conditions met
代码语言:javascript复制
    if request.method== 'POST' and form.validate():
代码语言:javascript复制
        # Extract information
代码语言:javascript复制
        seed= request.form['seed']
代码语言:javascript复制
        diversity= float(request.form['diversity'])
代码语言:javascript复制
        words= int(request.form['words'])
代码语言:javascript复制
        # Generate a random sequence
代码语言:javascript复制
        if seed== 'random':
代码语言:javascript复制
            return render_template('random.html',
代码语言:javascript复制
                                   input=generate_random_start(model=model,
代码语言:javascript复制
                                                               graph=graph,
代码语言:javascript复制
                                                               new_words=words,
代码语言:javascript复制
                                                               diversity=diversity))
代码语言:javascript复制
        # Generate starting from a seed sequence
代码语言:javascript复制
        else:
代码语言:javascript复制
            return render_template('seeded.html',
代码语言:javascript复制
                                   input=generate_from_seed(model=model,
代码语言:javascript复制
                                                            graph=graph,
代码语言:javascript复制
                                                            seed=seed,
代码语言:javascript复制
                                                            new_words=words,
代码语言:javascript复制
                                                            diversity=diversity))
代码语言:javascript复制
    # Send template information to index.html
代码语言:javascript复制
    return render_template('index.html', form=form)

现在,当用户点击submit并且信息正确时,输入被发送到generate_random_start或generate_from_seed中(取决于输入)。这些函数使用经过训练的Keras模型生成具有用户指定的diversity和num_words的新专利。这些函数的输出依次被发送到random.html或seeded.html任一模板作为网页。

用预训练的Keras模型进行预测

model参数是经过训练的Keras模型,其加载如下:

代码语言:javascript复制
from keras.modelsimport load_model
代码语言:javascript复制
import tensorflow as tf
代码语言:javascript复制
代码语言:javascript复制
def load_keras_model():
代码语言:javascript复制
    """Load in the pre-trained model"""
代码语言:javascript复制
    global model
代码语言:javascript复制
    model= load_model('../models/train-embeddings-rnn.h5')
代码语言:javascript复制
    # Required for model to work
代码语言:javascript复制
    global graph
代码语言:javascript复制
    graph= tf.get_default_graph()
代码语言:javascript复制
代码语言:javascript复制
load_keras_model()

(这tf.get_default_graph()是基于以下gist的变通方案。)

gist:https://gist.github.com/eyesonlyhack/2f0b20f1e73aaf5e9b83f49415f3601a

在这里我没有展示util函数的全部内容(https://github.com/WillKoehrsen/recurrent-neural-networks/blob/master/deployment/utils.py),你所需要了解的是,他们使用训练过的Keras模型和参数,并对一个新的专利摘要进行预测。

这些函数都返回HTML格式的Python字符串。此字符串将发送到另一个模板以显示为网页。例如,generate_random_start返回格式化的html进入random.html:

代码语言:javascript复制
<!DOCTYPE html>
代码语言:javascript复制
<html>
代码语言:javascript复制
代码语言:javascript复制
<header>
代码语言:javascript复制
    <title>Random Starting Abstract
代码语言:javascript复制
    </title>
代码语言:javascript复制
代码语言:javascript复制
    <link rel="stylesheet" href="/static/css/main.css">
代码语言:javascript复制
    <link rel="shortcut icon" href="/static/images/lstm.ico">
代码语言:javascript复制
    <ul>
代码语言:javascript复制
        <li><a href="/">Home</a></li>
代码语言:javascript复制
    </ul>
代码语言:javascript复制
</header>
代码语言:javascript复制
代码语言:javascript复制
<body>
代码语言:javascript复制
    <divclass="container">
代码语言:javascript复制
        {% block content%}
代码语言:javascript复制
        {{input|safe}}
代码语言:javascript复制
        {% endblock%}
代码语言:javascript复制
    </div>
代码语言:javascript复制
</body>
代码语言:javascript复制
代码语言:javascript复制
</html>

在这里,我们再次使用Jinja模板引擎来显示格式化的HTML。由于Python字符串已经格式化为HTML,我们所要做的就是使用{{input|safe}}(input是Python变量)来显示它。然后我们可以像使用其他html模板一样用main.css设置此页面的样式。

输出

函数generate_random_start选择随机专利摘要作为起始序列,并根据它进行预测。然后显示起始序列,RNN生成的输出和实际输出:

随机启动序列的输出。

函数generate_from_seed采用用户提供的启动序列,然后使用训练好的RNN构建输出。输出显示如下:

从起始种子序列得到的输出

虽然结果并不总是完全正确,但它们确实表明RNN已经掌握了英语的基础知识。它经过训练可以预测前50个单词中的下一个单词,并且已经学会了如何编写一个略有说服力的专利摘要!根据预测的多样性,输出可能完全是随机的或循环的。

运行应用程序

要自己运行应用程序,只需下载存储库,到deployment目录的python run_keras_server.py 。这将立即使web应用程序在localhost:10000上可用。

根据家庭WiFi的配置方式,你应该能够使用你的IP地址从网络上的任何计算机访问该应用程序。

下一步

在个人计算机上运行的Web应用程序非常适合与朋友和家人共享。不过,我绝对不会建议在你的家庭网络中向所有人开放这个网站!此,我们将在AWS EC2实例上设置应用程序,并将其提供给全世界(会在下节提供)。

为了改进应用程序,我们可以改变样式(通过main.css),或许还可以添加更多选项,比如选择预训练好的网络。个人项目的好处是,你可以随心所欲地去做。如果您想玩这个应用程序,请下载代码并开始使用。

结论

在本文中,我们了解了如何将经过训练的Keras深度学习模型部署为Web应用程序。这需要许多不同的技术,包括RNN,Web应用程序,模板,HTML,CSS,当然还有Python。

虽然这只是一个基础的应用程序,但它表明你可以用相对较少的努力开始使用深度学习来构建Web应用程序。没有多少人可以将深度学习模型部署为Web应用程序,但如果如果按本文操作,那么你就可以!

GitHub:https://github.com/WillKoehrsen/recurrent-neural-networks

0 人点赞