之前写过一个教程,教大家如何自己训练出一个文本生成的模型,然后用LightSeq来加速推理: 用了这个技术,我让模型训练和推理快了好几倍
这篇文章是我用AI生成出来的
但是,训练好模型之后,别人如果没有显卡的话,就没法体验到快乐了呀!有一个办法,那就是把模型部署在GPU服务器上,然后别人直接发送请求进行访问就行了。更进一步,还可以做成网页,更方便互动!
申请GPU服务器
这里我发现了一个能白嫖的GPU服务器: https://www.autodl.com
注册送10块钱,只需要申请一个便宜点的GPU就行了,几毛钱每小时。但一定要注意了,如果用的是int8量化模型,一定要选择计算能力>=7.5的显卡!下面的表格可以查看NVIDIA所有显卡的计算能力: https://en.wikipedia.org/wiki/CUDA
服务端部署
按照之前的教程,训练并导出模型之后,就可以用LightSeq进行部署了。
用ssh连接服务器之后,安装一些必要的python库:
代码语言:javascript复制pip3 install lightseq transformers
然后就可以开始部署了,下面是一个简单的服务端代码。这个代码从6006端口接收用户请求,然后转换成id,送给LightSeq推理,最后还原成文本,发送回去。
代码语言:javascript复制import os
import socket
import threading
import time
from transformers import BertTokenizer
import lightseq.inference as lsi
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.bind(('127.0.0.1', 6006))
s.listen(5)
os.system("wget -nc https://zenodo.org/record/7233565/files/aiai97.hdf5")
tokenizer = BertTokenizer.from_pretrained("uer/gpt2-chinese-cluecorpussmall")
model = lsi.QuantGpt("aiai97.hdf5", 16)
def tcplink(sock, addr):
print('Accept new connection from %s:%s...' % addr)
while True:
data = sock.recv(1024)
if not data or len(data.decode('utf-8')) <= 0:
break
time.sleep(1)
print(data.decode('utf-8'))
inputs_ids = tokenizer([data.decode('utf-8')], return_tensors="pt", padding=True)["input_ids"]
ls_res_ids = model.sample(inputs_ids)
ls_res = tokenizer.batch_decode(ls_res_ids, skip_special_tokens=True)
res = ''.join(ls_res[0].split())
print(res)
sock.send(res.encode('utf-8'))
sock.close()
print('Connection from %s:%s closed.' % addr)
while True:
sock, addr = s.accept()
t = threading.Thread(target=tcplink, args=(sock, addr))
t.start()
客户端请求
然后任何人就可以使用起来啦!控制台可以看到公网域名和端口号:
然后在任意电脑上,用下面代码就可以请求啦:
代码语言:javascript复制import socket
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect(('region-4.autodl.com', 40977))
s.send(input("请输入句子前缀:n").encode('utf-8'))
print(s.recv(1024).decode('utf-8', 'ignore'))
s.close()
注意这里的域名和端口号改成你自己的,然后运行得到结果:
网页互动
如果你觉得命令行黑漆漆的不方便,那也可以做一个网页,给别人更好的体验!
首先安装gradio库:
代码语言:javascript复制pip3 install gradio
然后客户端代码改成下面这样就行了:
代码语言:javascript复制import socket
import gradio as gr
def predict(text):
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect(('region-4.autodl.com', 40977))
s.send(text.encode('utf-8'))
res = s.recv(1024).decode('utf-8', 'ignore')
s.close()
return res
gr.Interface(fn=predict,
inputs=["text"],
outputs=["text"]).launch(share=True)
然后运行就可以看到下面提示:
这里会显示一个内网ip和公网地址,公网地址可以分享给你的小伙伴体验,打开后是这样的: