IO多路复用及ThreadingTCPServer源码阅读

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了IO多路复用及ThreadingTCPServer源码阅读相关的知识,希望对你有一定的参考价值。

IO多路复用

socket模块是阻塞的,通过socket建立的服务端可以接收多个请求,但只能同时处理一个请求,其他请求都被阻塞。可以通过IO多路复用解决这个问题,socketserver内部使用的就是IO多路复用以及多线程和多进程。

IO多路复用就是指通过一种机制可以监视多个描述符,一旦某个描述符就绪(读写就绪),就通知程序进行相应读写操作。

Linux中的select,poll,epoll都是IO多路复用的机制。

select
 
select最早于1983年出现在4.2BSD中,它通过一个select()系统调用来监视多个文件描述符的数组,当select()返回后,该数组中就绪的文件描述符便会被内核修改标志位,使得进程可以获得这些文件描述符从而进行后续的读写操作。
select目前几乎在所有的平台上支持,其良好跨平台支持也是它的一个优点,事实上从现在看来,这也是它所剩不多的优点之一。
select的一个缺点在于单个进程能够监视的文件描述符的数量存在最大限制,在Linux上一般为1024,不过可以通过修改宏定义甚至重新编译内核的方式提升这一限制。
另外,select()所维护的存储大量文件描述符的数据结构,随着文件描述符数量的增大,其复制的开销也线性增长。同时,由于网络响应时间的延迟使得大量TCP连接处于非活跃状态,但调用select()会对所有socket进行一次线性扫描,所以这也浪费了一定的开销。
 
poll
 
poll在1986年诞生于System V Release 3,它和select在本质上没有多大差别,但是poll没有最大文件描述符数量的限制。
poll和select同样存在一个缺点就是,包含大量文件描述符的数组被整体复制于用户态和内核的地址空间之间,而不论这些文件描述符是否就绪,它的开销随着文件描述符数量的增加而线性增大。
另外,select()和poll()将就绪的文件描述符告诉进程后,如果进程没有对其进行IO操作,那么下次调用select()和poll()的时候将再次报告这些文件描述符,所以它们一般不会丢失就绪的消息,这种方式称为水平触发(Level Triggered)。
 
epoll
 
直到Linux2.6才出现了由内核直接支持的实现方法,那就是epoll,它几乎具备了之前所说的一切优点,被公认为Linux2.6下性能最好的多路I/O就绪通知方法。
epoll可以同时支持水平触发和边缘触发(Edge Triggered,只告诉进程哪些文件描述符刚刚变为就绪状态,它只说一遍,如果我们没有采取行动,那么它将不会再次告知,这种方式称为边缘触发),理论上边缘触发的性能要更高一些,但是代码实现相当复杂。
epoll同样只告知那些就绪的文件描述符,而且当我们调用epoll_wait()获得就绪文件描述符时,返回的不是实际的描述符,而是一个代表就绪描述符数量的值,你只需要去epoll指定的一个数组中依次取得相应数量的文件描述符即可,这里也使用了内存映射(mmap)技术,这样便彻底省掉了这些文件描述符在系统调用时复制的开销。
另一个本质的改进在于epoll采用基于事件的就绪通知方式。在select/poll中,进程只有在调用一定的方法后,内核才对所有监视的文件描述符进行扫描,而epoll事先通过epoll_ctl()来注册一个文件描述符,一旦基于某个文件描述符就绪时,内核会采用类似callback的回调机制,迅速激活这个文件描述符,当进程调用epoll_wait()时便得到通知。

Python

python中有一个select模块,其中提供了:select,poll,epoll三个方法,分别调用系统的select,poll和epoll从而实现IO多路复用。

1 Windows Python:
2     提供: select
3 Mac Python:
4     提供: select
5 Linux Python:
6     提供: select、poll、epoll

谈一下select方法:

句柄列表11, 句柄列表22, 句柄列表33 = select.select(句柄序列1, 句柄序列2, 句柄序列3, 超时时间)
 
参数: 可接受四个参数(前三个必须)
返回值:三个列表
 
