# This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation; either version 2 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Library General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. # # Copyright 2005 Dan Williams and Red Hat, Inc. import urllib import os import socket import CommonErrors import exceptions import time import FileTransfer class FileNameException(exceptions.Exception): pass def get_base_filename_from_url(url, legal_exts): """ Safely unquotes a URL and gets the base file name from it. We're not using urlparse here because it doesn't un-escape stuff """ unquoted = url last_unquoted = None count = 5 # Keep unquoting the string until the last two unquote operations # produce the same string while (unquoted != last_unquoted) and (count > 0): last_unquoted = unquoted unquoted = urllib.unquote_plus(unquoted) count = count - 1 # If after 5 iterations of unquoting, the strings still aren't the same, # something is wrong. if (count == 0) and (unquoted != last_unquoted): raise FileNameException("URL quoting level too deep.") # Try to grab the filename off the end of the URL index = url.rfind('/') if index is -1: raise FileNameException("No separator in URL.") filename = url[index+1:] # Only accept certain file extensions ext_ok = False for ext in legal_exts: if filename.endswith(ext): ext_ok = True break if not ext_ok: raise FileNameException("Extension was not allowed.") # FIXME: what other validation can we do here? safe_list = ['_', '-', '.', '+'] for char in filename: # For now, legal characters are '_-.' plus alphanumeric if char in safe_list or char.isalnum(): pass else: raise FileNameException("Illegal character '%s' encountered." % char) return filename class FileDownloader(FileTransfer.FileTransfer): def __init__(self, urls, target_dir, legal_exts, certs=None): FileTransfer.FileTransfer.__init__(self, certs) if not target_dir: raise Exception("Require a target directory to download to.") self._target_dir = target_dir self._files = {} if type(urls) == type(""): urls = [urls] if type(urls) is not type([]): raise ValueError("urls argument must be a list of URLs.") for url in urls: fname = get_base_filename_from_url(url, legal_exts) if not fname: raise FileNameException("Bad file name from url %s" % url) self._files[url] = fname if len(self._files.keys()) == 0: raise ValueError("Need at least one file to download.") def _action(self, (url, fname)): result = None err_msg = None if not os.path.exists(self._target_dir): os.makedirs(self._target_dir) target_file = os.path.join(self._target_dir, fname) try: result = self._opener.retrieve(url, target_file) except socket.error, exc: if not CommonErrors.canIgnoreSocketError(exc): err_msg = "Socket Error: %s" % exc except IOError, exc: if not CommonErrors.canIgnoreSocketError(exc): err_msg = "IOError Error: %s" % exc return (result, err_msg) def run(self): final_result = FileTransfer.FT_RESULT_SUCCESS msg = None for url in self._files.keys(): (result, msg) = self._process_one_transfer((url, self._files[url])) if result == FileTransfer.FT_RESULT_FAILED: final_result = FileTransfer.FT_RESULT_FAILED msg = "Download of %s failed because: %s" % (url, msg) break if self._cancel: final_result = FileTransfer.FT_RESULT_CANCELED break if self._callback: self._callback(final_result, self._cb_data, msg) ########################################################### # Testing stuff ########################################################### import sys class DlCallbackData: def __init__(self, num, tracker): self.num = num self.tracker = tracker self.fdl = None def set_dl(self, fdl): self.fdl = fdl class DlTracker: def __init__(self): self.lst = [] def add(self, dlcb): self.lst.append(dlcbdata) def remove(self, dlcbdata): self.lst.remove(dlcbdata) def num(self): return len(self.lst) def dl_callback(status, dlcbdata, msg=""): print "Finished with %d (%s: %s)" % (dlcbdata.num, status, msg) dlcbdata.tracker.remove(dlcbdata) def main(): if len(sys.argv) < 4: print "Usage: python FileDownloader.py key_and_cert ca_cert peer_ca_cert" sys.exit(1) certs = {} certs['key_and_cert'] = sys.argv[1] certs['ca_cert'] = sys.argv[2] certs['peer_ca_cert'] = sys.argv[3] print "Starting..." dlt = DlTracker() count = 0 while count < 100: dlcb = DlCallbackData(count, dlt) dstdir = os.path.join("/tmp", "client_dir", "%s" % count) if not os.path.exists(dstdir): os.makedirs(dstdir) time.sleep(0.25) fdl = FileDownloader("https://localhost:8886/testfile.dat", dstdir, ['.dat'], certs) fdl.set_callback(dl_callback, dlcb) dlt.add(dlcb) fdl.start() count = count + 1 while dlt.num() > 0: try: time.sleep(1) except KeyboardInterrupt: print "Quitting..." os._exit(0) if __name__ == '__main__': main()