西部数码主机 | 阿里云主机| 虚拟主机 | 服务器 | 返回乐道官网
当前位置: 主页 > 开发教程 > python教程 >

初探Python3的异步IO编程

时间:2016-01-07 16:06来源:未知 作者:好模板 点击:
先简单介绍下各种 IO 模型: 最容易做的是阻塞 IO,即读写数据时,需要等待操作完成,才能继续执行。进阶的做法就是用多线程来处理需要 IO 的部分,缺点是开销会有些大。 接着是非

先简单介绍下各种 IO 模型:

最容易做的是阻塞 IO,即读写数据时,需要等待操作完成,才能继续执行。进阶的做法就是用多线程来处理需要 IO 的部分,缺点是开销会有些大。

接着是非阻塞 IO,即读写数据时,如果暂时不可读写,则立刻返回,而不等待。因为不知道什么时候是可读写的,所以轮询时可能会浪费 CPU 时间。

然后是 IO 复用,即在读写数据前,先检查哪些描述符是可读写的,再去读写。select 和 poll 就是这样做的,它们会遍历所有被监视的描述符,查看是否满足,这个检查的过程是阻塞的。而 epoll、kqueue 和 /dev/poll 则做了些改进,事先注册需要检查哪些描述符的哪些事件,当状态发生变化时,内核会调用对应的回调函数,将这些描述符保存下来;下次获取可用的描述符时,直接返回这些发生变化的描述符即可。

再之后是信号驱动,即描述符就绪时,内核发送 SIGIO 信号,再由信号处理程序去处理这些信号即可。不过信号处理的时机是从内核态返回用户态时,感觉也得把这些事件收集起来才好处理,有点像模拟 IO 复用了。

最后是异步 IO,即读写数据时,只注册事件,内核完成读写后(读取的数据会复制到用户态),再调用事件处理函数。这整个过程都不会阻塞调用线程,不过实现它的操作系统比较少,Windows 上有比较成熟的 IOCP,Linux 上的 AIO 则有不少缺点。

虽然真正的异步 IO 需要中间任何步骤都没有阻塞,这对于某些只是偶尔需要处理 IO 请求的情况确实有用(比如文本编辑器偶尔保存一下文件);但对于服务器端编程的大多数情况而言,它的主线程就是用来处理 IO 请求的,如果在空闲时不阻塞在 IO 等待上,也没有别的事情能做,所以本文就不纠结这个异步是否名副其实了。

在 Python 2 的时代,高性能的网络编程主要是使用 Twisted、Tornado 和 gevent 这三个库。

我对 Twisted 不熟,只知道它的缺点是比较重,性能相对而言并不算好。

Tornado 平时用得比较多,缺点是写异步调用时特别麻烦。

gevent 我只能算接触过,缺点是不太干净。

由于它们都各自有一个 IO loop,不好混用,而 Tornado 的 web 框架相对而言比较完善,因此成了我的首选。

而从 Python 3.4 开始,标准库里又新增了 asyncio 这个模块。

从原理上来说,它和 Tornado 其实差不多,都是注册 IO 事件,然后在 IO loop 中等待事件发生,然后调用相应的处理函数。

不同之处在于 Python 3 增加了一些新的特性,而 Tornado 需要兼容 Python 2,所以写起来会比较麻烦。

举例来说,Python 3.3 可以在 generator 中 return 返回值(相当于 raise StopIteration),而 Tornado 中需要 raise 一个 Return 对象。此外,Python 3.3 还增加了 yield from 语法,减轻了在 generator 中处理另一个 generator 的工作量(省去了循环和 try … except …)。

不过,虽然 asyncio 有那么多得天独厚的优势,却不一定比 Tornado 的性能更好,所以我写个简单的例子测试一下。

比较方法就是写个最简单的 HTTP 服务器,不做任何检查,读取到任何内容都输出一个 hello world,并断开连接。

测试的客户端就懒得写了,直接用 ab 即可:

ab -n 10000 -c 10 "http://0.0.0.0:8000/"
ab -n 10000 -c 10 "http://0.0.0.0:8000/"

Tornado 版是这样:

