123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369 |
- """\
- A library of useful helper classes to the SAX classes, for the
- convenience of application and driver writers.
- """
- import os, urllib.parse, urllib.request
- import io
- import codecs
- from . import handler
- from . import xmlreader
- def __dict_replace(s, d):
- """Replace substrings of a string using a dictionary."""
- for key, value in d.items():
- s = s.replace(key, value)
- return s
- def escape(data, entities={}):
- """Escape &, <, and > in a string of data.
- You can escape other strings of data by passing a dictionary as
- the optional entities parameter. The keys and values must all be
- strings; each key will be replaced with its corresponding value.
- """
- # must do ampersand first
- data = data.replace("&", "&")
- data = data.replace(">", ">")
- data = data.replace("<", "<")
- if entities:
- data = __dict_replace(data, entities)
- return data
- def unescape(data, entities={}):
- """Unescape &, <, and > in a string of data.
- You can unescape other strings of data by passing a dictionary as
- the optional entities parameter. The keys and values must all be
- strings; each key will be replaced with its corresponding value.
- """
- data = data.replace("<", "<")
- data = data.replace(">", ">")
- if entities:
- data = __dict_replace(data, entities)
- # must do ampersand last
- return data.replace("&", "&")
- def quoteattr(data, entities={}):
- """Escape and quote an attribute value.
- Escape &, <, and > in a string of data, then quote it for use as
- an attribute value. The \" character will be escaped as well, if
- necessary.
- You can escape other strings of data by passing a dictionary as
- the optional entities parameter. The keys and values must all be
- strings; each key will be replaced with its corresponding value.
- """
- entities = {**entities, '\n': ' ', '\r': ' ', '\t':'	'}
- data = escape(data, entities)
- if '"' in data:
- if "'" in data:
- data = '"%s"' % data.replace('"', """)
- else:
- data = "'%s'" % data
- else:
- data = '"%s"' % data
- return data
- def _gettextwriter(out, encoding):
- if out is None:
- import sys
- return sys.stdout
- if isinstance(out, io.TextIOBase):
- # use a text writer as is
- return out
- if isinstance(out, (codecs.StreamWriter, codecs.StreamReaderWriter)):
- # use a codecs stream writer as is
- return out
- # wrap a binary writer with TextIOWrapper
- if isinstance(out, io.RawIOBase):
- # Keep the original file open when the TextIOWrapper is
- # destroyed
- class _wrapper:
- __class__ = out.__class__
- def __getattr__(self, name):
- return getattr(out, name)
- buffer = _wrapper()
- buffer.close = lambda: None
- else:
- # This is to handle passed objects that aren't in the
- # IOBase hierarchy, but just have a write method
- buffer = io.BufferedIOBase()
- buffer.writable = lambda: True
- buffer.write = out.write
- try:
- # TextIOWrapper uses this methods to determine
- # if BOM (for UTF-16, etc) should be added
- buffer.seekable = out.seekable
- buffer.tell = out.tell
- except AttributeError:
- pass
- return io.TextIOWrapper(buffer, encoding=encoding,
- errors='xmlcharrefreplace',
- newline='\n',
- write_through=True)
- class XMLGenerator(handler.ContentHandler):
- def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False):
- handler.ContentHandler.__init__(self)
- out = _gettextwriter(out, encoding)
- self._write = out.write
- self._flush = out.flush
- self._ns_contexts = [{}] # contains uri -> prefix dicts
- self._current_context = self._ns_contexts[-1]
- self._undeclared_ns_maps = []
- self._encoding = encoding
- self._short_empty_elements = short_empty_elements
- self._pending_start_element = False
- def _qname(self, name):
- """Builds a qualified name from a (ns_url, localname) pair"""
- if name[0]:
- # Per http://www.w3.org/XML/1998/namespace, The 'xml' prefix is
- # bound by definition to http://www.w3.org/XML/1998/namespace. It
- # does not need to be declared and will not usually be found in
- # self._current_context.
- if 'http://www.w3.org/XML/1998/namespace' == name[0]:
- return 'xml:' + name[1]
- # The name is in a non-empty namespace
- prefix = self._current_context[name[0]]
- if prefix:
- # If it is not the default namespace, prepend the prefix
- return prefix + ":" + name[1]
- # Return the unqualified name
- return name[1]
- def _finish_pending_start_element(self,endElement=False):
- if self._pending_start_element:
- self._write('>')
- self._pending_start_element = False
- # ContentHandler methods
- def startDocument(self):
- self._write('<?xml version="1.0" encoding="%s"?>\n' %
- self._encoding)
- def endDocument(self):
- self._flush()
- def startPrefixMapping(self, prefix, uri):
- self._ns_contexts.append(self._current_context.copy())
- self._current_context[uri] = prefix
- self._undeclared_ns_maps.append((prefix, uri))
- def endPrefixMapping(self, prefix):
- self._current_context = self._ns_contexts[-1]
- del self._ns_contexts[-1]
- def startElement(self, name, attrs):
- self._finish_pending_start_element()
- self._write('<' + name)
- for (name, value) in attrs.items():
- self._write(' %s=%s' % (name, quoteattr(value)))
- if self._short_empty_elements:
- self._pending_start_element = True
- else:
- self._write(">")
- def endElement(self, name):
- if self._pending_start_element:
- self._write('/>')
- self._pending_start_element = False
- else:
- self._write('</%s>' % name)
- def startElementNS(self, name, qname, attrs):
- self._finish_pending_start_element()
- self._write('<' + self._qname(name))
- for prefix, uri in self._undeclared_ns_maps:
- if prefix:
- self._write(' xmlns:%s="%s"' % (prefix, uri))
- else:
- self._write(' xmlns="%s"' % uri)
- self._undeclared_ns_maps = []
- for (name, value) in attrs.items():
- self._write(' %s=%s' % (self._qname(name), quoteattr(value)))
- if self._short_empty_elements:
- self._pending_start_element = True
- else:
- self._write(">")
- def endElementNS(self, name, qname):
- if self._pending_start_element:
- self._write('/>')
- self._pending_start_element = False
- else:
- self._write('</%s>' % self._qname(name))
- def characters(self, content):
- if content:
- self._finish_pending_start_element()
- if not isinstance(content, str):
- content = str(content, self._encoding)
- self._write(escape(content))
- def ignorableWhitespace(self, content):
- if content:
- self._finish_pending_start_element()
- if not isinstance(content, str):
- content = str(content, self._encoding)
- self._write(content)
- def processingInstruction(self, target, data):
- self._finish_pending_start_element()
- self._write('<?%s %s?>' % (target, data))
- class XMLFilterBase(xmlreader.XMLReader):
- """This class is designed to sit between an XMLReader and the
- client application's event handlers. By default, it does nothing
- but pass requests up to the reader and events on to the handlers
- unmodified, but subclasses can override specific methods to modify
- the event stream or the configuration requests as they pass
- through."""
- def __init__(self, parent = None):
- xmlreader.XMLReader.__init__(self)
- self._parent = parent
- # ErrorHandler methods
- def error(self, exception):
- self._err_handler.error(exception)
- def fatalError(self, exception):
- self._err_handler.fatalError(exception)
- def warning(self, exception):
- self._err_handler.warning(exception)
- # ContentHandler methods
- def setDocumentLocator(self, locator):
- self._cont_handler.setDocumentLocator(locator)
- def startDocument(self):
- self._cont_handler.startDocument()
- def endDocument(self):
- self._cont_handler.endDocument()
- def startPrefixMapping(self, prefix, uri):
- self._cont_handler.startPrefixMapping(prefix, uri)
- def endPrefixMapping(self, prefix):
- self._cont_handler.endPrefixMapping(prefix)
- def startElement(self, name, attrs):
- self._cont_handler.startElement(name, attrs)
- def endElement(self, name):
- self._cont_handler.endElement(name)
- def startElementNS(self, name, qname, attrs):
- self._cont_handler.startElementNS(name, qname, attrs)
- def endElementNS(self, name, qname):
- self._cont_handler.endElementNS(name, qname)
- def characters(self, content):
- self._cont_handler.characters(content)
- def ignorableWhitespace(self, chars):
- self._cont_handler.ignorableWhitespace(chars)
- def processingInstruction(self, target, data):
- self._cont_handler.processingInstruction(target, data)
- def skippedEntity(self, name):
- self._cont_handler.skippedEntity(name)
- # DTDHandler methods
- def notationDecl(self, name, publicId, systemId):
- self._dtd_handler.notationDecl(name, publicId, systemId)
- def unparsedEntityDecl(self, name, publicId, systemId, ndata):
- self._dtd_handler.unparsedEntityDecl(name, publicId, systemId, ndata)
- # EntityResolver methods
- def resolveEntity(self, publicId, systemId):
- return self._ent_handler.resolveEntity(publicId, systemId)
- # XMLReader methods
- def parse(self, source):
- self._parent.setContentHandler(self)
- self._parent.setErrorHandler(self)
- self._parent.setEntityResolver(self)
- self._parent.setDTDHandler(self)
- self._parent.parse(source)
- def setLocale(self, locale):
- self._parent.setLocale(locale)
- def getFeature(self, name):
- return self._parent.getFeature(name)
- def setFeature(self, name, state):
- self._parent.setFeature(name, state)
- def getProperty(self, name):
- return self._parent.getProperty(name)
- def setProperty(self, name, value):
- self._parent.setProperty(name, value)
- # XMLFilter methods
- def getParent(self):
- return self._parent
- def setParent(self, parent):
- self._parent = parent
- # --- Utility functions
- def prepare_input_source(source, base=""):
- """This function takes an InputSource and an optional base URL and
- returns a fully resolved InputSource object ready for reading."""
- if isinstance(source, os.PathLike):
- source = os.fspath(source)
- if isinstance(source, str):
- source = xmlreader.InputSource(source)
- elif hasattr(source, "read"):
- f = source
- source = xmlreader.InputSource()
- if isinstance(f.read(0), str):
- source.setCharacterStream(f)
- else:
- source.setByteStream(f)
- if hasattr(f, "name") and isinstance(f.name, str):
- source.setSystemId(f.name)
- if source.getCharacterStream() is None and source.getByteStream() is None:
- sysid = source.getSystemId()
- basehead = os.path.dirname(os.path.normpath(base))
- sysidfilename = os.path.join(basehead, sysid)
- if os.path.isfile(sysidfilename):
- source.setSystemId(sysidfilename)
- f = open(sysidfilename, "rb")
- else:
- source.setSystemId(urllib.parse.urljoin(base, sysid))
- f = urllib.request.urlopen(source.getSystemId())
- source.setByteStream(f)
- return source
|