python 用Python实现的微型redis-like服务器

Posted

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了python 用Python实现的微型redis-like服务器相关的知识,希望对你有一定的参考价值。

# github: https://github.com/coleifer/simpledb

try:
    import gevent
    from gevent import socket
    from gevent.pool import Pool
    from gevent.server import StreamServer
except ImportError:
    import socket
    Pool = StreamServer = None

from collections import defaultdict
from collections import deque
from collections import namedtuple
from io import BytesIO
from socket import error as socket_error
import datetime
import heapq
import importlib
import json
import logging
import optparse
import sys
try:
    import socketserver as ss
except ImportError:
    import SocketServer as ss


__version__ = '0.1.0'

logger = logging.getLogger(__name__)


class ThreadedStreamServer(object):
    def __init__(self, address, handler):
        self.address = address
        self.handler = handler

    def serve_forever(self):
        handler = self.handler
        class RequestHandler(ss.BaseRequestHandler):
            def handle(self):
                return handler(self.request, self.client_address)

        class ThreadedServer(ss.ThreadingMixIn, ss.TCPServer):
            allow_reuse_address = True

        self.stream_server = ThreadedServer(self.address, RequestHandler)
        self.stream_server.serve_forever()


class CommandError(Exception):
    def __init__(self, message):
        self.message = message
        super(CommandError, self).__init__()


"""
Protocol is based on Redis wire protocol.
Client sends requests as an array of bulk strings.
Server replies, indicating response type using the first byte:
* "+" - simple string
* "-" - error
* ":" - integer
* "$" - bulk string
* "^" - bulk unicode string
* "@" - json string (uses bulk string rules)
* "*" - array
* "%" - dict
Simple strings: "+string content\r\n"  <-- cannot contain newlines
Error: "-Error message\r\n"
Integers: ":1337\r\n"
Bulk String: "$number of bytes\r\nstring data\r\n"
* Empty string: "$0\r\n\r\n"
* NULL: "$-1\r\n"
Array: "*number of elements\r\n...elements..."
* Empty array: "*0\r\n"
And a new data-type, dictionaries: "%number of elements\r\n...elements..."
"""
if sys.version_info[0] == 3:
    unicode = str
    basestring = (bytes, str)


def encode(s):
    if isinstance(s, unicode):
        return s.encode('utf-8')
    elif isinstance(s, bytes):
        return s
    else:
        return str(s).encode('utf-8')


def decode(s):
    if isinstance(s, unicode):
        return s
    elif isinstance(s, bytes):
        return s.decode('utf-8')
    else:
        return str(s)


Error = namedtuple('Error', ('message',))