from tornado.gen import coroutine
from tornado.ioloop import IOLoop
from tornado.tcpserver import TCPServer

class Server(TCPServer):
    @coroutine
    def handle_stream(self, stream, address):
        try:
            yield stream.read_bytes(1024, partial=True)
            yield stream.write(b'HTTP 1.0 200 OKrnrnhello world')
        finally:
            stream.close()

server = Server()
server.bind(8000)
server.start(1)
IOLoop.current().start()
fromtornado.genimportcoroutine
fromtornado.ioloopimportIOLoop
fromtornado.tcpserverimportTCPServer
 
class Server(TCPServer):
    @coroutine
    defhandle_stream(self, stream, address):
        try:
            yieldstream.read_bytes(1024, partial=True)
            yieldstream.write(b'HTTP 1.0 200 OKrnrnhello world')
        finally:
            stream.close()
 
server = Server()
server.bind(8000)
server.start(1)
IOLoop.current().start()

在我的电脑上大概 4000 QPS。

asyncio 版是这样:

import asyncio

class Server(asyncio.Protocol):
    def connection_made(self, transport):
        self.transport = transport

    def data_received(self, data):
        try:
            self.transport.write(b'HTTP/1.1 200 OKrnrnhello world')
        finally:
            self.transport.close()

loop = asyncio.get_event_loop()
server = loop.create_server(Server, '', 8000)
loop.run_until_complete(server)
loop.run_forever()
importasyncio
 
class Server(asyncio.Protocol):
    defconnection_made(self, transport):
        self.transport = transport
 
    defdata_received(self, data):
        try:
            self.transport.write(b'HTTP/1.1 200 OKrnrnhello world')
        finally:
            self.transport.close()
 
loop = asyncio.get_event_loop()
server = loop.create_server(Server, '', 8000)
loop.run_until_complete(server)
loop.run_forever()

在我的电脑上大概 3000 QPS,比 Tornado 版慢了一些。此外,asyncio 的 transport 在 write 时不用 yield from,这点可能有些不一致。

asyncio 还有个高级版的 API:

import asyncio

@asyncio.coroutine
def handle(reader, writer):
    yield from reader.read(1024)
    writer.write(b'HTTP/1.1 200 OKrnrnhello world')
    yield from writer.drain()
    writer.close()

loop = asyncio.get_event_loop()
task = asyncio.start_server(handle, '', 8000, loop=loop)
server = loop.run_until_complete(task)
loop.run_forever()
importasyncio
 
@asyncio.coroutine
defhandle(reader, writer):
    yieldfromreader.read(1024)
    writer.write(b'HTTP/1.1 200 OKrnrnhello world')
    yieldfromwriter.drain()
    writer.close()
 
loop = asyncio.get_event_loop()
task = asyncio.start_server(handle, '', 8000, loop=loop)
server = loop.run_until_complete(task)
loop.run_forever()

在我的电脑上大概 2200 QPS。这下读写都要 yield from 了,一致性上来说会好些。

以框架的性能而言,其实都够用,开销都不超过 1 毫秒,而 web 请求一般都需要 10 毫秒的以上的处理时间。

于是顺便再测一下和 MySQL 的搭配,即在每个请求内调用一下 SELECT 1,然后输出返回值。

因为自己懒得写客户端了,于是就用现成的 tornado_mysql 和 aiomysql 来测试了。原理应该都差不多,发送写请求后就返回,等收到可读事件时再获取内容。

Tornado 版是这样:

from tornado.gen import coroutine
from tornado.ioloop import IOLoop
from tornado.tcpserver import TCPServer
from tornado_mysql import pools

class Server(TCPServer):
    @coroutine
    def handle_stream(self, stream, address):
        try:
            yield stream.read_bytes(1024, partial=True)
            cursor = yield POOL.execute(b'SELECT 1')
            data = cursor.fetchone()
            yield stream.write('HTTP/1.1 200 OKrnrn{0[0]}'.format(data).encode())  # Python 3.5 的 bytes 才能用 % 格式化
        finally:
            stream.close()

POOL = pools.Pool(
    dict(host='127.0.0.1', port=3306, user='root', passwd='123', db='mysql'),
    max_idle_connections=10,
    max_open_connections=10)