select方法用来监视文件句柄,如果句柄发生变化,则获取该句柄。
1、当 参数1 序列中的句柄发生可读时(accetp和read),则获取发生变化的句柄并添加到 返回值1 序列中
2、当 参数2 序列中含有句柄时,则将该序列中所有的句柄添加到 返回值2 序列中
3、当 参数3 序列中的句柄发生错误时,则将该发生错误的句柄添加到 返回值3 序列中
4、当 超时时间 未设置,则select会一直阻塞,直到监听的句柄发生变化
   当 超时时间 = 1时,那么如果监听的句柄均无任何变化,则select会阻塞 1 秒,之后返回三个空列表,如果监听的句柄有变化,则直接执行。

例子:
通过select和socket模块实现的伪IO多路复用,客户端输入什么,服务端就返回response + 客户端输入内容

#!/usr/bin/env python
# coding=utf-8

import socket
import select

sk = socket.socket()
sk.bind((\'127.0.0.1\', 9999, ))
sk.listen(5)
inputs = [sk, ]
messages = {}
# messages = {
#  hexm: [消息1,消息2】
#  zhuxj: [消息1,消息2】
# }
# inputs = [sk, hexm, zhuxj, ly] # 服务端sk,客户端对象
outputs = []
while True:
    # sk监听哪个对象,只要有变化,新连接来了,rlist = [sk], 否则rlist=[], 如果一个连接sk来了,rlist=[sk],如果两个sk1,sk2同时来了,rlist=【sk1,sk2】
    # 1 超时时间

    # 监听sk(服务器端)对象,如果sk对象发生变化,表示有客户端来连接了,此时rlist值为服务端[sk]
    # 监听conn对象,如果conn发生变化,表示客户端有新消息发送过来,此时rlist值为[客户端]
    rlist, wlist, e = select.select(inputs, outputs, [], 1)
    # wlist所有给我发消息的人
    # r就是sk
    # rlist = [hexm]
    # rlist = [zhuxij, ly]
    print(len(inputs), len(rlist), len(outputs)) # inputs里面多少对象, rlist表示多少客户端对象发生变化
    # 只有连接进来才for循环,不然rlist一直为空
    for r in rlist:
        if r == sk:
            # 新客户端来连接
            print(r)
            # conn 是socket对象, 每个客户端的socket对象
            conn, address = r.accept()
            conn.sendall(bytes(\'hello\', encoding=\'utf-8\'))
            # 新客户来连接,
            messages[conn] = []
            inputs.append(conn)
        else:
            print(\'-----\')
            try:
                ret = r.recv(1024)
                if not ret:  # 接受空消息,主动抛出异常,断开连接
                    raise Exception(\'断开连接\')
                else:
                    outputs.append(r)
                    # 客户发的消息加入这个客户的消息列表
                    messages[r].append(ret)
            # 有人发消息
            except Exception as e:
                # 断开连接,移除
                inputs.remove(r)
                del messages[r]  # 删除这个用户消息

    # 所有给我发过消息的人
    for w in wlist:
        msg = messages[w].pop()
        resp = msg + bytes(\'response\', encoding=\'utf-8\')
        w.sendall(resp)
        outputs.remove(w)
server
#!/usr/bin/env python
# coding=utf-8

import socket
sk = socket.socket()
sk.connect(("127.0.0.1", 9999))
data = sk.recv(1024)
print(data)
while True:
    ret = input(">>>")
    print(ret)
    sk.sendall(bytes(ret, encoding=\'utf-8\'))
    print(sk.recv(1024))
sk.close()
View Code

此处的Socket服务端相比与原生的Socket,他支持当某一个请求不再发送数据时,服务器端不会等待而是可以去处理其他请求的数据。但是,如果每个请求的耗时比较长时,select版本的服务器端也无法完成同时操作。

ThreadingTCPServer

 ThreadingTCPServer实现的socket服务器内部会为每个client创建一个线程,该线程用来和客户端进行交互。

使用ThreadingTCPServer

  • 创建一个继承自 SocketServer.BaseRequestHandler 的类
  • 类中必须定义一个名称为 handle 的方法
  • 启动ThreadingTCPServer
#!/usr/bin/env python
# coding=utf-8

import socketserver
import subprocess

class MyServer(socketserver.BaseRequestHandler):

    def handle(self):
        self.request.sendall(bytes(\'welcome\', encoding=\'utf-8\'))
        while True:
            data = self.request.recv(1024)
            if len(data) == 0: break
            print(self.client_address, data.decode())
            cmd = subprocess.Popen(data, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            cmd_res = cmd.stdout.read()
            if not cmd_res:
                cmd_res = cmd.stderr.read()
            if len(cmd_res) == 0:  # cmd has not output
                cmd_res = b"cmd has output"
            self.request.send(cmd_res)

if __name__ == \'__main__\':
    server = socketserver.ThreadingTCPServer((\'0.0.0.0\', 8000), MyServer)
    server.serve_forever()
ThreadingTCPServer
#!/usr/bin/env python
# coding=utf-8

import socket

ip_port = (\'127.0.0.1\', 8000)

s = socket.socket()
s.connect(ip_port)
welcome_msg = s.recv(1024)
print(welcome_msg.decode())
while True:
    send_data = input(">> input message: ").strip()
    if len(send_data) == 0: continue
    s.send(bytes(send_data, encoding=\'utf-8\'))  # 发送输入的命令
    tag = s.recv(1024)  # 接收tag
    tag = str(tag, encoding=\'utf-8\')
    if tag == \'exit\':
        break
    elif tag.startswith(\'Ready\'):  # 如果收到Ready和包长度
        msg_size = int(tag.split(\'|\')[-1])
    else:
        print(tag)
        continue
    start_tag = \'Start\'
    s.send(bytes(start_tag, encoding=\'utf-8\'))  # 发送开始标志
    recv_size = 0
    recv_data = b\'\'
    while recv_size < msg_size:
        recv_msg = s.recv(1024)
        recv_data += recv_msg
        recv_size += len(recv_msg)
        print(\'MSG SIZE %s RECV_SIZE %s\' % (msg_size, recv_size))

    print(str(recv_data, encoding = \'utf-8\'))
s.close()
View Code

源码剖析

ThreadingTCPServer的类关系图如下:

 

内部调用流程:

  • 执行TCPServer.__init__方法,创建服务端socket对象并绑定IP和端口并监听。
  • 执行BaseServer.__init__方法,将MyServer类赋值给self.RequestHandlerClass
  • 执行BaseServer.server_forever方法,while循环监听是否有客户端请求到达。
  • 当客户端连接到服务器,执行ThreadingMixIn.process_request 方法,创建一个 “线程” 用来处理请求
  • 执行 ThreadingMixIn.process_request_thread 方法
  • 执行 BaseServer.finish_request 方法,执行 self.RequestHandlerClass()  即:执行 自定义 MyRequestHandler 的构造方法(自动调用基类BaseRequestHandler的构造方法,在该构造方法中又会调用 MyRequestHandler的handle方法)

相关源码:

class TCPServer(BaseServer):

    """Base class for various socket-based server classes.

    Defaults to synchronous IP stream (i.e., TCP).

    Methods for the caller:

    - __init__(server_address, RequestHandlerClass, bind_and_activate=True)
    - serve_forever(poll_interval=0.5)
    - shutdown()
    - handle_request()  # if you don\'t use serve_forever()
    - fileno() -> int   # for selector

    Methods that may be overridden:

    - server_bind()
    - server_activate()
    - get_request() -> request, client_address
    - handle_timeout()
    - verify_request(request, client_address)
    - process_request(request, client_address)
    - shutdown_request(request)
    - close_request(request)
    - handle_error()

    Methods for derived classes:

    - finish_request(request, client_address)

    Class variables that may be overridden by derived classes or
    instances:

    - timeout
    - address_family
    - socket_type
    - request_queue_size (only for stream sockets)
    - allow_reuse_address

    Instance variables:

    - server_address
    - RequestHandlerClass
    - socket

    """

    address_family = socket.AF_INET

    socket_type = socket.SOCK_STREAM

    request_queue_size = 5

    allow_reuse_address = False

    def __init__(self, server_address, RequestHandlerClass, bind_and_activate=True):
        """Constructor.  May be extended, do not override."""
        BaseServer.__init__(self, server_address, RequestHandlerClass)
        self.socket = socket.socket(self.address_family,
                                    self.socket_type)
        if bind_and_activate:
            try:
                self.server_bind()
                self.server_activate()
            except:
                self.server_close()
                raise

    def server_bind(self):
        """Called by constructor to bind the socket.

        May be overridden.

        """
        if self.allow_reuse_address:
            self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.socket.bind(self.server_address)
        self.server_address = self.socket.getsockname()

    def server_activate(self):
        """Called by constructor to activate the server.

        May be overridden.

        """
        self.socket.listen(self.request_queue_size)

    def server_close(self):
        """Called to clean-up the server.

        May be overridden.

        """
        self.socket.close()

    def fileno(self):
        """Return socket file number.

        Interface required by selector.

        """
        return self.socket.fileno()

    def get_request(self):
        """Get the request and client address from the socket.

        May be overridden.

        """
        return self.socket.accept()

    def shutdown_request(self, request):
        """Called to shutdown and close an individual request."""
        try:
            #explicitly shutdown.  socket.close() merely releases
            #the socket and waits for GC to perform the actual close.
            request.shutdown(socket.SHUT_WR)
        except OSError:
            pass #some platforms may raise ENOTCONN here
        self.close_request(request)

    def close_request(self, request):
        """Called to clean up an individual request."""
        request.close()
TCPServer
class BaseServer:

    """Base class for server classes.

    Methods for the caller:

    - __init__(server_address, RequestHandlerClass)
    - serve_forever(poll_interval=0.5)
    - shutdown()
    - handle_request()  # if you do not use serve_forever()
    - fileno() -> int   # for selector

    Methods that may be overridden:

    - server_bind()
    - server_activate()
    - get_request() -> request, client_address
    - handle_timeout()
    - verify_request(request, client_address)
    - server_close()
    - process_request(request, client_address)
    - shutdown_request(request)
    - close_request(request)
    - service_actions()
    - handle_error()

    Methods for derived classes:

    - finish_request(request, client_address)

    Class variables that may be overridden by derived classes or
    instances:

    - timeout
    - address_family
    - socket_type
    - allow_reuse_address

    Instance variables:

    - RequestHandlerClass
    - socket

    """

    timeout = None

    def __init__(self, server_address, RequestHandlerClass):
        """Constructor.  May be extended, do not override."""
        self.server_address = server_address
        self.RequestHandlerClass = RequestHandlerClass
        self.__is_shut_down = threading.Event()
        self.__shutdown_request = False

    def server_activate(self):
        """Called by constructor to activate the server.

        May be overridden.

        """
        pass

    def serve_forever(self, poll_interval=0.5):
        """Handle one request at a time until shutdown.

        Polls for shutdown every poll_interval seconds. Ignores
        self.timeout. If you need to do periodic tasks, do them in
        another thread.
        """
        self.__is_shut_down.clear()
        try:
            # XXX: Consider using another file descriptor or connecting to the
            # socket to wake this up instead of polling. Polling reduces our
            # responsiveness to a shutdown request and wastes cpu at all other
            # times.
            with _ServerSelector() as selector:
                selector.register(self, selectors.EVENT_READ)

                while not self.__shutdown_request:
                    ready = selector.select(poll_interval)
                    if ready:
                        self._handle_request_noblock()

                    self.service_actions()
        finally:
            self.__shutdown_request = False
            self.__is_shut_down.set()

    def shutdown(self):
        """Stops the serve_forever loop.

        Blocks until the loop has finished. This must be called while
        serve_forever() is running in another thread, or it will
        deadlock.
        """
        self.__shutdown_request = True
        self.__is_shut_down.wait()

    def service_actions(self):
        """Called by the serve_forever() loop.

        May be overridden by a subclass / Mixin to implement any code that
        needs to be run during the loop.
        """
        pass

    # The distinction between handling, getting, processing and finishing a
    # request is fairly arbitrary.  Remember:
    #
    # - handle_request() is the top-level call.  It calls selector.select(),
    #   get_request(), verify_request() and process_request()
    # - get_request() is different for stream or datagram sockets
    # - process_request() is the place that may fork a new process or create a
    #   new thread to finish the request
    # - finish_request() instantiates the request handler class; this
    #   constructor will handle the request all by itself

    def handle_request(self):
        """Handle one request, possibly blocking.

        Respects self.timeout.
        """
        # Support people who used socket.settimeout() to escape
        # handle_request before self.timeout was available.
        timeout = self.socket.gettimeout()
        if timeout is None:
            timeout = self.timeout
        elif self.timeout is not None:
            timeout = min(timeout, self.timeout)
        if timeout is not None:
            deadline = time() + timeout

        # Wait until a request arrives or the timeout expires - the loop is
        # necessary to accommodate early wakeups due to EINTR.
        with _ServerSelector() as selector:
            selector.register(self, selectors.EVENT_READ)

            while True:
                ready = selector.select(timeout)
                if ready:
                    return self._handle_request_noblock()
                else:
                    if timeout is not None:
                        timeout = deadline - time()
                        if timeout < 0:
                            return self.handle_timeout()

    def _handle_request_noblock(self):
        """Handle one request, without blocking.

        I assume that selector.select() has returned that the socket is
        readable before this function was called, so there should be no risk of
        blocking in get_request().
        """
        try:
            request, client_address = self.get_request()
        except OSError:
            return
        if self.verify_request(request, client_address):
            try:
                self.process_request(request, client_address)
            except:
                self.handle_error(request, client_address)
                self.shutdown_request(request)
        else:
            self.shutdown_request(request)

    def handle_timeout(self):
        """Called if no new request arrives within self.timeout.

        Overridden by ForkingMixIn.
        """
        pass

    def verify_request(self, request, client_address):
        """Verify the request.  May be overridden.

        Return True if we should proceed with this request.

        """
        return True

    def process_request(self, request, client_address):
        """Call finish_request.

        Overridden by ForkingMixIn and ThreadingMixIn.

        """
        self.finish_request(request, client_address)
        self.shutdown_request(request)

    def server_close(self):
        """Called to clean-up the server.

        May be overridden.

        """
        pass

    def finish_request(self, request, client_address):
        """Finish one request by instantiating RequestHandlerClass."""
        self.RequestHandlerClass(request, client_address, self)

    def shutdown_request(self, request):
        """Called to shutdown and close an individual request."""
        self.close_request(request)

    def close_request(self, request):
        """Called to clean up an individual request."""
        pass

    def handle_error(self, request, client_address):
        """Handle an error gracefully.  May be overridden.

        The default is to print a traceback and continue.

        """
        print(\'-\'*40)
        print(\'Exception happened during processing of request from\', end=\' \')
        print(client_address)
        import traceback
        traceback.print_exc() # XXX But this goes to stderr!
        print(\'-\'*40)
BaseServer
class ThreadingMixIn:
    """Mix-in class to handle each request in a new thread."""

    # Decides how threads will act upon termination of the
    # main process
    daemon_threads = False

    def process_request_thread(self, request, client_address):
        """Same as in BaseServer but as a thread.

        In addition, exception handling is done here.

        """
        try:
            self.finish_request(request, client_address)
            self.shutdown_request(request)
        except:
            self.handle_error(request, client_address)
            self.shutdown_request(request)

    def process_request(self, request, client_address):
        """Start a new thread to process the request."""
        t = threading.Thread(target = self.process_request_thread,
                             args = (request, client_address))
        t.daemon = self.daemon_threads
        t.start()
ThreadingMixIn
class ThreadingTCPServer(ThreadingMixIn, TCPServer): pass
ThreadingTCPServer
class BaseRequestHandler:

    """Base class for request handler classes.

    This class is instantiated for each request to be handled.  The
    constructor sets the instance variables request, client_address
    and server, and then calls the handle() method.  To implement a
    specific service, all you need to do is to derive a class which
    defines a handle() method.

    The handle() method can find the request as self.request, the
    client address as self.client_address, and the server (in case it
    needs access to per-server information) as self.server.  Since a
    separate instance is created for each request, the handle() method
    can define other arbitrary instance variables.

    """

    def __init__(self, request, client_address, server):
        self.request = request
        self.client_address = client_address
        self.server = server
        self.setup()
        try:
            self.handle()
        finally:
            self.finish()

    def setup(self):
        pass

    def handle(self):
        pass

    def finish(self):
        pass
BaseRequestHandler

 

以上是关于IO多路复用及ThreadingTCPServer源码阅读的主要内容,如果未能解决你的问题,请参考以下文章

PythonSocket编程IO多路复用SocketServer

Python----Socket编程IO多路复用SocketServer

python-- IO多路复用(selectpollepoll)介绍及实现

20第七周-网络编程 - IO多路复用及selectpollepoll模式详解

PHP实现系统编程 --- 网络Socket及IO多路复用网摘

Linux IO多路复用之epoll网络编程及源码(转)