Python 使用 Tornado Web 框架写HTTP接口实例代码(异步+线程池)

2022-09-27 18:09:53 浏览数 (1)

代码语言:javascript复制
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
# pip3 install tornado

import tornado.ioloop
import tornado.web
import tornado.gen
from concurrent.futures import ThreadPoolExecutor
from tornado.concurrent import run_on_executor
import json

from aixcoder.aixcode import AIXCode

AIXCode1 = AIXCode('codegen-350M-multi')
AIXCode2 = AIXCode('codegen-2B-multi')


def get_body_json(body):
    body_decode = body.decode()
    body_json = json.loads(body_decode)
    return body_json


class PingHandler(tornado.web.RequestHandler):
    @tornado.gen.coroutine
    def get(self):
        print(f'request:{self.request.full_url()}')
        self.write("Pong!")

    @tornado.gen.coroutine
    def post(self):
        print(f'request:{self.request.full_url()}')
        body_json = get_body_json(self.request.body)
        print(f'request:{body_json}')
        self.write("Pong!")


class AIX1Handler(tornado.web.RequestHandler):
    executor = ThreadPoolExecutor(32)

    @run_on_executor
    def aixcode(self, x):
        return AIXCode1.aixcode(x)

    @tornado.gen.coroutine
    def get(self):
        """get请求"""
        print(f'request:{self.request.full_url()}')
        x = self.get_argument('x')
        y = yield self.aixcode(x)
        self.write(y)

    @tornado.gen.coroutine
    def post(self):
        '''post请求'''
        print(f'request:{self.request.full_url()}')
        body_json = get_body_json(self.request.body)
        print(f'request:{body_json}')
        x = body_json.get("x")
        y = yield self.aixcode(x)
        self.write(y)


class AIX2Handler(tornado.web.RequestHandler):
    executor = ThreadPoolExecutor(32)

    @run_on_executor
    def aixcode(self, x):
        return AIXCode2.aixcode(x)

    @tornado.gen.coroutine
    def get(self):
        """get请求"""
        print(f'request:{self.request.full_url()}')
        x = self.get_argument('x')
        y = yield self.aixcode(x)
        self.write(y)

    @tornado.gen.coroutine
    def post(self):
        '''post请求'''
        print(f'request:{self.request.full_url()}')
        body_json = get_body_json(self.request.body)
        print(f'request:{body_json}')
        x = body_json.get("x")
        y = yield self.aixcode(x)
        self.write(y)


if __name__ == "__main__":
    # 注册路由
    app = tornado.web.Application([
        (r"/ping", PingHandler),
        (r"/aix1", AIX1Handler),
        (r"/aix2", AIX2Handler),
    ])

    # 监听端口
    port = 8888
    app.listen(port)
    print(f'AIXCoder Started, Listening on Port:{port}')
    # 启动应用程序
    tornado.ioloop.IOLoop.instance().start()

其中,class AIXCode 代码如下:

代码语言:javascript复制
# Copyright (c) 2022, salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause

# models_nl = ['codegen-350M-nl', 'codegen-2B-nl', 'codegen-6B-nl', 'codegen-16B-nl']
# models_pl = ['codegen-350M-multi', 'codegen-2B-multi', 'codegen-6B-multi', 'codegen-16B-multi',
#              'codegen-350M-mono',
#              'codegen-2B-mono', 'codegen-6B-mono', 'codegen-16B-mono']

import os
import re
import time
import random

import torch

from transformers import GPT2TokenizerFast
from aixcoder.codegen.modeling_codegen import CodeGenForCausalLM


########################################################################
# util
class print_time:
    def __init__(self, desc):
        self.desc = desc

    def __enter__(self):
        print(self.desc)
        self.t = time.time()

    def __exit__(self, type, value, traceback):
        print(f'{self.desc} took {time.time() - self.t:.02f}s')


def set_env():
    os.environ['TOKENIZERS_PARALLELISM'] = 'false'


def set_seed(seed, deterministic=True):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = deterministic
        torch.backends.cudnn.benchmark = not deterministic
        # torch.use_deterministic_algorithms(deterministic)


def cast(model, fp16=True):
    # if fp16:
    #     model.half()
    return model


########################################################################
# model