server = Server()
server.bind(8000)
server.start(1)
IOLoop.current().start()
fromtornado.genimportcoroutine
fromtornado.ioloopimportIOLoop
fromtornado.tcpserverimportTCPServer
fromtornado_mysqlimportpools
 
class Server(TCPServer):
    @coroutine
    defhandle_stream(self, stream, address):
        try:
            yieldstream.read_bytes(1024, partial=True)
            cursor = yieldPOOL.execute(b'SELECT 1')
            data = cursor.fetchone()
            yieldstream.write('HTTP/1.1 200 OKrnrn{0[0]}'.format(data).encode())  # Python 3.5 的 bytes 才能用 % 格式化
        finally:
            stream.close()
 
POOL = pools.Pool(
    dict(host='127.0.0.1', port=3306, user='root', passwd='123', db='mysql'),
    max_idle_connections=10,
    max_open_connections=10)
 
server = Server()
server.bind(8000)
server.start(1)
IOLoop.current().start()

在我的电脑上大概 680 QPS。

asyncio 版是这样:

import asyncio

import aiomysql

class Server(asyncio.Protocol):
    def connection_made(self, transport):
        self.transport = transport

class Server(asyncio.Protocol):
    def connection_made(self, transport):
        self.transport = transport

    def data_received(self, data):
        @asyncio.coroutine
        def handle():
            with (yield from pool) as conn:
                cursor = yield from conn.cursor()
                yield from cursor.execute(b'SELECT 1')
                result = yield from cursor.fetchone()
            try:
                self.transport.write('HTTP/1.1 200 OKrnrn{0[0]}'.format(result).encode())
            finally:
                self.transport.close()
        loop.create_task(handle())  # 或者 asyncio.async(handle())

@asyncio.coroutine
def get_pool():
    return(yield from aiomysql.create_pool(host='127.0.0.1', port=3306, user='root', password='123', loop=loop))

loop = asyncio.get_event_loop()
pool = loop.run_until_complete(get_pool())

server = loop.create_server(Server, '', 8000)
loop.run_until_complete(server)
loop.run_forever()
importasyncio
 
importaiomysql
 
class Server(asyncio.Protocol):
    defconnection_made(self, transport):
        self.transport = transport
 
class Server(asyncio.Protocol):
    defconnection_made(self, transport):
        self.transport = transport
 
    defdata_received(self, data):
        @asyncio.coroutine
        defhandle():
            with (yieldfrompool) as conn:
                cursor = yieldfromconn.cursor()
                yieldfromcursor.execute(b'SELECT 1')
                result = yieldfromcursor.fetchone()
            try:
                self.transport.write('HTTP/1.1 200 OKrnrn{0[0]}'.format(result).encode())
            finally:
                self.transport.close()
        loop.create_task(handle())  # 或者 asyncio.async(handle())
 
@asyncio.coroutine
defget_pool():
    return(yieldfromaiomysql.create_pool(host='127.0.0.1', port=3306, user='root', password='123', loop=loop))
 
loop = asyncio.get_event_loop()
pool = loop.run_until_complete(get_pool())
 
server = loop.create_server(Server, '', 8000)
loop.run_until_complete(server)
loop.run_forever()

在我的电脑上大概 1250 QPS,比 Tornado 版快了不少。不过写起来比较蛋疼,因为 data_received 方法里不能直接用 yield from。

用 cProfile 看了下,Tornado 版在 tornado.gen 和 functools 模块里花了不少时间,可能是异步调用过多了吧。但如果不做异步库的开发者,而只就使用者的体验而言,Tornado 会显得更加灵活和易用。不过 asyncio 的高级 API 应该也能提供类似的体验。

顺便再用底层 socket 模块写个服务器试试。先用 poll 看看,错误处理什么的就先不做了:

from functools import partial
import select
import socket