class ProtocolHandler(object):
    def __init__(self):
        self.handlers = {
            b'+': self.handle_simple_string,
            b'-': self.handle_error,
            b':': self.handle_integer,
            b'$': self.handle_string,
            b'^': self.handle_unicode,
            b'@': self.handle_json,
            b'*': self.handle_array,
            b'%': self.handle_dict,
        }

    def handle_simple_string(self, socket_file):
        return socket_file.readline().rstrip(b'\r\n')

    def handle_error(self, socket_file):
        return Error(socket_file.readline().rstrip(b'\r\n'))

    def handle_integer(self, socket_file):
        number = socket_file.readline().rstrip(b'\r\n')
        if b'.' in number:
            return float(number)
        return int(number)

    def handle_string(self, socket_file):
        length = int(socket_file.readline().rstrip(b'\r\n'))
        if length == -1:
            return None
        length += 2
        return socket_file.read(length)[:-2]

    def handle_unicode(self, socket_file):
        return self.handle_string(socket_file).decode('utf-8')

    def handle_json(self, socket_file):
        return json.loads(self.handle_string(socket_file))

    def handle_array(self, socket_file):
        num_elements = int(socket_file.readline().rstrip(b'\r\n'))
        return [self.handle_request(socket_file) for _ in range(num_elements)]

    def handle_dict(self, socket_file):
        num_items = int(socket_file.readline().rstrip(b'\r\n'))
        elements = [self.handle_request(socket_file)
                    for _ in range(num_items * 2)]
        return dict(zip(elements[::2], elements[1::2]))

    def handle_request(self, socket_file):
        first_byte = socket_file.read(1)
        if not first_byte:
            raise EOFError()

        try:
            return self.handlers[first_byte](socket_file)
        except KeyError:
            rest = socket_file.readline().rstrip(b'\r\n')
            return first_byte + rest

    def write_response(self, socket_file, data):
        buf = BytesIO()
        self._write(buf, data)
        buf.seek(0)
        socket_file.write(buf.getvalue())
        socket_file.flush()

    def _write(self, buf, data):
        if isinstance(data, bytes):
            buf.write(b'$%d\r\n%s\r\n' % (len(data), data))
        elif isinstance(data, unicode):
            bdata = data.encode('utf-8')
            buf.write(b'^%d\r\n%s\r\n' % (len(bdata), bdata))
        elif data is True or data is False:
            buf.write(b':%d\r\n' % (1 if data else 0))
        elif isinstance(data, (int, float)):
            buf.write(b':%d\r\n' % data)
        elif isinstance(data, Error):
            buf.write(b'-%s\r\n' % encode(data.message))
        elif isinstance(data, (list, tuple)):
            buf.write(b'*%d\r\n' % len(data))
            for item in data:
                self._write(buf, item)
        elif isinstance(data, dict):
            buf.write(b'%%%d\r\n' % len(data))
            for key in data:
                self._write(buf, key)
                self._write(buf, data[key])
        elif data is None:
            buf.write(b'$-1\r\n')
        elif isinstance(data, datetime.datetime):
            self._write(buf, str(data))


class Shutdown(Exception): pass


