跳过正文
  1. Posts/

Migrate from Tornado to FastAPI

·3 分钟· · · #PYTHON #Tornado #FastAPI

心水 FastAPI 的 ASGI、依赖注入、类型注解和自动在线文档很久,但迫于项目历史依赖(立项的时候 FastAPI 还没出来),迟迟没有迁移的动力。直到某次又翻到过期的接口文档,实在不想继续过文档和代码分割的日子了,于是决定开整。

迫于旧接口太多,一下迁移明显不现实,想了两个过渡方案:

  1. Nginx 分流,旧接口继续使用 Tornado 驱动,新接口走 FastAPI
  2. 完全抛弃 Tornado 驱动,利用 WSGI 将旧接口分流、转换成 Tornado 兼容的响应

方案一实施起来最简单,但有个问题是 Session 共享不好解决,另外增加了额外的运维部署复杂度,于是pass;

走方案二的话,就需要写一个中间件来将旧接口分流、转换给原 Tornado 的实现,代码如下:

经典的 Tornado hello world

import asyncio
import tornado


class MainHandler(tornado.web.RequestHandler):
    def get(self):
        self.write("Hello, world")


def make_app():
    return tornado.web.Application([
        (r"/api/v1/", MainHandler),
    ])


async def main():
    app = make_app()
    app.listen(8888)
    await asyncio.Event().wait()


if __name__ == "__main__":
    asyncio.run(main())

本来是打算手撸 ASGI -> WSGI 的,刚好刷推看到 Django 的 a2wsgi,能省不少手撕 WSGI 协议的活,于是果断抄之,最终成品如下:

import asyncio
import contextvars
import functools
import urllib.parse as urllib_parse
from asyncio import Future
from typing import Any, List, Tuple

import tornado
import uvicorn
from a2wsgi.types import Environ, StartResponse
from a2wsgi.wsgi import Body, WSGIResponder, build_environ
from fastapi import FastAPI
from starlette.types import Receive, Scope, Send
from tornado import httputil
from tornado.escape import native_str
from tornado.web import Application


# WSGI has no facilities for flow control, so just return an already-done
# Future when the interface requires it.
def _dummy_future():
    f = Future()
    f.set_result(None)
    return f


class _WSGIRequestContext:
    def __init__(self, remote_ip, protocol):
        self.remote_ip = remote_ip
        self.protocol = protocol

    def __str__(self):
        return self.remote_ip


class _WSGIConnection(httputil.HTTPConnection):
    def __init__(self, method, start_response, context):
        self.method = method
        self.start_response = start_response
        self.context = context
        self._write_buffer = []
        self._finished = False
        self._expected_content_remaining = None
        self._error = None

    def set_close_callback(self, callback):
        # WSGI has no facility for detecting a closed connection mid-request,
        # so we can simply ignore the callback.
        pass

    def write_headers(self, start_line, headers, chunk=None, callback=None):
        if self.method == "HEAD":
            self._expected_content_remaining = 0
        elif "Content-Length" in headers:
            self._expected_content_remaining = int(headers["Content-Length"])
        else:
            self._expected_content_remaining = None
        self.start_response(
            "%d %s" % (start_line.code, start_line.reason),
            [(native_str(k), native_str(v)) for (k, v) in headers.get_all()],
        )
        if chunk is not None:
            self.write(chunk, callback)
        elif callback is not None:
            callback()
        return _dummy_future()

    def write(self, chunk, callback=None):
        if self._expected_content_remaining is not None:
            self._expected_content_remaining -= len(chunk)
            if self._expected_content_remaining < 0:
                self._error = httputil.HTTPOutputError(
                    "Tried to write more data than Content-Length"
                )
                raise self._error
        self._write_buffer.append(chunk)
        if callback is not None:
            callback()
        return _dummy_future()

    def finish(self):
        if (
            self._expected_content_remaining is not None
            and self._expected_content_remaining != 0
        ):
            self._error = httputil.HTTPOutputError(
                f"Tried to write {self._expected_content_remaining} bytes less than Content-Length"
            )
            raise self._error
        self._finished = True


class WSGIAdapter:
    def __init__(self, app):
        self.app = app

    async def __call__(
        self, environ: Environ, start_response: StartResponse
    ) -> List[bytes]:
        method = environ["REQUEST_METHOD"]
        uri = urllib_parse.quote(environ.get("SCRIPT_NAME", ""))
        uri += urllib_parse.quote(environ.get("PATH_INFO", ""))
        if environ.get("QUERY_STRING"):
            uri += "?" + environ["QUERY_STRING"]
        headers = httputil.HTTPHeaders()
        if environ.get("CONTENT_TYPE"):
            headers["Content-Type"] = environ["CONTENT_TYPE"]
        if environ.get("CONTENT_LENGTH"):
            headers["Content-Length"] = environ["CONTENT_LENGTH"]
        for key in environ:
            if key.startswith("HTTP_"):
                headers[key[5:].replace("_", "-")] = environ[key]
        if headers.get("Content-Length"):
            body = await environ["wsgi.input"].aread(int(headers["Content-Length"]))
        else:
            body = b""
        protocol = environ["wsgi.url_scheme"]
        remote_ip = environ.get("REMOTE_ADDR", "")
        if environ.get("HTTP_HOST"):
            host = environ["HTTP_HOST"]
        else:
            host = environ["SERVER_NAME"]
        connection = _WSGIConnection(
            method, start_response, _WSGIRequestContext(remote_ip, protocol)
        )
        request = httputil.HTTPServerRequest(
            method,
            uri,
            "HTTP/1.1",
            headers=headers,
            body=body,
            host=host,
            connection=connection,
        )
        request._parse_body()
        await self.app(request)
        if connection._error:
            raise connection._error
        if not connection._finished:
            raise Exception("request did not finish synchronously")
        return connection._write_buffer


