Browse Source

[mhtml, cleanup] Use imghdr

pukkandan 2 years ago
parent
commit
b4daacb4ec
2 changed files with 9 additions and 14 deletions
  1. 7 5
      yt_dlp/compat/imghdr.py
  2. 2 9
      yt_dlp/downloader/mhtml.py

+ 7 - 5
yt_dlp/compat/imghdr.py

@@ -2,13 +2,15 @@ tests = {
     'webp': lambda h: h[0:4] == b'RIFF' and h[8:] == b'WEBP',
     'webp': lambda h: h[0:4] == b'RIFF' and h[8:] == b'WEBP',
     'png': lambda h: h[:8] == b'\211PNG\r\n\032\n',
     'png': lambda h: h[:8] == b'\211PNG\r\n\032\n',
     'jpeg': lambda h: h[6:10] in (b'JFIF', b'Exif'),
     'jpeg': lambda h: h[6:10] in (b'JFIF', b'Exif'),
+    'gif': lambda h: h[:6] in (b'GIF87a', b'GIF89a'),
 }
 }
 
 
 
 
-def what(path):
-    """Detect format of image (Currently supports jpeg, png, webp only)
+def what(file=None, h=None):
+    """Detect format of image (Currently supports jpeg, png, webp, gif only)
     Ref: https://github.com/python/cpython/blob/3.10/Lib/imghdr.py
     Ref: https://github.com/python/cpython/blob/3.10/Lib/imghdr.py
     """
     """
-    with open(path, 'rb') as f:
-        head = f.read(12)
-    return next((type_ for type_, test in tests.items() if test(head)), None)
+    if h is None:
+        with open(file, 'rb') as f:
+            h = f.read(12)
+    return next((type_ for type_, test in tests.items() if test(h)), None)

+ 2 - 9
yt_dlp/downloader/mhtml.py

@@ -4,6 +4,7 @@ import re
 import uuid
 import uuid
 
 
 from .fragment import FragmentFD
 from .fragment import FragmentFD
+from ..compat import imghdr
 from ..utils import escapeHTML, formatSeconds, srt_subtitles_timecode, urljoin
 from ..utils import escapeHTML, formatSeconds, srt_subtitles_timecode, urljoin
 from ..version import __version__ as YT_DLP_VERSION
 from ..version import __version__ as YT_DLP_VERSION
 
 
@@ -166,21 +167,13 @@ body > figure > img {
                 continue
                 continue
             frag_content = self._read_fragment(ctx)
             frag_content = self._read_fragment(ctx)
 
 
-            mime_type = b'image/jpeg'
-            if frag_content.startswith(b'\x89PNG\r\n\x1a\n'):
-                mime_type = b'image/png'
-            if frag_content.startswith((b'GIF87a', b'GIF89a')):
-                mime_type = b'image/gif'
-            if frag_content.startswith(b'RIFF') and frag_content[8:12] == b'WEBP':
-                mime_type = b'image/webp'
-
             frag_header = io.BytesIO()
             frag_header = io.BytesIO()
             frag_header.write(
             frag_header.write(
                 b'--%b\r\n' % frag_boundary.encode('us-ascii'))
                 b'--%b\r\n' % frag_boundary.encode('us-ascii'))
             frag_header.write(
             frag_header.write(
                 b'Content-ID: <%b>\r\n' % self._gen_cid(i, fragment, frag_boundary).encode('us-ascii'))
                 b'Content-ID: <%b>\r\n' % self._gen_cid(i, fragment, frag_boundary).encode('us-ascii'))
             frag_header.write(
             frag_header.write(
-                b'Content-type: %b\r\n' % mime_type)
+                b'Content-type: %b\r\n' % f'image/{imghdr.what(h=frag_content) or "jpeg"}'.encode())
             frag_header.write(
             frag_header.write(
                 b'Content-length: %u\r\n' % len(frag_content))
                 b'Content-length: %u\r\n' % len(frag_content))
             frag_header.write(
             frag_header.write(