def create_model(ckpt, fp16=False):
    # if fp16:
    #     return CodeGenForCausalLM.from_pretrained(ckpt, revision='float16', torch_dtype=torch.float16, low_cpu_mem_usage=True)
    # else:
    return CodeGenForCausalLM.from_pretrained(ckpt)


def create_tokenizer():
    t = GPT2TokenizerFast.from_pretrained('gpt2')
    t.max_model_input_sizes['gpt2'] = 1e20
    return t


def include_whitespace(t, n_min=2, n_max=20, as_special_tokens=False):
    t.add_tokens([' ' * n for n in reversed(range(n_min, n_max))], special_tokens=as_special_tokens)
    return t


def include_tabs(t, n_min=2, n_max=20, as_special_tokens=False):
    t.add_tokens(['t' * n for n in reversed(range(n_min, n_max))], special_tokens=as_special_tokens)
    return t


def create_custom_gpt2_tokenizer():
    t = create_tokenizer()
    t = include_whitespace(t=t, n_min=2, n_max=32, as_special_tokens=False)
    t = include_tabs(t=t, n_min=2, n_max=10, as_special_tokens=False)
    return t


########################################################################
# sample

MAX_LENGTH_SAMPLE = 512


def sample(
        model,
        tokenizer,
        context,
        pad_token_id,
        num_return_sequences=1,
        temp=0.2,
        top_p=0.95,
        max_length_sample=MAX_LENGTH_SAMPLE,
        max_length=2048
):
    input_ids = tokenizer(
        context,
        truncation=True,
        padding=True,
        max_length=max_length,
        return_tensors='pt',
    ).input_ids

    input_ids_len = input_ids.shape[1]
    assert input_ids_len < max_length

    with torch.no_grad():
        input_ids = input_ids.to()
        tokens = model.generate(
            input_ids,
            do_sample=True,
            num_return_sequences=num_return_sequences,
            temperature=temp,
            max_length=input_ids_len   max_length_sample,
            top_p=top_p,
            pad_token_id=pad_token_id,
            use_cache=True,
        )
        text = tokenizer.batch_decode(tokens[:, input_ids_len:, ...])

    return text


def truncate(completion):
    def find_re(string, pattern, start_pos):
        m = pattern.search(string, start_pos)
        return m.start() if m else -1

    terminals = [
        re.compile(r, re.MULTILINE)
        for r in
        [
            '^#',
            re.escape('<|endoftext|>'),
            "^'''",
            '^"""',
            'nnn'
        ]
    ]

    prints = list(re.finditer('^print', completion, re.MULTILINE))
    if len(prints) > 1:
        completion = completion[:prints[1].start()]

    defs = list(re.finditer('^def', completion, re.MULTILINE))
    if len(defs) > 1:
        completion = completion[:defs[1].start()]

    start_pos = 0

    terminals_pos = [pos for pos in [find_re(completion, terminal, start_pos) for terminal in terminals] if pos != -1]
    if len(terminals_pos) > 0:
        return completion[:min(terminals_pos)]
    else:
        return completion


class AIXCode:
    def __init__(self, model_name):
        # preamble
        set_env()
        set_seed(42, deterministic=True)

        ckpt = f'/Users/bytedance/githubcode/CodeGen/checkpoints/{model_name}'

        # load
        with print_time(f'{model_name} loading parameters'):
            model = create_model(ckpt=ckpt, fp16=False).to()

        with print_time(f'{model_name} loading tokenizer'):
            tokenizer = create_custom_gpt2_tokenizer()
            tokenizer.padding_side = 'left'
            tokenizer.pad_token = 50256

        self.model = model
        self.tokenizer = tokenizer

    def aixcode(self, context_string):
        # sample
        with print_time(f'{context_string} ... AIXCoding >>>'):
            completion = sample(model=self.model,
                                tokenizer=self.tokenizer,
                                context=context_string,
                                pad_token_id=50256,
                                num_return_sequences=1,
                                temp=0.2,
                                top_p=0.95,
                                max_length_sample=MAX_LENGTH_SAMPLE)[0]

            truncation = truncate(completion)

            return context_string   truncation

参考文档: https://blog.csdn.net/rensihui/article/details/80474706

0 人点赞