class TornadoMiddleware:
    def __init__(self, app) -> None:
        self.app = WSGIAdapter(app)
        self.executor = None

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        if scope["type"] == "http":
            responder = _WSGIResponder(self.app, self.executor)
            return await responder(scope, receive, send)

        if scope["type"] == "websocket":
            await send({"type": "websocket.close", "code": 1000})
            return None


class _Body(Body):
    async def _areceive_more_data(self) -> bytes:
        if not self._has_more:
            return b""
        message = await self.receive()
        self._has_more = message.get("more_body", False)
        return message.get("body", b"")

    async def aread(self, size: int = -1) -> bytes:
        while size == -1 or size > len(self.buffer):
            self.buffer.extend(await self._areceive_more_data())
            if not self._has_more:
                break
        if size == -1:
            result = bytes(self.buffer)
            self.buffer.clear()
        else:
            result = bytes(self.buffer[:size])
            del self.buffer[:size]
        return result


class _WSGIResponder(WSGIResponder):
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        body = _Body(self.loop, receive)
        environ = build_environ(scope, body)
        sender = None
        try:
            sender = self.loop.create_task(self.sender(send))
            context = contextvars.copy_context()
            func = functools.partial(context.run, self.awsgi)
            await func(environ, self.start_response)
            self.send_queue.append(None)
            self.send_event.set()
            await asyncio.wait_for(sender, None)
            if self.exc_info is not None:
                raise self.exc_info[0].with_traceback(
                    self.exc_info[1], self.exc_info[2]
                )
        finally:
            if sender and not sender.done():
                sender.cancel()  # pragma: no cover

    def start_response(
        self,
        status: str,
        response_headers: List[Tuple[str, str]],
        exc_info: Any = None,
    ) -> None:
        self.exc_info = exc_info
        if not self.response_started:
            self.response_started = True
            status_code_string, _ = status.split(" ", 1)
            status_code = int(status_code_string)
            headers = [
                (name.strip().encode("latin1").lower(), value.strip().encode("latin1"))
                for name, value in response_headers
            ]
            self.send(
                {
                    "type": "http.response.start",
                    "status": status_code,
                    "headers": headers,
                }
            )

    async def awsgi(self, environ: Environ, start_response: StartResponse) -> None:
        for chunk in await self.app(environ, start_response):
            self.send({"type": "http.response.body", "body": chunk, "more_body": True})

        self.send({"type": "http.response.body", "body": b""})


class HandleDelegate(tornado.web._HandlerDelegate):
    def execute(self):
        if not self.application.settings.get("compiled_template_cache", True):
            with tornado.web.RequestHandler._template_loader_lock:
                for loader in tornado.web.RequestHandler._template_loaders.values():
                    loader.reset()
        if not self.application.settings.get("static_hash_cache", True):
            tornado.web.StaticFileHandler.reset()

        self.handler = self.handler_class(
            self.application, self.request, **self.handler_kwargs
        )
        transforms = [t(self.request) for t in self.application.transforms]

        return self.handler._execute(transforms, *self.path_args, **self.path_kwargs)


class TornadoApplication(Application):
    def get_handler_delegate(
        self,
        request,
        target_class,
        target_kwargs=None,
        path_args=None,
        path_kwargs=None,
    ):
        return HandleDelegate(
            self, request, target_class, target_kwargs, path_args, path_kwargs
        )


class MainHandler(tornado.web.RequestHandler):
    def get(self):
        self.write("Hello, world")


def make_app():
    return TornadoApplication([(r"/api/v1/", MainHandler)])


async def main():
    app = make_app()
    app.listen(8888)
    await asyncio.Event().wait()


app = FastAPI()
app.mount("/api/v1", TornadoMiddleware(make_app()))


@app.get("/api/v2/")
def hello_world():
    return {"message": "Hello, World!"}


if __name__ == "__main__":
    uvicorn.run(app, host="127.0.0.1", port=8888)

测一下:

~ curl http://127.0.0.1:8888/api/v2/
{"message":"Hello, World!"}%
~ curl http://127.0.0.1:8888/api/v1/
Hello, world%

可以看出 v1 接口请求被 FastAPI 通过中间件机制传给 TornadoMiddleware,再由 TornadoMiddleware 完成 ASGI 到 WSGI 再到 Tornado 接口代码的转换工作,完美解决了旧 Tornado 接口与新 FastAPI 接口共存的问题,且不影响现有项目部署流程,十分nice。

# NOTE: I am not responsible for any expired content.
create@2023-06-09T12:32:25+08:00
update@2023-12-27T11:08:32+08:00
comment@https://github.com/ferstar/blog/issues/77

相关文章