class QueueServer(object):
    def __init__(self, host='127.0.0.1', port=31337, max_clients=64,
                 use_gevent=True):
        self._host = host
        self._port = port
        self._max_clients = max_clients
        if use_gevent:
            if Pool is None or StreamServer is None:
                raise Exception('gevent not installed. Please install gevent '
                                'or instantiate QueueServer with '
                                'use_gevent=False')

            self._pool = Pool(max_clients)
            self._server = StreamServer(
                (self._host, self._port),
                self.connection_handler,
                spawn=self._pool)
        else:
            self._server = ThreadedStreamServer(
                (self._host, self._port),
                self.connection_handler)

        self._commands = self.get_commands()
        self._protocol = ProtocolHandler()

        self._hashes = defaultdict(dict)
        self._kv = {}
        self._queues = defaultdict(deque)
        self._schedule = []
        self._sets = defaultdict(set)

        self._active_connections = 0
        self._commands_processed = 0
        self._command_errors = 0
        self._connections = 0

    def get_commands(self):
        timestamp_re = (r'(?P<timestamp>\d{4}-\d{2}-\d{2} '
                        '\d{2}:\d{2}:\d{2}(?:\.\d+)?)')
        return dict((
            # Queue commands.
            (b'LPUSH', self.lpush),
            (b'RPUSH', self.rpush),
            (b'LPOP', self.lpop),
            (b'RPOP', self.rpop),
            (b'LREM', self.lrem),
            (b'LLEN', self.llen),
            (b'LINDEX', self.lindex),
            (b'LRANGE', self.lrange),
            (b'LSET', self.lset),
            (b'LTRIM', self.ltrim),
            (b'RPOPLPUSH', self.rpoplpush),
            (b'LFLUSH', self.lflush),

            # K/V commands.
            (b'APPEND', self.kv_append),
            (b'DECR', self.kv_decr),
            (b'DECRBY', self.kv_decrby),
            (b'DELETE', self.kv_delete),
            (b'EXISTS', self.kv_exists),
            (b'GET', self.kv_get),
            (b'GETSET', self.kv_getset),
            (b'INCR', self.kv_incr),
            (b'INCRBY', self.kv_incrby),
            (b'MGET', self.kv_mget),
            (b'MPOP', self.kv_mpop),
            (b'MSET', self.kv_mset),
            (b'POP', self.kv_pop),
            (b'SET', self.kv_set),
            (b'SETNX', self.kv_setnx),
            (b'LEN', self.kv_len),
            (b'FLUSH', self.kv_flush),

            # Hash commands.
            (b'HDEL', self.hdel),
            (b'HEXISTS', self.hexists),
            (b'HGET', self.hget),
            (b'HGETALL', self.hgetall),
            (b'HINCRBY', self.hincrby),
            (b'HKEYS', self.hkeys),
            (b'HLEN', self.hlen),
            (b'HMGET', self.hmget),
            (b'HMSET', self.hmset),
            (b'HSET', self.hset),
            (b'HSETNX', self.hsetnx),
            (b'HVALS', self.hvals),

            # Set commands.
            (b'SADD', self.sadd),
            (b'SCARD', self.scard),
            (b'SDIFF', self.sdiff),
            (b'SDIFFSTORE', self.sdiffstore),
            (b'SINTER', self.sinter),
            (b'SINTERSTORE', self.sinterstore),
            (b'SISMEMBER', self.sismember),
            (b'SMEMBERS', self.smembers),
            (b'SPOP', self.spop),
            (b'SREM', self.srem),
            (b'SUNION', self.sunion),
            (b'SUNIONSTORE', self.sunionstore),

            # Schedule commands.
            (b'ADD', self.schedule_add),
            (b'READ', self.schedule_read),
            (b'FLUSH_SCHEDULE', self.schedule_flush),
            (b'LENGTH_SCHEDULE', self.schedule_length),

            # Misc.
            (b'INFO', self.info),
            (b'FLUSHALL', self.flush_all),
            (b'SHUTDOWN', self.shutdown),
        ))

    def lpush(self, queue, *values):
        self._queues[queue].extendleft(values)
        return len(values)

    def rpush(self, queue, *values):
        self._queues[queue].extend(values)
        return len(values)

    def lpop(self, queue):
        try:
            return self._queues[queue].popleft()
        except IndexError:
            pass

    def rpop(self, queue):
        try:
            return self._queues[queue].pop()
        except IndexError:
            pass

    def lrem(self, queue, value):
        try:
            self._queues[queue].remove(value)
        except ValueError:
            return 0
        else:
            return 1

    def llen(self, queue):
        return len(self._queues[queue])

    def lindex(self, queue, idx):
        try:
            return self._queues[queue][idx]
        except IndexError:
            pass

    def lset(self, queue, idx, value):
        try:
            self._queues[queue][idx] = value
        except IndexError:
            return 0
        else:
            return 1

    def ltrim(self, queue, start, stop):
        self._queues[queue] = deque(list(self._queues[queue])[start:stop])

    def rpoplpush(self, src, dest):
        try:
            self._queues[dest].appendleft(self._queues[src].pop())
        except IndexError:
            return 0
        else:
            return 1

    def lrange(self, queue, start, end=None):
        return list(self._queues[queue])[start:end]

    def lflush(self, queue):
        qlen = self.llen(queue)
        self._queues[queue].clear()
        return qlen

    def kv_append(self, key, value):
        self._kv.setdefault(key, '')
        self._kv[key] += value
        return self._kv[key]

    def _kv_incr(self, key, n):
        self._kv.setdefault(key, 0)
        self._kv[key] += n
        return self._kv[key]

    def kv_decr(self, key):
        return self._kv_incr(key, -1)

    def kv_decrby(self, key, n):
        return self._kv_incr(key, -1 * n)

    def kv_delete(self, key):
        if key in self._kv:
            del self._kv[key]
            return 1
        return 0

    def kv_exists(self, key):
        return 1 if key in self._kv else 0

    def kv_get(self, key):
        return self._kv.get(key)

    def kv_getset(self, key, value):
        orig = self._kv.get(key)
        self._kv[key] = value
        return orig

    def kv_incr(self, key):
        return self._kv_incr(key, 1)

    def kv_incrby(self, key, n):
        return self._kv_incr(key, n)

    def kv_mget(self, *keys):
        return [self._kv.get(key) for key in keys]

    def kv_mpop(self, *keys):
        return [self._kv.pop(key, None) for key in keys]

    def kv_mset(self, __data=None, **kwargs):
        n = 0
        if __data is not None:
            self._kv.update(__data)
            n += len(__data)
        if kwargs:
            self._kv.update(kwargs)
            n += len(kwargs)
        return n

    def kv_pop(self, key):
        return self._kv.pop(key, None)

    def kv_set(self, key, value):
        self._kv[key] = value
        return 1

    def kv_setnx(self, key, value):
        if key in self._kv:
            return 0
        else:
            self._kv[key] = value
            return 1

    def kv_len(self):
        return len(self._kv)

    def kv_flush(self):
        kvlen = self.kv_len()
        self._kv.clear()
        return kvlen

    def _decode_timestamp(self, timestamp):
        timestamp = decode(timestamp)
        fmt = '%Y-%m-%d %H:%M:%S'
        if '.' in timestamp:
            fmt = fmt + '.%f'
        try:
            return datetime.datetime.strptime(timestamp, fmt)
        except ValueError:
            raise CommandError('Timestamp must be formatted Y-m-d H:M:S')

    def hdel(self, key, field):
        if self.hexists(key, field):
            del self._hashes[key][field]
            return 1
        return 0

    def hexists(self, key, field):
        return 1 if field in self._hashes[key] else 0

    def hget(self, key, field):
        return self._hashes[key].get(field)

    def hgetall(self, key):
        return self._hashes[key]

    def hincrby(self, key, field, incr=1):
        self._hashes[key].setdefault(field, 0)
        self._hashes[key][field] += incr
        return self._hashes[key][field]

    def hkeys(self, key):
        return list(self._hashes[key])

    def hlen(self, key):
        return len(self._hashes[key])

    def hmget(self, key, *fields):
        accum = {}
        for field in fields:
            accum[field] = self._hashes[key].get(field)
        return accum

    def hmset(self, key, data):
        self._hashes[key].update(data)
        return len(data)

    def hset(self, key, field, value):
        self._hashes[key][field] = value
        return 1

    def hsetnx(self, key, field, value):
        if field not in self._hashes[key]:
            self._hashes[key][field] = value
            return 1
        return 0

    def hvals(self, key):
        return list(self._hashes[key].values())

    def sadd(self, key, *members):
        self._sets[key].update(members)
        return len(self._sets[key])

    def scard(self, key):
        return len(self._sets[key])

    def sdiff(self, key, *keys):
        src = set(self._sets[key])
        for key in keys:
            src -= self._sets[key]
        return list(src)

    def sdiffstore(self, dest, key, *keys):
        src = set(self._sets[key])
        for key in keys:
            src -= self._sets[key]
        self._sets[dest] = src
        return len(src)

    def sinter(self, key, *keys):
        src = set(self._sets[key])
        for key in keys:
            src &= self._sets[key]
        return list(src)

    def sinterstore(self, dest, key, *keys):
        src = set(self._sets[key])
        for key in keys:
            src &= self._sets[key]
        self._sets[dest] = src
        return len(src)

    def sismember(self, key, member):
        return 1 if member in self._sets[key] else 0

    def smembers(self, key):
        return list(self._sets[key])

    def spop(self, key, n=1):
        accum = []
        for _ in range(n):
            try:
                accum.append(self._sets[key].pop())
            except KeyError:
                break
        return accum

    def srem(self, key, *members):
        ct = 0
        for member in members:
            try:
                self._sets[key].remove(member)
            except KeyError:
                pass
            else:
                ct += 1
        return ct

    def sunion(self, key, *keys):
        src = set(self._sets[key])
        for key in keys:
            src |= self._sets[key]
        return list(src)

    def sunionstore(self, dest, key, *keys):
        src = set(self._sets[key])
        for key in keys:
            src |= self._sets[key]
        self._sets[dest] = src
        return len(src)

    def schedule_add(self, timestamp, data):
        dt = self._decode_timestamp(timestamp)
        heapq.heappush(self._schedule, (dt, data))
        return 1

    def schedule_read(self, timestamp=None):
        dt = self._decode_timestamp(timestamp)
        accum = []
        while self._schedule and self._schedule[0][0] <= dt:
            ts, data = heapq.heappop(self._schedule)
            accum.append(data)
        return accum

    def schedule_flush(self):
        schedulelen = self.schedule_length()
        self._schedule = []
        return schedulelen

    def schedule_length(self):
        return len(self._schedule)

    def info(self):
        return {
            'active_connections': self._active_connections,
            'commands_processed': self._commands_processed,
            'command_errors': self._command_errors,
            'connections': self._connections}

    def flush_all(self):
        self._hashes = defaultdict(dict)
        self._queues = defaultdict(deque)
        self._sets = defaultdict(set)
        self.kv_flush()
        self.schedule_flush()
        return 1

    def add_command(self, command, callback):
        if isinstance(command, unicode):
            command = command.encode('utf-8')
        self._commands[command] = callback

    def shutdown(self):
        raise Shutdown('shutting down')

    def run(self):
        self._server.serve_forever()

    def connection_handler(self, conn, address):
        logger.info('Connection received: %s:%s' % address)
        socket_file = conn.makefile('rwb')
        self._active_connections += 1
        while True:
            try:
                self.request_response(socket_file)
            except EOFError:
                logger.info('Client went away: %s:%s' % address)
                break
            except Exception as exc:
                logger.exception('Error processing command.')
        self._active_connections -= 1

    def request_response(self, socket_file):
        data = self._protocol.handle_request(socket_file)

        try:
            resp = self.respond(data)
        except Shutdown:
            logger.info('Shutting down')
            self._protocol.write_response(socket_file, 1)
            raise KeyboardInterrupt()
        except CommandError as command_error:
            resp = Error(command_error.message)
            self._command_errors += 1
        except Exception as exc:
            logger.exception('Unhandled error')
            resp = Error('Unhandled server error')
        else:
            self._commands_processed += 1

        self._protocol.write_response(socket_file, resp)

    def respond(self, data):
        if not isinstance(data, list):
            try:
                data = data.split()
            except:
                raise CommandError('Unrecognized request type.')

        if not isinstance(data[0], basestring):
            raise CommandError('First parameter must be command name.')

        command = data[0].upper()
        if command not in self._commands:
            raise CommandError('Unrecognized command: %s' % command)
        else:
            logger.debug('Received %s', decode(command))

        return self._commands[command](*data[1:])


