# This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 2 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Library General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. # # Copyright 2005 Dan Williams and Red Hat, Inc. import os, sys import socket import time import SocketServer import xmlrpclib import SimpleXMLRPCServer import SSLCommon import threading import OpenSSL """ So we need a way of getting data from the RequestHandler instance, which is unique to each request, into the actual handler instance that we've registered. To do this without overriding a bunch of member functions, we have a global dict of data that maps authorization info to a particular thread, since each request is handled in its own thread. Therefore, the handler instance can access the data that our request handler has set. """ __authinfos = {} def _add_authinfo(authinfo): __authinfos[threading.currentThread()] = authinfo def get_authinfo(): return __authinfos.get(threading.currentThread()) def _del_authinfo(): del __authinfos[threading.currentThread()] class AuthedSimpleXMLRPCRequestHandler(SimpleXMLRPCServer.SimpleXMLRPCRequestHandler): # For some reason, httplib closes the connection right after headers # have been sent if the connection is _not_ HTTP/1.1, which results in # a "Bad file descriptor" error when the client tries to read from the socket protocol_version = "HTTP/1.1" def do_POST(self): authinfo = self.server.get_authinfo(self.request, self.client_address) _add_authinfo(authinfo) try: SimpleXMLRPCServer.SimpleXMLRPCRequestHandler.do_POST(self) except socket.timeout: pass except (socket.error, OpenSSL.SSL.SysCallError), e: print "Error (%s): socket error - '%s'" % (self.client_address, e) _del_authinfo() class BaseAuthedXMLRPCServer: def __init__(self, address, authinfo_callback=None): self.allow_reuse_address = 1 self.logRequests = 0 self.authinfo_callback = authinfo_callback if sys.version_info[:3] > (2, 2, 3): if sys.version_info[:2] < (2, 5): SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self) else: SimpleXMLRPCServer.SimpleXMLRPCDispatcher.__init__(self, False, None) else: self.funcs = {} self.instance = None def get_authinfo(self, request, client_address): if self.authinfo_callback: return self.authinfo_callback(request, client_address) return None class AuthedSSLXMLRPCServer(BaseAuthedXMLRPCServer, SSLCommon.PlgBaseSSLServer, SimpleXMLRPCServer.SimpleXMLRPCServer): """ Extension to allow more fine-tuned SSL handling """ def __init__(self, address, authinfo_callback=None, certs=None, timeout=None): BaseAuthedXMLRPCServer.__init__(self, address, authinfo_callback) SSLCommon.PlgBaseSSLServer.__init__(self, address, AuthedSimpleXMLRPCRequestHandler, certs, timeout=timeout) class AuthedXMLRPCServer(BaseAuthedXMLRPCServer, SSLCommon.PlgBaseServer, SimpleXMLRPCServer.SimpleXMLRPCServer): def __init__(self, address, authinfo_callback=None): BaseAuthedXMLRPCServer.__init__(self, address, authinfo_callback) SSLCommon.PlgBaseServer.__init__(self, address, AuthedSimpleXMLRPCRequestHandler) ########################################################### # Testing stuff ########################################################### class ReqHandler: def ping(self, callerid, trynum): authinfo = get_authinfo() print "AUTHINFO(%d / %d): %s" % (callerid, trynum, authinfo) return "pong %d / %d" % (callerid, trynum) class TestServer(AuthedSSLXMLRPCServer): """ SSL XMLRPC server that authenticates clients based on their certificate. """ def __init__(self, address, certs): AuthedSSLXMLRPCServer.__init__(self, address, self.auth_cb, certs) def auth_cb(self, request, client_address): import random peer_cert = request.get_peer_certificate() return peer_cert.get_subject().emailAddress if __name__ == '__main__': if len(sys.argv) < 4: print "Usage: python AuthdXMLRPCServer.py key_and_cert ca_cert peer_ca_cert" sys.exit(1) certs = {} certs['key_and_cert'] = sys.argv[1] certs['ca_cert'] = sys.argv[2] certs['peer_ca_cert'] = sys.argv[3] print "Starting the server." server = TestServer(('localhost', 8886), certs) h = ReqHandler() server.register_instance(h) server.serve_forever()