class Server:
    def __init__(self):
        self._sock = socket.socket()
        self._poll = select.poll()
        self._handlers = {}
        self._fd_events = {}

    def start(self):
        sock = self._sock
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.setblocking(0)
        sock.bind(('', 8000))
        sock.listen(100)

        handlers = self._handlers
        poll = self._poll
        self.add_handler(sock.fileno(), self._accept, select.POLLIN)

        while True:
            poll_events = poll.poll(1)
            for fd, event in poll_events:
                handler = handlers.get(fd)
                if handler:
                    handler()

    def _accept(self):
        for i in range(100):
            try:
                conn, address = self._sock.accept()
            except OSError:
                break
            else:
                conn.setblocking(0)
                fd = conn.fileno()
                self.add_handler(fd, partial(self._read, conn), select.POLLIN)

    def _read(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.recv(1024)
        except:
            conn.close()
            raise
        else:
            self.add_handler(fd, partial(self._write, conn), select.POLLOUT)

    def _write(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.send(b'HTTP 1.0 200 OKrnrnhello world')
        finally:
            conn.close()

    def add_handler(self, fd, handler, event):
        self._handlers[fd] = handler
        self.register(fd, event)

    def remove_handler(self, fd):
        self._handlers.pop(fd, None)
        self.unregister(fd)

    def register(self, fd, event):
        if fd in self._fd_events:
            raise IOError("fd %s already registered" % fd)
        self._poll.register(fd, event)
        self._fd_events[fd] = event

    def unregister(self, fd):
        event = self._fd_events.pop(fd, None)
        if event is not None:
            self._poll.unregister(fd)

Server().start()
fromfunctoolsimportpartial
importselect
importsocket
 
class Server:
    def__init__(self):
        self._sock = socket.socket()
        self._poll = select.poll()
        self._handlers = {}
        self._fd_events = {}
 
    defstart(self):
        sock = self._sock
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.setblocking(0)
        sock.bind(('', 8000))
        sock.listen(100)
 
        handlers = self._handlers
        poll = self._poll
        self.add_handler(sock.fileno(), self._accept, select.POLLIN)
 
        while True:
            poll_events = poll.poll(1)
            for fd, eventin poll_events:
                handler = handlers.get(fd)
                if handler:
                    handler()
 
    def_accept(self):
        for i in range(100):
            try:
                conn, address = self._sock.accept()
            exceptOSError:
                break
            else:
                conn.setblocking(0)
                fd = conn.fileno()
                self.add_handler(fd, partial(self._read, conn), select.POLLIN)
 
    def_read(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.recv(1024)
        except:
            conn.close()
            raise
        else:
            self.add_handler(fd, partial(self._write, conn), select.POLLOUT)
 
    def_write(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.send(b'HTTP 1.0 200 OKrnrnhello world')
        finally:
            conn.close()
 
    defadd_handler(self, fd, handler, event):
        self._handlers[fd] = handler
        self.register(fd, event)
 
    defremove_handler(self, fd):
        self._handlers.pop(fd, None)
        self.unregister(fd)
 
    defregister(self, fd, event):
        if fdin self._fd_events:
            raiseIOError("fd %s already registered" % fd)
        self._poll.register(fd, event)
        self._fd_events[fd] = event
 
    defunregister(self, fd):
        event = self._fd_events.pop(fd, None)
        if eventis not None:
            self._poll.unregister(fd)
 
Server().start()

在我的电脑上大概 7700 QPS,优势巨大。

再用 kqueue 试试(我用的是 OS X):

from functools import partial
import select
import socket

class Server:
    def __init__(self):
        self._sock = socket.socket()
        self._kqueue = select.kqueue()
        self._handlers = {}
        self._fd_events = {}

    def start(self):
        sock = self._sock
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.setblocking(0)
        sock.bind(('', 8000))
        sock.listen(100)

        self.add_handler(sock.fileno(), self._accept, select.KQ_FILTER_READ)
        handlers = self._handlers

        while True:
            kevents = self._kqueue.control(None, 1000, 1)
            for kevent in kevents:
                fd = kevent.ident
                handler = handlers.get(fd)
                if handler:
                    handler()

    def _accept(self):
        for i in range(100):
            try:
                conn, address = self._sock.accept()
            except OSError:
                break
            else:
                conn.setblocking(0)
                fd = conn.fileno()
                self.add_handler(fd, partial(self._read, conn), select.KQ_FILTER_READ)

    def _read(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.recv(1024)
        except:
            conn.close()
            raise
        else:
            self.add_handler(fd, partial(self._write, conn), select.KQ_FILTER_WRITE)

    def _write(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.send(b'HTTP 1.0 200 OKrnrnhello world')
        finally:
            conn.close()

    def add_handler(self, fd, handler, event):
        self._handlers[fd] = handler
        self.register(fd, event)

    def remove_handler(self, fd):
        self._handlers.pop(fd, None)
        self.unregister(fd)

    def register(self, fd, event):
        if fd in self._fd_events:
            raise IOError("fd %s already registered" % fd)
        self._control(fd, event, select.KQ_EV_ADD)
        self._fd_events[fd] = event

    def unregister(self, fd):
        event = self._fd_events.pop(fd, None)
        if event is not None:
            self._control(fd, event, select.KQ_EV_DELETE)

    def _control(self, fd, event, flags):
        change_list = (select.kevent(fd, event, flags),)
        self._kqueue.control(change_list, 0)

Server().start()
fromfunctoolsimportpartial
importselect
importsocket
 
class Server:
    def__init__(self):
        self._sock = socket.socket()
        self._kqueue = select.kqueue()
        self._handlers = {}
        self._fd_events = {}
 
    defstart(self):
        sock = self._sock
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.setblocking(0)
        sock.bind(('', 8000))
        sock.listen(100)
 
        self.add_handler(sock.fileno(), self._accept, select.KQ_FILTER_READ)
        handlers = self._handlers
 
        while True:
            kevents = self._kqueue.control(None, 1000, 1)
            for keventin kevents:
                fd = kevent.ident
                handler = handlers.get(fd)
                if handler:
                    handler()
 
    def_accept(self):
        for i in range(100):
            try:
                conn, address = self._sock.accept()
            exceptOSError:
                break
            else:
                conn.setblocking(0)
                fd = conn.fileno()
                self.add_handler(fd, partial(self._read, conn), select.KQ_FILTER_READ)
 
    def_read(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.recv(1024)
        except:
            conn.close()
            raise
        else:
            self.add_handler(fd, partial(self._write, conn), select.KQ_FILTER_WRITE)
 
    def_write(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.send(b'HTTP 1.0 200 OKrnrnhello world')
        finally:
            conn.close()
 
    defadd_handler(self, fd, handler, event):
        self._handlers[fd] = handler
        self.register(fd, event)
 
    defremove_handler(self, fd):
        self._handlers.pop(fd, None)
        self.unregister(fd)
 
    defregister(self, fd, event):
        if fdin self._fd_events:
            raiseIOError("fd %s already registered" % fd)
        self._control(fd, event, select.KQ_EV_ADD)
        self._fd_events[fd] = event
 
    defunregister(self, fd):
        event = self._fd_events.pop(fd, None)
        if eventis not None:
            self._control(fd, event, select.KQ_EV_DELETE)
 
    def_control(self, fd, event, flags):
        change_list = (select.kevent(fd, event, flags),)
        self._kqueue.control(change_list, 0)
 
Server().start()

在我的电脑上大概 7200 QPS,比 poll 版稍慢。不过因为只有 10 个并发连接,而且没有慢速网络的影响,所以 poll 的性能好并不奇怪。

再试试 Python 3.4 新增的 selectors 模块,它的 DefaultSelector 会自动选择所在平台最高效的实现,asyncio 就用到了这个模块。

import selectors
import socket

class Server:
    def __init__(self):
        self._sock = socket.socket()
        self._selector = selectors.DefaultSelector()

    def start(self):
        sock = self._sock
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.setblocking(0)
        sock.bind(('', 8000))
        sock.listen(100)

        selector = self._selector
        self.add_handler(sock.fileno(), self._accept, selectors.EVENT_READ)

        while True:
            events = selector.select(1)
            for key, event in events:
                handler, data = key.data
                if data:
                    handler(**data)
                else:
                    handler()

    def _accept(self):
        for i in range(100):
            try:
                conn, address = self._sock.accept()
            except OSError:
                break
            else:
                conn.setblocking(0)
                fd = conn.fileno()
                self.add_handler(fd, self._read, selectors.EVENT_READ, {'conn': conn})

    def _read(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.recv(1024)
        except:
            conn.close()
            raise
        else:
            self.add_handler(fd, self._write, selectors.EVENT_WRITE, {'conn': conn})

    def _write(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.send(b'HTTP 1.0 200 OKrnrnhello world')
        finally:
            conn.close()

    def add_handler(self, fd, handler, event, data=None):
        self._selector.register(fd, event, (handler, data))

    def remove_handler(self, fd):
        self._selector.unregister(fd)

Server().start()
importselectors
importsocket
 
class Server:
    def__init__(self):
        self._sock = socket.socket()
        self._selector = selectors.DefaultSelector()
 
    defstart(self):
        sock = self._sock
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.setblocking(0)
        sock.bind(('', 8000))
        sock.listen(100)
 
        selector = self._selector
        self.add_handler(sock.fileno(), self._accept, selectors.EVENT_READ)
 
        while True:
            events = selector.select(1)
            for key, eventin events:
                handler, data = key.data
                if data:
                    handler(**data)
                else:
                    handler()
 
    def_accept(self):
        for i in range(100):
            try:
                conn, address = self._sock.accept()
            exceptOSError:
                break
            else:
                conn.setblocking(0)
                fd = conn.fileno()
                self.add_handler(fd, self._read, selectors.EVENT_READ, {'conn': conn})
 
    def_read(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.recv(1024)
        except:
            conn.close()
            raise
        else:
            self.add_handler(fd, self._write, selectors.EVENT_WRITE, {'conn': conn})
 
    def_write(self, conn):
        fd = conn.fileno()
        self.remove_handler(fd)
        try:
            conn.send(b'HTTP 1.0 200 OKrnrnhello world')
        finally:
            conn.close()
 
    defadd_handler(self, fd, handler, event, data=None):
        self._selector.register(fd, event, (handler, data))
 
    defremove_handler(self, fd):
        self._selector.unregister(fd)
 
Server().start()

在我的电脑上大概 6100 QPS,成绩也还不错。

从这些测试来看,如果想自己实现一个舍弃了一些功能和兼容性的 Tornado,应该能比它稍快一点,不过似乎没多大必要。所以暂时不纠结性能了,还是从使用的便利性上来考虑。Tornado 可以用 yield 取代 callback,我们也来实现这个 feature。

实现前先得了解下 yield。

当一个函数内部出现了 yield 语句时,它就不再是一个单纯的函数了,而是一个生成器函数,调用它并不会执行它的代码,而是返回一个生成器。

调用这个生成器的 send 方法时,才会执行内部的代码。当执行到 yield 时,这个 send 方法就返回了,调用者可以得到其返回值。

send 方法在第一次调用时,参数必须为 None。Python 2 中可以用它的 next 方法,Python 3 中改成了 __next__ 方法,还可以用内置的 next 函数来调用。

send 方法可以被多次调用,参数会作为 yield 的返回值,回到生成器内上一次执行的地方,并继续执行下去。

当生成器的代码执行完时,会抛出一个 StopIteration 的异常。Python 3.3 开始可以在生成器里使用 return,返回值可以从 StopIteration 异常的 value 属性获取。

for … in … 循环会自动捕获 StopIteration 异常,并作为循环停止的条件。

由此可见,yield 可以用于跳转。而我们要做的,则是在遇到 IO 请求时,用 yield 返回 IO loop;当事件发生时,找到对应的生成器,用 send 方法继续执行即可。为了简单起见,我就在 poll 版的基础上进行改造了:

from collections import deque
import select
import socket
from types import GeneratorType

class Stream:
    def __init__(self, sock, loop):
        sock.setblocking(0)
        self._sock = sock
        self._loop = loop

    def close(self):
        self._sock.close()

    def read(self, size=1024):
        sock = self._sock
        fd = sock.fileno()
        try:
            data = sock.recv(size)
        except OSError as e:
            if e.errno == socket.EAGAIN or socket.EWOULDBLOCK:
                self._loop.add_handler(fd, self.read(size), select.POLLIN)
                yield
            else:
                raise
        else:
            return data
        finally:
            self._loop.remove_handler(fd)

    def write(self, data):
        sock = self._sock
        fd = sock.fileno()
        try:
            try:
                sent_bytes = sock.send(data)
            except OSError as e:
                if e.errno not in (socket.EAGAIN, socket.EWOULDBLOCK):
                    raise
            else:
                if sent_bytes == len(data):
                    return
                data = data[sent_bytes:]

            self._loop.add_handler(fd, self.write(data), select.POLLOUT)
            yield

            while data:
                try:
                    sent_bytes = sock.send(data)
                except OSError as e:
                    if e.errno not in (socket.EAGAIN, socket.EWOULDBLOCK):
                        raise
                else:
                    if sent_bytes == len(data):
                        return
                    data = data[sent_bytes:]
                yield
        finally:
            self._loop.remove_handler(fd)

class IOLoop:
    def __init__(self):
        self._poll = select.poll()
        self._handlers = {}
        self._fd_events = {}

    def start(self):
        handlers = self._handlers
        poll = self._poll

        while True:
            poll_events = poll.poll(1)
            for fd, event in poll_events:
                handler = handlers.get(fd)
                if handler:
                    if callable(handler):
                        handler()
                    else:
                        stack = handler
                        while True:
                            generator, value = stack[-1]
                            try:
                                value = generator.send(value)
                                if isinstance(value, GeneratorType):
                                    stack.append([value, None])
                                else:
                                    break
                            except StopIteration as e:
                                stack.pop()
                                if stack:
                                    stack[-1][-1] = e.value
                                else:
                                    break

    def add_handler(self, fd, handler, event):
        if isinstance(handler, GeneratorType):
            self._handlers[fd] = deque([[handler, None]])
        else:
            self._handlers[fd] = handler
        self.register(fd, event)

    def remove_handler(self, fd):
        self._handlers.pop(fd, None)
        self.unregister(fd)

    def update_handler(self, fd, event):
        self.modify(fd, event)

    def register(self, fd, event):
        if fd in self._fd_events:
            raise IOError("fd %s already registered" % fd)
        self._poll.register(fd, event)
        self._fd_events[fd] = event

    def unregister(self, fd):
        event = self._fd_events.pop(fd, None)
        if event is not None:
            self._poll.unregister(fd)

    def modify(self, fd, event):
        self._poll.modify(fd, event)
        self._fd_events[fd] = event

class Server:
    def __init__(self):
        self._sock = socket.socket()
        self._loop = IOLoop()
        self._stream = Stream(self._sock, self._loop)

    def start(self):
        sock = self._sock
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.setblocking(0)
        sock.bind(('', 8000))
        sock.listen(100)

        self._loop.add_handler(sock.fileno(), self._accept, select.POLLIN)
        self._loop.start()

    def _accept(self):
        for i in range(100):
            try:
                conn, address = self._sock.accept()
            except OSError:
                break
            else:
                stream = Stream(conn, self._loop)
                fd = conn.fileno()
                self._loop.add_handler(fd, self._handle(stream), select.POLLIN)

    def _handle(self, stream):
        yield stream.read()
        yield stream.write(b'HTTP 1.0 200 OK\r\n\r\nhello world')

Server().start()
fromcollectionsimportdeque
importselect
importsocket
fromtypesimportGeneratorType
 
class Stream:
    def__init__(self, sock, loop):
        sock.setblocking(0)
        self._sock = sock
        self._loop = loop
 
    defclose(self):
        self._sock.close()
 
    defread(self, size=1024):
        sock = self._sock
        fd = sock.fileno()
        try:
            data = sock.recv(size)
        exceptOSErroras e:
            if e.errno == socket.EAGAINor socket.EWOULDBLOCK:
                self._loop.add_handler(fd, self.read(size), select.POLLIN)
                yield
            else:
                raise
        else:
            return data
        finally:
            self._loop.remove_handler(fd)
 
    defwrite(self, data):
        sock = self._sock
        fd = sock.fileno()
        try:
            try:
                sent_bytes = sock.send(data)
            exceptOSErroras e:
                if e.errnonot in (socket.EAGAIN, socket.EWOULDBLOCK):
                    raise
            else:
                if sent_bytes == len(data):
                    return
                data = data[sent_bytes:]
 
            self._loop.add_handler(fd, self.write(data), select.POLLOUT)
            yield
 
            while data:
                try:
                    sent_bytes = sock.send(data)
                exceptOSErroras e:
                    if e.errnonot in (socket.EAGAIN, socket.EWOULDBLOCK):
                        raise
                else:
                    if sent_bytes == len(data):
                        return
                    data = data[sent_bytes:]
                yield
        finally:
            self._loop.remove_handler(fd)
 
class IOLoop:
    def__init__(self):
        self._poll = select.poll()
        self._handlers = {}
        self._fd_events = {}
 
    defstart(self):
        handlers = self._handlers
        poll = self._poll
 
        while True:
            poll_events = poll.poll(1)
            for fd, eventin poll_events:
                handler = handlers.get(fd)
                if handler:
                    if callable(handler):
                        handler()
                    else:
                        stack = handler
                        while True:
                            generator, value = stack[-1]
                            try:
                                value = generator.send(value)
                                if isinstance(value, GeneratorType):
                                    stack.append([value, None])
                                else:
                                    break
                            exceptStopIterationas e:
                                stack.pop()
                                if stack:
                                    stack[-1][-1] = e.value
                                else:
                                    break
 
    defadd_handler(self, fd, handler, event):
        if isinstance(handler, GeneratorType):
            self._handlers[fd] = deque([[handler, None]])
        else:
            self._handlers[fd] = handler
        self.register(fd, event)
 
    defremove_handler(self, fd):
        self._handlers.pop(fd, None)
        self.unregister(fd)
 
    defupdate_handler(self, fd, event):
        self.modify(fd, event)
 
    defregister(self, fd, event):
        if fdin self._fd_events:
            raiseIOError("fd %s already registered" % fd)
        self._poll.register(fd, event)
        self._fd_events[fd] = event
 
    defunregister(self, fd):
        event = self._fd_events.pop(fd, None)
        if eventis not None:
            self._poll.unregister(fd)
 
    defmodify(self, fd, event):
        self._poll.modify(fd, event)
        self._fd_events[fd] = event
 
class Server:
    def__init__(self):
        self._sock = socket.socket()
        self._loop = IOLoop()
        self._stream = Stream(self._sock, self._loop)
 
    defstart(self):
        sock = self._sock
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        sock.setblocking(0)
        sock.bind(('', 8000))
        sock.listen(100)
 
        self._loop.add_handler(sock.fileno(), self._accept, select.POLLIN)
        self._loop.start()
 
    def_accept(self):
        for i in range(100):
            try:
                conn, address = self._sock.accept()
            exceptOSError:
                break
            else:
                stream = Stream(conn, self._loop)
                fd = conn.fileno()
                self._loop.add_handler(fd, self._handle(stream), select.POLLIN)
 
    def_handle(self, stream):
        yieldstream.read()
        yieldstream.write(b'HTTP 1.0 200 OK\r\n\r\nhello world')
 
Server().start()

在我的电脑上大概 5300 QPS。

虽然成绩比较尴尬,但毕竟用起来比前一个版本好多了。至于慢的原因,我估计是自己维护了一个堆栈的原因(也可能是有什么 bug,毕竟写这个感觉太跳跃了,能运行起来就谢天谢地了)。实现时做了两点假设:

  1. handler 为 generator 时,视为异步方法。
  2. 在异步方法中 yield None 时,视为等待 IO;yield / yield from 异步方法时,则是等待方法返回。

实现细节也没什么好说的了,只是觉得在实现 Stream 的 read / write 方法时,调用 IOLoop.add_handler 方法不太优雅。其实可以直接 yield 一个 fd 和 event,在 IOLoop.start 方法中再去注册。不过这个重构其实蛮小的,我就不再贴一次代码了,感兴趣的可以自己试试。

(责任编辑:好模板)
顶一下
(0)
0%
踩一下
(0)
0%
------分隔线----------------------------
栏目列表
热点内容