class Client(object):
    def __init__(self, host='127.0.0.1', port=31337):
        self._host = host
        self._port = port
        self._socket = None
        self._fh = None
        self._protocol = ProtocolHandler()
        self._is_closed = True

    def connect(self):
        if not self._is_closed:
            self.close()

        self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        self._socket.connect((self._host, self._port))
        self._fh = self._socket.makefile('rwb')
        self._is_closed = False
        return True

    def close(self):
        if self._is_closed:
            return False

        self._socket.close()
        self._is_closed = True
        return True

    def execute(self, *args):
        if self._is_closed:
            self.connect()
        self._protocol.write_response(self._fh, args)
        resp = self._protocol.handle_request(self._fh)
        if isinstance(resp, Error):
            raise CommandError(resp.message)
        return resp

    def command(cmd):
        def method(self, *args):
            return self.execute(cmd.encode('utf-8'), *args)
        return method

    lpush = command('LPUSH')
    rpush = command('RPUSH')
    lpop = command('LPOP')
    rpop = command('RPOP')
    lrem = command('LREM')
    llen = command('LLEN')
    lindex = command('LINDEX')
    lrange = command('LRANGE')
    lset = command('LSET')
    ltrim = command('LTRIM')
    rpoplpush = command('RPOPLPUSH')
    lflush = command('LFLUSH')

    append = command('APPEND')
    decr = command('DECR')
    decrby = command('DECRBY')
    delete = command('DELETE')
    exists = command('EXISTS')
    get = command('GET')
    getset = command('GETSET')
    incr = command('INCR')
    incrby = command('INCRBY')
    mget = command('MGET')
    mpop = command('MPOP')
    mset = command('MSET')
    pop = command('POP')
    set = command('SET')
    setnx = command('SETNX')
    length = command('LEN')
    flush = command('FLUSH')

    hdel = command('HDEL')
    hexists = command('HEXISTS')
    hget = command('HGET')
    hgetall = command('HGETALL')
    hincrby = command('HINCRBY')
    hkeys = command('HKEYS')
    hlen = command('HLEN')
    hmget = command('HMGET')
    hmset = command('HMSET')
    hset = command('HSET')
    hsetnx = command('HSETNX')
    hvals = command('HVALS')

    sadd = command('SADD')
    scard = command('SCARD')
    sdiff = command('SDIFF')
    sdiffstore = command('SDIFFSTORE')
    sinter = command('SINTER')
    sinterstore = command('SINTERSTORE')
    sismember = command('SISMEMBER')
    smembers = command('SMEMBERS')
    spop = command('SPOP')
    srem = command('SREM')
    sunion = command('SUNION')
    sunionstore = command('SUNIONSTORE')

    add = command('ADD')
    read = command('READ')
    flush_schedule = command('FLUSH_SCHEDULE')
    length_schedule = command('LENGTH_SCHEDULE')

    info = command('INFO')
    flushall = command('FLUSHALL')
    shutdown = command('SHUTDOWN')


def get_option_parser():
    parser = optparse.OptionParser()
    parser.add_option('-d', '--debug', action='store_true', dest='debug',
                      help='Log debug messages.')
    parser.add_option('-e', '--errors', action='store_true', dest='error',
                      help='Log error messages only.')
    parser.add_option('-t', '--use-threads', action='store_false',
                      default=True, dest='use_gevent',
                      help='Use threads instead of gevent.')
    parser.add_option('-H', '--host', default='127.0.0.1', dest='host',
                      help='Host to listen on.')
    parser.add_option('-m', '--max-clients', default=64, dest='max_clients',
                      help='Maximum number of clients.', type=int)
    parser.add_option('-p', '--port', default=31337, dest='port',
                      help='Port to listen on.', type=int)
    parser.add_option('-l', '--log-file', dest='log_file', help='Log file.')
    parser.add_option('-x', '--extension', action='append', dest='extensions',
                      help='Import path for Python extension module(s).')
    return parser


def configure_logger(options):
    logger.addHandler(logging.StreamHandler())
    if options.log_file:
        logger.addHandler(logging.FileHandler(options.log_file))
    if options.debug:
        logger.setLevel(logging.DEBUG)
    elif options.error:
        logger.setLevel(logging.ERROR)
    else:
        logger.setLevel(logging.INFO)


def load_extensions(server, extensions):
    for extension in extensions:
        try:
            module = importlib.import_module(extension)
        except ImportError:
            logger.exception('Could not import extension %s' % extension)
        else:
            try:
                initialize = getattr(module, 'initialize')
            except AttributeError:
                logger.exception('Could not find "initialize" function in '
                                 'extension %s' % extension)
                raise
            else:
                initialize(server)
                logger.info('Loaded %s extension.' % extension)


if __name__ == '__main__':
    options, args = get_option_parser().parse_args()

    if options.use_gevent:
        from gevent import monkey; monkey.patch_all()

    configure_logger(options)
    server = QueueServer(host=options.host, port=options.port,
                         max_clients=options.max_clients,
                         use_gevent=options.use_gevent)
    load_extensions(server, options.extensions or ())
    server.run()

以上是关于python 用Python实现的微型redis-like服务器的主要内容,如果未能解决你的问题,请参考以下文章

python中,用Redis构建分布式锁

Python开发桌面微型计算器

微型 Python Web 框架: Bottle

python学习笔记4-redis multi watch实现锁库存

markdown 使用Python的启动微型的HttpServer

Redis(二十二)-秒杀案例的基本实现以及用ab工具模拟并发