This commit is contained in:
尹舟 2025-03-16 14:11:13 +08:00
commit 813f1488e0
31 changed files with 2382 additions and 0 deletions

5
.dockerignore Normal file
View File

@ -0,0 +1,5 @@
.venv/
.idea/
.deploy/
logs/
./mp4/*

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
./idea
./build
./dist
./m3u8_To_MP4.egg-info
./__pycache__
*.pyc

10
config/config.json Normal file
View File

@ -0,0 +1,10 @@
{
"yz_mysql": {
"host": "mysql",
"port": 3306,
"userName": "root",
"password": "mysql",
"dbName": "movie",
"charsets": "UTF8"
}
}

15
docker-compose.yaml Normal file
View File

@ -0,0 +1,15 @@
version: '3.4'
services:
sql-runner:
build:
context: .
dockerfile: Dockerfile
restart: always
container_name: m3u8_download
image: registry.cn-hangzhou.aliyuncs.com/yinzhou_docker_hub/m3u8_download:latest
ports:
- "1314:1314"
volumes:
- ./mp4:/opt/m3u8_download/mp4
# docker-compose up --build

20
dockerfile Normal file
View File

@ -0,0 +1,20 @@
# 使用阿里云的 Python 3.11 镜像
FROM registry.cn-hangzhou.aliyuncs.com/yinzhou_docker_hub/python:3.11-alpine
# 设置工作目录
WORKDIR /opt/m3u8_download
# 设置时区为 Asia/Shanghai
ENV TZ=Asia/Shanghai
# 将 requirements.txt 文件复制到容器中
COPY requirements.txt .
# 安装依赖
RUN pip install --no-cache-dir -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
# 将其他文件复制到容器中
COPY . .
# 运行应用程序
ENTRYPOINT ["python3", "m3u8_download.py"]

171
m3u8_To_MP4/__init__.py Normal file
View File

@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
"""
m3u8ToMP4
~~~~~~~~~~~~
Basic usage:
import m3u8_to_mp4
m3u8_to_mp4.download("https://xxx.com/xxx/index.m3u8")
"""
import logging
import subprocess
from m3u8_To_MP4.helpers import printer_helper
printer_helper.config_logging()
# verify ffmpeg
def verify_ffmpey():
test_has_ffmpeg_cmd = "ffmpeg -version"
proc = subprocess.Popen(test_has_ffmpeg_cmd, shell=True,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
outs, errs = proc.communicate()
output_text = outs.decode('utf8')
if 'version' not in output_text:
logging.warning('NOT FOUND FFMPEG!')
logging.info('Compressing into tar.bz2 is only supported')
# define API
import m3u8_To_MP4.multithreads_processor
from m3u8_To_MP4.v2_async_processor import AsynchronousFileCrawler
from m3u8_To_MP4.v2_async_processor import AsynchronousUriCrawler
from m3u8_To_MP4.v2_multithreads_processor import MultiThreadsFileCrawler
from m3u8_To_MP4.v2_multithreads_processor import MultiThreadsUriCrawler
__all__ = (
"MultiThreadsFileCrawler",
"MultiThreadsUriCrawler",
"AsynchronousFileCrawler",
"AsynchronousUriCrawler",
"async_download",
"async_file_download",
"async_uri_download",
"multithread_download",
"multithread_file_download",
"multithread_uri_download",
"download"
)
# ================ Async ===================
def async_download(m3u8_uri, file_path='./m3u8_To_MP4.ts', customized_http_header=None, max_retry_times=3,
num_concurrent=50, tmpdir=None):
'''
Download mp4 video from given m3u uri.
:param m3u8_uri: m3u8 uri
:param max_retry_times: max retry times
:param max_concurrent: concurrency
:param mp4_file_dir: folder path where mp4 file is stored
:param mp4_file_name: a mp4 file name with suffix ".mp4"
:return:
'''
with m3u8_To_MP4.v2_async_processor.AsynchronousUriCrawler(m3u8_uri,
file_path,
customized_http_header,
max_retry_times,
num_concurrent,
tmpdir) as crawler:
crawler.fetch_mp4_by_m3u8_uri('ts')
def async_uri_download(m3u8_uri, file_path='./m3u8_To_MP4.mp4', customized_http_header=None,
max_retry_times=3, num_concurrent=50, tmpdir=None):
with m3u8_To_MP4.v2_async_processor.AsynchronousUriCrawler(m3u8_uri,
file_path,
customized_http_header,
max_retry_times,
num_concurrent,
tmpdir) as crawler:
crawler.fetch_mp4_by_m3u8_uri('ts')
def async_file_download(m3u8_uri, m3u8_file_path, file_path='./m3u8_To_MP4.ts', customized_http_header=None,
max_retry_times=3, num_concurrent=50, tmpdir=None):
with m3u8_To_MP4.v2_async_processor.AsynchronousFileCrawler(m3u8_uri,
m3u8_file_path,
file_path,
customized_http_header,
max_retry_times,
num_concurrent,
tmpdir) as crawler:
crawler.fetch_mp4_by_m3u8_uri('ts')
# ================ MultiThread ===================
def multithread_download(m3u8_uri, file_path='./m3u8_To_MP4.ts', customized_http_header=None,
max_retry_times=3, max_num_workers=100, tmpdir=None):
'''
Download mp4 video from given m3u uri.
:param m3u8_uri: m3u8 uri
:param max_retry_times: max retry times
:param max_num_workers: number of download threads
:param mp4_file_dir: folder path where mp4 file is stored
:param mp4_file_name: a mp4 file name with suffix ".mp4"
:return:
'''
with m3u8_To_MP4.v2_multithreads_processor.MultiThreadsUriCrawler(m3u8_uri,
file_path,
customized_http_header,
max_retry_times,
max_num_workers,
tmpdir) as crawler:
crawler.fetch_mp4_by_m3u8_uri('ts')
def multithread_uri_download(m3u8_uri, file_path='./m3u8_To_MP4.ts', customied_http_header=None,
max_retry_times=3, max_num_workers=100, tmpdir=None):
with m3u8_To_MP4.v2_multithreads_processor.MultiThreadsUriCrawler(m3u8_uri,
file_path,
customied_http_header,
max_retry_times,
max_num_workers,
tmpdir) as crawler:
crawler.fetch_mp4_by_m3u8_uri('ts')
def multithread_file_download(m3u8_uri, m3u8_file_path, file_path,
customized_http_header=None, max_retry_times=3,
max_num_workers=100, tmpdir=None):
with m3u8_To_MP4.v2_multithreads_processor.MultiThreadsFileCrawler(
m3u8_uri, m3u8_file_path, file_path, customized_http_header, max_retry_times,
max_num_workers, tmpdir) as crawler:
crawler.fetch_mp4_by_m3u8_uri(True)
# ================ Deprecated Function ===================
import warnings
def download(m3u8_uri, max_retry_times=3, max_num_workers=100,
mp4_file_dir='./', mp4_file_name='m3u8_To_MP4', tmpdir=None):
'''
Download mp4 video from given m3u uri.
:param m3u8_uri: m3u8 uri
:param max_retry_times: max retry times
:param max_num_workers: number of download threads
:param mp4_file_dir: folder path where mp4 file is stored
:param mp4_file_name: a mp4 file name with suffix ".mp4"
:return:
'''
warnings.warn(
'download function is deprecated, and please use multithread_download.',
DeprecationWarning)
with m3u8_To_MP4.multithreads_processor.Crawler(m3u8_uri, max_retry_times,
max_num_workers,
mp4_file_dir,
mp4_file_name,
tmpdir) as crawler:
crawler.fetch_mp4_by_m3u8_uri()

View File

@ -0,0 +1,303 @@
# -*- coding: utf-8 -*-
import collections
import logging
import os
import shutil
import subprocess
import tarfile
import tempfile
import time
import zlib
import m3u8
from m3u8_To_MP4.helpers import path_helper
from m3u8_To_MP4.helpers import printer_helper
from m3u8_To_MP4.networks.asynchronous import async_producer_consumer
from m3u8_To_MP4.networks.synchronous import sync_DNS
from m3u8_To_MP4.networks.synchronous import sync_http
printer_helper.config_logging()
EncryptedKey = collections.namedtuple(typename='EncryptedKey',
field_names=['method', 'value', 'iv'])
class Crawler(object):
def __init__(self, m3u8_uri, max_retry_times=3, num_concurrent=50,
mp4_file_dir=None, mp4_file_name='m3u8-To-Mp4.mp4',
tmpdir=None):
self.m3u8_uri = m3u8_uri
self.max_retry_times = max_retry_times
self.num_concurrent = num_concurrent
self.tmpdir = tmpdir
self.mp4_file_dir = mp4_file_dir
self.mp4_file_name = mp4_file_name
self.use_ffmpeg = True
def __enter__(self):
if self.tmpdir is None:
self._apply_for_tmpdir()
self.segment_path_recipe = os.path.join(self.tmpdir, "ts_recipe.txt")
self._find_out_done_ts()
self._legalize_mp4_file_path()
self._imitate_tar_file_path()
print('\nsummary')
print(
'm3u8_uri: {};\nmax_retry_times: {};\ntmp_dir: {};\nmp4_file_path: {};\n'.format(
self.m3u8_uri, self.max_retry_times, self.tmpdir,
self.mp4_file_path))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._freeup_tmpdir()
def _apply_for_tmpdir(self):
os_tmp_dir = tempfile.gettempdir()
url_crc32_str = str(
zlib.crc32(self.m3u8_uri.encode())) # hash algorithm
self.tmpdir = os.path.join(os_tmp_dir, 'm3u8_' + url_crc32_str)
if not os.path.exists(self.tmpdir):
os.mkdir(self.tmpdir)
def _freeup_tmpdir(self):
os_tmp_dir = tempfile.gettempdir()
for file_symbol in os.listdir(os_tmp_dir):
if file_symbol.startswith('m3u8_'):
file_symbol_absolute_path = os.path.join(os_tmp_dir,
file_symbol)
if os.path.isdir(file_symbol_absolute_path):
shutil.rmtree(file_symbol_absolute_path)
def _find_out_done_ts(self):
file_names_in_tmpdir = os.listdir(self.tmpdir)
full_ts_file_names = list()
for file_name in reversed(file_names_in_tmpdir):
if file_name == 'ts_recipe.txt':
continue
absolute_file_path = os.path.join(self.tmpdir, file_name)
if os.path.getsize(absolute_file_path) > 0:
full_ts_file_names.append(file_name)
self.fetched_file_names = full_ts_file_names
def _legalize_mp4_file_path(self):
if self.mp4_file_dir is None:
self.mp4_file_dir = os.getcwd()
is_valid, mp4_file_name = path_helper.calibrate_mp4_file_name(
self.mp4_file_name)
if not is_valid:
mp4_file_name = path_helper.create_mp4_file_name()
mp4_file_path = os.path.join(self.mp4_file_dir, mp4_file_name)
if os.path.exists(mp4_file_path):
mp4_file_name = path_helper.create_mp4_file_name()
mp4_file_path = os.path.join(self.mp4_file_dir, mp4_file_name)
self.mp4_file_path = mp4_file_path
def _imitate_tar_file_path(self):
self.tar_file_path = self.mp4_file_path[:-4] + '.tar.bz2'
def _resolve_DNS(self):
self.available_addr_info_pool = sync_DNS.available_addr_infos_of_url(
self.m3u8_uri)
self.best_addr_info = self.available_addr_info_pool[0]
logging.info('Resolved available hosts:')
for addr_info in self.available_addr_info_pool:
logging.info('{}:{}'.format(addr_info.host, addr_info.port))
def _request_m3u8_obj_from_url(self):
try:
response_code, m3u8_bytes = sync_http.retrieve_resource_from_url(
self.best_addr_info, self.m3u8_uri)
if response_code != 200:
raise Exception(
'DOWNLOAD KEY FAILED, URI IS {}'.format(self.m3u8_uri))
m3u8_str = m3u8_bytes.decode()
m3u8_obj = m3u8.loads(content=m3u8_str, uri=self.m3u8_uri)
except Exception as exc:
logging.exception(
'Failed to load m3u8 file,reason is {}'.format(exc))
raise Exception('FAILED TO LOAD M3U8 FILE!')
return m3u8_obj
def _get_m3u8_obj_with_best_bandwitdth(self):
m3u8_obj = self._request_m3u8_obj_from_url()
if m3u8_obj.is_variant:
best_bandwidth = -1
best_bandwidth_m3u8_uri = None
for playlist in m3u8_obj.playlists:
if playlist.stream_info.bandwidth > best_bandwidth:
best_bandwidth = playlist.stream_info.bandwidth
best_bandwidth_m3u8_uri = playlist.absolute_uri
logging.info("Choose the best bandwidth, which is {}".format(
best_bandwidth))
logging.info("Best m3u8 uri is {}".format(best_bandwidth_m3u8_uri))
self.m3u8_uri = best_bandwidth_m3u8_uri
m3u8_obj = self._request_m3u8_obj_from_url()
return m3u8_obj
def _construct_key_segment_pairs_by_m3u8(self, m3u8_obj):
key_segments_pairs = list()
for key in m3u8_obj.keys:
if key:
if key.method.lower() == 'none':
continue
response_code, encryped_value = sync_http.retrieve_resource_from_url(
self.best_addr_info, key.absolute_uri)
if response_code != 200:
raise Exception('DOWNLOAD KEY FAILED, URI IS {}'.format(
key.absolute_uri))
encryped_value = encryped_value.decode()
_encrypted_key = EncryptedKey(method=key.method,
value=encryped_value, iv=key.iv)
key_segments = m3u8_obj.segments.by_key(key)
segments_by_key = [(_encrypted_key, segment.absolute_uri) for
segment in key_segments]
key_segments_pairs.extend(segments_by_key)
if len(key_segments_pairs) == 0:
_encrypted_key = None
key_segments = m3u8_obj.segments
segments_by_key = [(_encrypted_key, segment.absolute_uri) for
segment in key_segments]
key_segments_pairs.extend(segments_by_key)
return key_segments_pairs
def _is_ads(self, segment_uri):
if segment_uri.startswith(self.longest_common_subsequence):
return True
if not segment_uri.endswith('.ts'):
return True
return False
def _filter_ads_ts(self, key_segments_pairs):
self.longest_common_subsequence = path_helper.longest_common_subsequence(
[segment_uri for _, segment_uri in key_segments_pairs])
key_segments_pairs = [(_encrypted_key, segment_uri) for
_encrypted_key, segment_uri in key_segments_pairs
if not self._is_ads(segment_uri)]
return key_segments_pairs
def _is_fetched(self, segment_uri):
file_name = path_helper.resolve_file_name_by_uri(segment_uri)
if file_name in self.fetched_file_names:
return True
return False
def _filter_done_ts(self, key_segments_pairs):
num_ts_segments = len(key_segments_pairs)
key_segments_pairs = [(_encrypted_key, segment_uri) for
_encrypted_key, segment_uri in key_segments_pairs
if not self._is_fetched(segment_uri)]
self.num_fetched_ts_segments = num_ts_segments - len(
key_segments_pairs)
return key_segments_pairs
def _fetch_segments_to_local_tmpdir(self, key_segments_pairs):
async_producer_consumer.factory_pipeline(self.num_fetched_ts_segments,
key_segments_pairs,
self.available_addr_info_pool,
self.num_concurrent,
self.tmpdir)
def _construct_segment_path_recipe(self, key_segment_pairs):
with open(self.segment_path_recipe, 'w', encoding='utf8') as fw:
for _, segment in key_segment_pairs:
file_name = path_helper.resolve_file_name_by_uri(segment)
segment_file_path = os.path.join(self.tmpdir, file_name)
fw.write("file '{}'\n".format(segment_file_path))
def _merge_to_mp4_by_ffmpeg(self):
merge_cmd = "ffmpeg -y -f concat -safe 0 -i " + '"' + self.segment_path_recipe + '"' + " -c copy " + '"' + self.mp4_file_path + '"'
p = subprocess.Popen(merge_cmd, shell=True, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
logging.info("merging segments...")
p.communicate()
def _merge_to_mp4_by_os(self):
raise NotImplementedError
def _merge_to_tar_by_os(self):
with tarfile.open(self.tar_file_path, 'w:bz2') as targz:
targz.add(name=self.tmpdir, arcname=os.path.basename(self.tmpdir))
def fetch_mp4_by_m3u8_uri(self, as_mp4):
task_start_time = time.time()
# preparation
self._resolve_DNS()
# resolve ts segment uris
m3u8_obj = self._get_m3u8_obj_with_best_bandwitdth()
key_segments_pairs = self._construct_key_segment_pairs_by_m3u8(
m3u8_obj)
key_segments_pairs = self._filter_ads_ts(key_segments_pairs)
self._construct_segment_path_recipe(key_segments_pairs)
key_segments_pairs = self._filter_done_ts(key_segments_pairs)
# download
if len(key_segments_pairs) > 0:
self._fetch_segments_to_local_tmpdir(key_segments_pairs)
fetch_end_time = time.time()
if as_mp4:
self._merge_to_mp4_by_ffmpeg()
task_end_time = time.time()
printer_helper.display_speed(task_start_time, fetch_end_time,
task_end_time, self.mp4_file_path)
else:
self._merge_to_tar_by_os()
task_end_time = time.time()
printer_helper.display_speed(task_start_time, fetch_end_time,
task_end_time, self.tar_file_path)

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,6 @@
# -*- coding: utf-8 -*-
import os
def get_core_count():
return os.cpu_count()

View File

@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
import datetime
import random
from collections import Counter
WINDOWS_BANNED_CHARACTERS = ['\\', '/', ':', '*', '?', '"', '<', '>', '|']
def updated_resource_path(path, query):
if query:
path = f'{path}?{query}'
return path
def resolve_file_name_by_uri(uri):
# pattern = r"\/+(.*)"
# file_name = re.findall(pattern=pattern, string=uri)[0]
name = uri.split('/')[-1]
return calibrate_name(name)
def calibrate_mp4_file_name(mp4_file_name):
mp4_file_name = calibrate_name(mp4_file_name)
return mp4_file_name
def random_5_char():
random_digits = [str(random.randint(0, 10)) for _ in range(5)]
return ''.join(random_digits)
def random_name():
dt_str = datetime.datetime.now().strftime('%Y-%m-%d %H-%M-%S')
return 'm3u8_To_MP4' + dt_str + random_5_char()+'.mp4'
def calibrate_name(name):
if len(name.strip()) == 0:
return random_name()
for ch in WINDOWS_BANNED_CHARACTERS:
name = name.replace(ch, '')
return name
# def create_mp4_file_name():
# mp4_file_name = 'm3u8_To_Mp4_{}.mp4'.format(
# datetime.datetime.now().strftime('%Y-%m-%d %H-%M-%S'))
# return mp4_file_name
def longest_common_subsequence(segment_absolute_urls):
num_shortest_segment_absolute_url_length = min(
len(url) for url in segment_absolute_urls)
common_subsequence = list()
for i in range(num_shortest_segment_absolute_url_length):
c = Counter(segment_absolute_url[i] for segment_absolute_url in
segment_absolute_urls)
common_subsequence.append(c.most_common(1)[0][0])
return ''.join(common_subsequence)

View File

@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
import logging
import os
import sys
def config_logging():
str_format = '%(asctime)s | %(levelname)s | %(message)s'
logging.basicConfig(format=str_format, level=logging.INFO)
class ProcessBar:
def __init__(self, progress, max_iter, prefix='Progress',
suffix='downloading', completed_suffix='completed',
bar_length=50):
self.progress = progress
self.max_iter = max_iter
self.bar_length = bar_length
self.prefix = prefix
self.suffix = suffix
self.completed_suffix = completed_suffix
def display(self):
progress_rate = self.progress / self.max_iter
percent = 100 * progress_rate
filled_length = round(self.bar_length * progress_rate)
bar = '#' * filled_length + '-' * (self.bar_length - filled_length)
sys.stdout.write(
'\r{}: |{}| {:.1f}% {}'.format(self.prefix, bar, percent,
self.suffix))
if self.progress == self.max_iter:
sys.stdout.write(
'\r{}: |{}| {:.1f}% {}'.format(self.prefix, bar, percent,
self.completed_suffix))
sys.stdout.write('\n')
sys.stdout.flush()
def update(self):
self.progress += 1
self.display()
def display_speed(start_time, fetch_end_time, task_end_time, target_mp4_file_path):
download_time = fetch_end_time - start_time
total_time = task_end_time - start_time
if download_time < 0.01:
download_speed = os.path.getsize(target_mp4_file_path) / 1024
else:
download_speed = os.path.getsize( target_mp4_file_path) / download_time / 1024
logging.info( "download successfully take {:.2f}s, average download speed is {:.2f}KB/s".format( total_time, download_speed))

View File

@ -0,0 +1,249 @@
# -*- coding: utf-8 -*-
import collections
import concurrent.futures
import logging
import os
import shutil
import subprocess
import sys
import tempfile
import time
import m3u8
from Crypto.Cipher import AES
from m3u8_To_MP4.helpers import path_helper
from m3u8_To_MP4.helpers import printer_helper
from m3u8_To_MP4.networks.synchronous.sync_http_requester import request_for
printer_helper.config_logging()
def download_segment(segment_url):
response_code, response_content = request_for(segment_url)
return response_code, response_content
EncryptedKey = collections.namedtuple(typename='EncryptedKey',
field_names=['method', 'value', 'iv'])
class Crawler(object):
def __init__(self, m3u8_uri, max_retry_times=3, max_num_workers=100,
mp4_file_dir='./', mp4_file_name='m3u8_To_Mp4.mp4',
tmpdir=None):
self.m3u8_uri = m3u8_uri
self.max_retry_times = max_retry_times
self.max_num_workers = max_num_workers
self.tmpdir = tmpdir
self.fetched_file_names = list()
self.mp4_file_dir = mp4_file_dir
self.mp4_file_name = mp4_file_name
self.mp4_file_path = None
def __enter__(self):
if self.tmpdir is None:
self._apply_for_tmpdir()
self.fetched_file_names = os.listdir(self.tmpdir)
self._legalize_valid_mp4_file_path()
print('\nsummary:')
print(
'm3u8_uri: {};\nmax_retry_times: {};\nmax_num_workers: {};\ntmp_dir: {};\nmp4_file_path: {};\n'.format(
self.m3u8_uri, self.max_retry_times, self.max_num_workers,
self.tmpdir, self.mp4_file_path))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._freeup_tmpdir()
def _apply_for_tmpdir(self):
self.tmpdir = tempfile.mkdtemp(prefix='m3u8_')
def _freeup_tmpdir(self):
if os.path.exists(self.tmpdir):
shutil.rmtree(self.tmpdir)
def _legalize_valid_mp4_file_path(self):
is_valid, mp4_file_name = path_helper.calibrate_mp4_file_name(
self.mp4_file_name)
if not is_valid:
mp4_file_name = path_helper.create_mp4_file_name()
mp4_file_path = os.path.join(self.mp4_file_dir, mp4_file_name)
if os.path.exists(mp4_file_path):
mp4_file_name = path_helper.create_mp4_file_name()
mp4_file_path = os.path.join(self.mp4_file_dir, mp4_file_name)
self.mp4_file_path = mp4_file_path
def _get_m3u8_obj_by_uri(self, m3u8_uri):
try:
m3u8_obj = m3u8.load(uri=m3u8_uri)
except Exception as exc:
logging.exception(
'failed to load m3u8 file,reason is {}'.format(exc))
raise Exception('FAILED TO LOAD M3U8 FILE!')
return m3u8_obj
def _get_m3u8_obj_with_best_bandwitdth(self, m3u8_uri):
m3u8_obj = self._get_m3u8_obj_by_uri(m3u8_uri)
if m3u8_obj.is_variant:
best_bandwidth = -1
best_bandwidth_m3u8_uri = None
for playlist in m3u8_obj.playlists:
if playlist.stream_info.bandwidth > best_bandwidth:
best_bandwidth = playlist.stream_info.bandwidth
best_bandwidth_m3u8_uri = playlist.absolute_uri
logging.info("choose the best bandwidth, which is {}".format(
best_bandwidth))
logging.info("m3u8 uri is {}".format(best_bandwidth_m3u8_uri))
m3u8_obj = self._get_m3u8_obj_by_uri(best_bandwidth_m3u8_uri)
return m3u8_obj
def _is_fetched(self, segment_uri):
file_name = path_helper.resolve_file_name_by_uri(segment_uri)
if file_name in self.fetched_file_names:
return True
return False
def _construct_key_segment_pairs_by_m3u8(self, m3u8_obj):
key_segments_pairs = list()
for key in m3u8_obj.keys:
if key:
if key.method.lower() == 'none':
continue
response_code, encryped_value = request_for(key.absolute_uri,
max_try_times=self.max_retry_times)
if response_code != 200:
raise Exception('DOWNLOAD KEY FAILED, URI IS {}'.format(
key.absolute_uri))
_encrypted_key = EncryptedKey(method=key.method,
value=encryped_value, iv=key.iv)
key_segments = m3u8_obj.segments.by_key(key)
segments_by_key = [segment.absolute_uri for segment in
key_segments if
not self._is_fetched(segment.absolute_uri)]
key_segments_pairs.append((_encrypted_key, segments_by_key))
if len(key_segments_pairs) == 0:
_encrypted_key = None
key_segments = m3u8_obj.segments
segments_by_key = [segment.absolute_uri for segment in key_segments
if not self._is_fetched(segment.absolute_uri)]
key_segments_pairs.append((_encrypted_key, segments_by_key))
return key_segments_pairs
def _fetch_segments_to_local_tmpdir(self, num_segments,
key_segments_pairs):
if len(self.fetched_file_names) >= num_segments:
return
progress_bar = printer_helper.ProcessBar(len(self.fetched_file_names),
num_segments, 'segment set',
'downloading...',
'downloaded segments successfully!')
for encrypted_key, segments_by_key in key_segments_pairs:
segment_url_to_encrypted_content = list()
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.max_num_workers) as executor:
while len(segments_by_key) > 0:
future_2_segment_uri = {
executor.submit(download_segment, segment_url): segment_url
for segment_url in segments_by_key}
response_code, response_content = None, None
for future in concurrent.futures.as_completed(
future_2_segment_uri):
segment_uri = future_2_segment_uri[future]
try:
response_code, response_content = future.result()
except Exception as exc:
logging.exception(
'{} generated an exception: {}'.format(
segment_uri, exc))
if response_code == 200:
segment_url_to_encrypted_content.append(
(segment_uri, response_content))
segments_by_key.remove(segment_uri)
progress_bar.update()
if len(segments_by_key) > 0:
sys.stdout.write('\n')
logging.info(
'{} segments are failed to download, retry...'.format(
len(segments_by_key)))
logging.info('decrypt and dump segments...')
for segment_url, encrypted_content in segment_url_to_encrypted_content:
file_name = path_helper.resolve_file_name_by_uri(segment_url)
file_path = os.path.join(self.tmpdir, file_name)
if encrypted_key is not None:
crypt_ls = {"AES-128": AES}
crypt_obj = crypt_ls[encrypted_key.method]
cryptor = crypt_obj.new(encrypted_key.value,
crypt_obj.MODE_CBC)
encrypted_content = cryptor.decrypt(encrypted_content)
with open(file_path, 'wb') as fin:
fin.write(encrypted_content)
def _merge_tmpdir_segments_to_mp4_by_ffmpeg(self, m3u8_obj):
order_segment_list_file_path = os.path.join(self.tmpdir, "ts_ls.txt")
with open(order_segment_list_file_path, 'w', encoding='utf8') as fin:
for segment in m3u8_obj.segments:
file_name = path_helper.resolve_file_name_by_uri(segment.uri)
segment_file_path = os.path.join(self.tmpdir, file_name)
fin.write("file '{}'\n".format(segment_file_path))
merge_cmd = "ffmpeg -y -f concat -safe 0 -i " + '"' + order_segment_list_file_path + '"' + " -c copy " + '"' + self.mp4_file_path + '"'
p = subprocess.Popen(merge_cmd, shell=True, stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
logging.info("merging segments...")
p.communicate()
def fetch_mp4_by_m3u8_uri(self):
m3u8_obj = self._get_m3u8_obj_with_best_bandwitdth(self.m3u8_uri)
key_segments_pairs = self._construct_key_segment_pairs_by_m3u8(
m3u8_obj)
start_time = time.time()
self._fetch_segments_to_local_tmpdir(len(m3u8_obj.segments),
key_segments_pairs)
fetch_end_time = time.time()
self._merge_tmpdir_segments_to_mp4_by_ffmpeg(m3u8_obj)
task_end_time = time.time()
if len(self.fetched_file_names) < len(m3u8_obj.segments):
printer_helper.display_speed(start_time, fetch_end_time,
task_end_time, self.mp4_file_path)

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
import asyncio
import socket
import urllib.parse
from m3u8_To_MP4.networks.http_base import AddressInfo
async def available_addr_infos_of_url(url):
loop = asyncio.get_event_loop()
scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url)
# todo:: support IPv6
addr_infos = await loop.getaddrinfo(host=netloc, port=scheme,
family=socket.AF_INET)
available_addr_info_pool = list()
for family, type, proto, canonname, sockaddr in addr_infos:
ai = AddressInfo(host=sockaddr[0], port=sockaddr[1], family=family,
proto=proto)
available_addr_info_pool.append(ai)
return available_addr_info_pool

View File

@ -0,0 +1,120 @@
# -*- coding: utf-8 -*-
import asyncio
import logging
import urllib.error
import urllib.parse
import urllib.request
import urllib.response
from m3u8_To_MP4.helpers import path_helper
from m3u8_To_MP4.networks import http_base
def http_get_header(domain_name, port, resource_path_at_server, is_keep_alive):
http_get_resource = http_base.statement_of_http_get_resource_str(
resource_path_at_server)
http_connect_address = http_base.statement_of_http_connect_address_str(
domain_name, port)
user_agent = http_base.random_user_agent_str()
http_connection = http_base.statement_of_http_connection_str(is_keep_alive)
# x_forward_for=random_x_forward_for()
# cookie=random_cookie()
request_header = '\r\n'.join((http_get_resource, http_connect_address,
user_agent, http_connection)) + '\r\n\r\n'
return request_header.encode()
async def handler_of_connection(address_info, default_ssl_context,
limit=256 * 1024):
loop = asyncio.get_event_loop()
host = address_info.host
port = address_info.port
family = address_info.family
proto = address_info.proto
reader = asyncio.StreamReader(limit=limit, loop=loop)
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
transport, protocol = await loop.create_connection(lambda: protocol, host,
port,
ssl=default_ssl_context,
family=family,
proto=proto)
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
return reader, writer
async def retrieve_resource_from_handler(reader, writer, request_header):
writer.write(data=request_header)
await writer.drain()
byted_response_header = await reader.readuntil(separator=b'\r\n\r\n')
response_header = byted_response_header.decode()
http_header_state = http_base.formatted_http_header(response_header)
content_length = -1
if 'content_length' in http_header_state:
content_length = int(http_header_state['content_length'])
byted_response_content = await reader.read(n=content_length)
return http_header_state, byted_response_content
async def retrieve_resource_from_url(address_info, url, ssl_context_,
is_keep_alive=False, max_retry_times=5,
limit=2 ** 10):
port = address_info.port
scheme, domain_name_with_suffix, path, query, fragment = urllib.parse.urlsplit(
url)
resource_path = path_helper.updated_resource_path(path, query)
reader, writer = None, None
response_header_state, byted_response_content = {'response_code': -1}, None
num_retry = 0
while num_retry < max_retry_times:
try:
request_header = http_get_header(domain_name_with_suffix, port,
resource_path, is_keep_alive)
reader, writer = await asyncio.wait_for(
handler_of_connection(address_info, ssl_context_, limit), 3)
response_header_state, byted_response_content = await asyncio.wait_for(
retrieve_resource_from_handler(reader, writer, request_header),
timeout=8)
except asyncio.TimeoutError as te:
logging.debug('request timeout: {}'.format(url))
except Exception as exc:
logging.debug(
'request failed: {}, and caused reason is {}'.format(url,
str(exc)))
try:
if not is_keep_alive:
if writer is not None:
writer.close()
await writer.wait_closed()
# if reader is not None:
# reader.feed_eof()
# assert reader.at_eof()
except Exception as exc:
logging.debug(
'request failed: {}, and caused reason is {}'.format(url,
str(exc)))
if response_header_state['response_code'] == 200:
return response_header_state, byted_response_content
num_retry += 1
else:
return response_header_state, None

View File

@ -0,0 +1,139 @@
# -*- coding: utf-8 -*-
import asyncio
import itertools
import logging
import os
import platform
# import random
import urllib.parse
import urllib.parse
from collections import Counter
from multiprocessing import JoinableQueue, Process
from Crypto.Cipher import AES
from m3u8_To_MP4.helpers import path_helper
from m3u8_To_MP4.helpers import printer_helper
from m3u8_To_MP4.networks import http_base
from m3u8_To_MP4.networks.asynchronous import async_http
async def ts_request(concurrent_condition, ssl_context, addr_info,
_encrypted_key, segment_uri):
async with concurrent_condition:
response_header_state, response_content_bytes = await async_http.retrieve_resource_from_url(
addr_info, segment_uri, ssl_context, limit=256 * 1024)
return addr_info, segment_uri, response_header_state, _encrypted_key, response_content_bytes
async def ts_producer_scheduler(key_segment_pairs, addr_infos, ts_queue,
num_concurrent, addr_quantity_statistics):
exec_event_loop = asyncio.get_event_loop()
# concurrent_condition = asyncio.Semaphore(value=50, loop=exec_event_loop)# Python 3.10 does not recommend
concurrent_condition = asyncio.Semaphore(value=num_concurrent)
scheme, domain_name_with_suffix, path, query, fragment = urllib.parse.urlsplit(
key_segment_pairs[0][1])
default_ssl_context = http_base.ssl_under_scheme(scheme)
task_params = list()
for (_encrypted_key, segment_uri), addr_info in zip(key_segment_pairs,
itertools.cycle(
addr_infos)):
task_params.append((concurrent_condition, default_ssl_context,
addr_info, _encrypted_key, segment_uri))
awaitable_tasks = list()
for params in task_params:
awaitable_task = exec_event_loop.create_task(ts_request(*params))
awaitable_tasks.append(awaitable_task)
incompleted_tasks = list()
for task in asyncio.as_completed(awaitable_tasks):
addr_info, segment_uri, response_header_state, encrypted_key, encrypted_content_bytes = await task
if response_header_state['response_code'] == 200:
file_name = path_helper.resolve_file_name_by_uri(segment_uri)
ts_queue.put((encrypted_key, encrypted_content_bytes, file_name))
else:
incompleted_tasks.append((encrypted_key, segment_uri))
addr_quantity_statistics.update([addr_info.host])
return incompleted_tasks, addr_quantity_statistics
def producer_process(key_segment_uris, addr_infos, ts_queue, num_concurrent):
incompleted_tasks = key_segment_uris
num_efficient_addr_info = int(len(addr_infos) * 0.5)
num_efficient_addr_info = 1 if num_efficient_addr_info < 1 else num_efficient_addr_info
addr_quantity_statistics = Counter(
{addr_info.host: 0 for addr_info in addr_infos})
# solve error in windows: event loop is closed
if platform.system().lower() == 'windows':
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
while len(incompleted_tasks) > 0:
incompleted_tasks, addr_quantity_statistics = asyncio.run(
ts_producer_scheduler(incompleted_tasks, addr_infos, ts_queue,
num_concurrent, addr_quantity_statistics))
efficient_hosts = [host for host, _ in
addr_quantity_statistics.most_common()]
efficient_hosts = efficient_hosts[-num_efficient_addr_info:]
addr_infos = [addr_info for addr_info in addr_infos if
addr_info.host in efficient_hosts]
# random.shuffle(addr_infos)
if len(incompleted_tasks) > 0:
print()
logging.info(
'{} requests failed, retry ...'.format(len(incompleted_tasks)))
def consumer_process(ts_queue, tmpdir, progress_bar):
while True:
encrypted_key, encrypted_content, file_name = ts_queue.get()
if encrypted_key is not None:
crypt_ls = {"AES-128": AES}
crypt_obj = crypt_ls[encrypted_key.method]
cryptor = crypt_obj.new(encrypted_key.value.encode(),
crypt_obj.MODE_CBC)
encrypted_content = cryptor.decrypt(encrypted_content)
file_path = os.path.join(tmpdir, file_name)
with open(file_path, 'wb') as fin:
fin.write(encrypted_content)
ts_queue.task_done()
progress_bar.update()
def factory_pipeline(num_fetched_ts_segments, key_segments_pairs,
available_addr_info_pool, num_concurrent, tmpdir):
num_ts_segments = len(key_segments_pairs)
progress_bar = printer_helper.ProcessBar(num_fetched_ts_segments,
num_ts_segments + num_fetched_ts_segments,
'segment set', 'downloading...',
'downloaded segments successfully!')
# schedule tasks
ts_queue = JoinableQueue()
ts_producer = Process(target=producer_process, args=(
key_segments_pairs, available_addr_info_pool, ts_queue, num_concurrent))
ts_consumer = Process(target=consumer_process,
args=(ts_queue, tmpdir, progress_bar))
ts_producer.start()
ts_consumer.daemon = True
ts_consumer.start()
ts_producer.join()

View File

@ -0,0 +1,132 @@
# -*- coding: utf-8 -*-
import collections
import random
import ssl
AddressInfo = collections.namedtuple(typename='AddressInfo',
field_names=['host', 'port', 'family',
'proto'])
def statement_of_http_get_resource(resource_path_at_server):
return f'{resource_path_at_server}'
def statement_of_http_get_resource_str(resource_path_at_server):
return f'GET {statement_of_http_get_resource(resource_path_at_server)} HTTP/1.1'
def statement_of_http_connect_address(domain_name, port):
return f'{domain_name}:{port}'
def statement_of_http_connect_address_str(domain_name, port):
return f'Host: {statement_of_http_connect_address(domain_name,port)}'
def random_user_agent():
user_agent_pool = [
'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/535.1 (KHTML, like Gecko) Chrome/14.0.835.163 Safari/535.1',
'Mozilla/5.0 (Windows NT 6.1; WOW64; rv:6.0) Gecko/20100101 Firefox/6.0',
'Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/534.50 (KHTML, like Gecko) Version/5.1 Safari/534.50',
'Opera/9.80 (Windows NT 6.1; U; zh-cn) Presto/2.9.168 Version/11.50',
'Mozilla/5.0 (compatible; MSIE 9.0; Windows NT 6.1; Win64; x64; Trident/5.0; .NET CLR 2.0.50727; SLCC2; .NET CLR 3.5.30729; .NET CLR 3.0.30729; Media Center PC 6.0; InfoPath.3; .NET4.0C; Tablet PC 2.0; .NET4.0E)',
'Mozilla/4.0 (compatible; MSIE 8.0; Windows NT 6.1; WOW64; Trident/4.0; SLCC2; .NET CLR 2.0.50727; .NET CLR 3.5.30729; .NET CLR 3.0.30729; Media Center PC 6.0; .NET4.0C; InfoPath.3)',
'Mozilla/4.0 (compatible; MSIE 8.0; Windows NT 5.1; Trident/4.0; GTB7.0)',
'Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 5.1)',
'Mozilla/4.0 (compatible; MSIE 6.0; Windows NT 5.1; SV1)',
'Mozilla/5.0 (Windows; U; Windows NT 6.1; ) AppleWebKit/534.12 (KHTML, like Gecko) Maxthon/3.0 Safari/534.12',
'Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.1; WOW64; Trident/5.0; SLCC2; .NET CLR 2.0.50727; .NET CLR 3.5.30729; .NET CLR 3.0.30729; Media Center PC 6.0; InfoPath.3; .NET4.0C; .NET4.0E)',
'Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.1; WOW64; Trident/5.0; SLCC2; .NET CLR 2.0.50727; .NET CLR 3.5.30729; .NET CLR 3.0.30729; Media Center PC 6.0; InfoPath.3; .NET4.0C; .NET4.0E; SE 2.X MetaSr 1.0)',
'Mozilla/5.0 (Windows; U; Windows NT 6.1; en-US) AppleWebKit/534.3 (KHTML, like Gecko) Chrome/6.0.472.33 Safari/534.3 SE 2.X MetaSr 1.0',
'Mozilla/5.0 (compatible; MSIE 9.0; Windows NT 6.1; WOW64; Trident/5.0; SLCC2; .NET CLR 2.0.50727; .NET CLR 3.5.30729; .NET CLR 3.0.30729; Media Center PC 6.0; InfoPath.3; .NET4.0C; .NET4.0E)',
'Mozilla/5.0 (Windows NT 6.1) AppleWebKit/535.1 (KHTML, like Gecko) Chrome/13.0.782.41 Safari/535.1 QQBrowser/6.9.11079.201',
'Mozilla/4.0 (compatible; MSIE 7.0; Windows NT 6.1; WOW64; Trident/5.0; SLCC2; .NET CLR 2.0.50727; .NET CLR 3.5.30729; .NET CLR 3.0.30729; Media Center PC 6.0; InfoPath.3; .NET4.0C; .NET4.0E) QQBrowser/6.9.11079.201',
'Mozilla/5.0 (compatible; MSIE 9.0; Windows NT 6.1; WOW64; Trident/5.0)',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.107 Safari/537.36',
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.77 Safari/537.36'
]
return random.choice(user_agent_pool)
def random_user_agent_str():
return 'User-Agent: ' + random_user_agent()
def statement_of_http_connection(is_keep_alive):
state = 'Keep-alive' if is_keep_alive else 'close'
return state
def statement_of_http_connection_str(is_keep_alive):
return 'Connection: ' + statement_of_http_connection(is_keep_alive)
# optional
def random_x_forward_for():
num_ips = random.randint(2, 5)
ip_segments = list()
while len(ip_segments) < 4 * num_ips:
ip_segments.append(str(random.randint(0, 256)))
ips = list()
for ip_index in range(num_ips):
ips.append('.'.join(ip_segments[ip_index * 4:(ip_index + 1) * 4]))
return ','.join(ips)
def random_x_forward_for_str():
return 'X-Forwarded-For ' + random_x_forward_for()
# optional
def random_cookie():
return 'Cookie: ' + ''
# ssl
def ssl_context():
_ssl_context = ssl.create_default_context()
_ssl_context.check_hostname = False
_ssl_context.verify_mode = ssl.CERT_NONE
return _ssl_context
def ssl_under_scheme(scheme):
if scheme == 'https':
return ssl_context()
elif scheme == 'http':
return False
else:
raise ValueError('{} is not supported now.'.format(scheme))
def formatted_http_header(http_header_str):
http_header_state = dict()
http_header_lines = http_header_str.strip().split('\n')
response_fragments = http_header_lines[0].strip().split()
response_code = response_fragments[1]
response_description = ' '.join(response_fragments[2:])
http_header_state['response_code'] = int(response_code)
http_header_state['response_description'] = response_description
for line in http_header_lines[1:]:
line = line.strip()
i = line.find(':')
if i == -1:
continue
key = line[:i]
value = line[i + 1:]
http_header_state[key.lower()] = value
return http_header_state

View File

@ -0,0 +1 @@
# -*- coding: utf-8 -*-

View File

@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
import re
import socket
import urllib.parse
from m3u8_To_MP4.networks.http_base import AddressInfo
def available_addr_infos_of_url(url):
scheme, netloc, path, query, fragment = urllib.parse.urlsplit(url)
specific_port_pattern = re.compile(r':(\d+)')
specific_ports = re.findall(specific_port_pattern, netloc)
netloc = re.sub(specific_port_pattern, '', netloc)
# todo:: support IPv6
addr_infos = socket.getaddrinfo(host=netloc, port=scheme,
family=socket.AF_INET)
available_addr_info_pool = list()
for family, type, proto, canonname, sockaddr in addr_infos:
port = specific_ports[0] if len(specific_ports) > 0 else sockaddr[1]
ai = AddressInfo(host=sockaddr[0], port=port, family=family,
proto=proto)
available_addr_info_pool.append(ai)
return available_addr_info_pool

View File

@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
import urllib.error
import urllib.parse
import urllib.request
import urllib.response
from m3u8_To_MP4.helpers import path_helper
from m3u8_To_MP4.networks import http_base
def http_get_header(domain_name, port, resource_path_at_server, is_keep_alive,
customized_http_header):
request_header = dict()
# http_get_resource = http_base.statement_of_http_get_resource(resource_path_at_server)
# http_connect_address = http_base.statement_of_http_connect_address(domain_name, port)
user_agent = http_base.random_user_agent()
request_header['User-Agent'] = user_agent
if customized_http_header is not None:
request_header.update(customized_http_header)
# http_connection = http_base.statement_of_http_connection(is_keep_alive)
# x_forward_for=random_x_forward_for()
# cookie=random_cookie()
return request_header
def retrieve_resource_from_url(address_info, url, is_keep_alive=False,
max_retry_times=5, timeout=30,
customized_http_header=None):
port = address_info.port
scheme, domain_name_with_suffix, path, query, fragment = urllib.parse.urlsplit(
url)
resource_path = path_helper.updated_resource_path(path, query)
response_code = -1
response_content = None
for num_retry in range(max_retry_times):
headers = http_get_header(domain_name_with_suffix, port, resource_path,
is_keep_alive, customized_http_header)
try:
request = urllib.request.Request(url=url, headers=headers)
with urllib.request.urlopen(url=request,
timeout=timeout) as response:
response_code = response.getcode()
response_content = response.read()
if response_code == 200:
break
except urllib.error.HTTPError as he:
response_code = he.code
except urllib.error.ContentTooShortError as ctse:
response_code = -2 # -2:ctse
except urllib.error.URLError as ue:
response_code = -3 # -3:URLError
except Exception as exc:
response_code = -4 # other error
finally:
timeout += 2
return response_code, response_content

View File

@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
import urllib.error
import urllib.parse
import urllib.request
import urllib.response
def get_headers(customized_http_header):
request_header = dict()
request_header[
'User-Agent'] = 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.77 Safari/537.36'
if customized_http_header is not None:
request_header.update(customized_http_header)
return request_header
def request_for(url, max_try_times=1, headers=None, data=None, timeout=30,
proxy_ip=None, verify=False, customized_http_header=None):
response_code = -1
response_content = None
for num_retry in range(max_try_times):
if headers is None:
headers = get_headers(customized_http_header)
try:
request = urllib.request.Request(url=url, data=data,
headers=headers)
with urllib.request.urlopen(url=request,
timeout=timeout) as response:
response_code = response.getcode()
response_content = response.read()
if response_code == 200:
break
except urllib.error.HTTPError as he:
response_code = he.code
except urllib.error.ContentTooShortError as ctse:
response_code = -2 # -2:ctse
except urllib.error.URLError as ue:
response_code = -3 # -3:URLError
except Exception as exc:
response_code = -4 # other error
finally:
timeout += 2
return response_code, response_content

View File

@ -0,0 +1,247 @@
# -*- coding: utf-8 -*-
import logging
import os
import shutil
import subprocess
import tarfile
import tempfile
import time
import warnings
import zlib
from m3u8_To_MP4.helpers import path_helper
from m3u8_To_MP4.helpers import printer_helper
from m3u8_To_MP4.helpers.os_helper import get_core_count
from m3u8_To_MP4.networks.synchronous import sync_DNS
printer_helper.config_logging()
class AbstractCrawler(object):
def __init__(self,
m3u8_uri,
file_path='./m3u8_To_MP4.mp4',
customized_http_header=None,
max_retry_times=3,
num_concurrent=50,
tmpdir=None
):
self.m3u8_uri = m3u8_uri
self.customized_http_header = customized_http_header
self.max_retry_times = max_retry_times
self.num_concurrent = num_concurrent
self.tmpdir = tmpdir
self.file_path = file_path
def __enter__(self):
if self.tmpdir is None:
self._apply_for_tmpdir()
self.segment_path_recipe = os.path.join(self.tmpdir, "ts_recipe.txt")
self._find_out_done_ts()
self._legalize_file_path()
# self._imitate_tar_file_path()
print('\nsummary')
print('m3u8_uri: {};\nmax_retry_times: {};\ntmp_dir: {};\nmp4_file_path: {};\n'.format(
self.m3u8_uri, self.max_retry_times, self.tmpdir,
self.file_path))
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._freeup_tmpdir()
def _apply_for_tmpdir(self):
os_tmp_dir = tempfile.gettempdir()
url_crc32_str = str(zlib.crc32(self.m3u8_uri.encode())) # hash algorithm
self.tmpdir = os.path.join(os_tmp_dir, 'm3u8_' + url_crc32_str)
if not os.path.exists(self.tmpdir):
os.mkdir(self.tmpdir)
def _freeup_tmpdir(self):
os_tmp_dir = tempfile.gettempdir()
for file_symbol in os.listdir(os_tmp_dir):
if file_symbol.startswith('m3u8_'):
file_symbol_absolute_path = os.path.join(os_tmp_dir, file_symbol)
if os.path.isdir(file_symbol_absolute_path):
shutil.rmtree(file_symbol_absolute_path)
def _find_out_done_ts(self):
file_names_in_tmpdir = os.listdir(self.tmpdir)
full_ts_file_names = list()
for file_name in reversed(file_names_in_tmpdir):
if file_name == 'ts_recipe.txt':
continue
absolute_file_path = os.path.join(self.tmpdir, file_name)
if os.path.getsize(absolute_file_path) > 0:
full_ts_file_names.append(file_name)
self.fetched_file_names = full_ts_file_names
def _legalize_file_path(self):
parent = os.path.dirname(self.file_path)
if not os.path.exists(parent):
parent = os.getcwd()
print('{} does not exists, remap to current directory.')
name = os.path.basename(self.file_path)
name = path_helper.calibrate_mp4_file_name(name)
# if not is_valid:
# mp4_file_name = path_helper.create_mp4_file_name()
self.file_path = os.path.join(parent, name)
if os.path.exists(self.file_path):
mp4_file_name = path_helper.random_name()
self.file_path = os.path.join(parent, mp4_file_name)
def _resolve_DNS(self):
self.available_addr_info_pool = sync_DNS.available_addr_infos_of_url(self.m3u8_uri)
self.best_addr_info = self.available_addr_info_pool[0]
logging.info('Resolved available hosts:')
for addr_info in self.available_addr_info_pool:
logging.info('{}:{}'.format(addr_info.host, addr_info.port))
def _create_tasks(self):
raise NotImplementedError
def _is_ads(self, segment_uri):
if segment_uri.startswith(self.longest_common_subsequence):
return True
# if not segment_uri.endswith('.ts'):
# return True
return False
def _filter_ads_ts(self, key_segments_pairs):
self.longest_common_subsequence = path_helper.longest_common_subsequence([segment_uri for _, segment_uri in key_segments_pairs])
key_segments_pairs = [(_encrypted_key, segment_uri) for
_encrypted_key, segment_uri in key_segments_pairs
if not self._is_ads(segment_uri)]
return key_segments_pairs
def _is_fetched(self, segment_uri):
file_name = path_helper.resolve_file_name_by_uri(segment_uri)
if file_name in self.fetched_file_names:
return True
return False
def _filter_done_ts(self, key_segments_pairs):
num_ts_segments = len(key_segments_pairs)
key_segments_pairs = [(_encrypted_key, segment_uri) for
_encrypted_key, segment_uri in key_segments_pairs
if not self._is_fetched(segment_uri)]
self.num_fetched_ts_segments = num_ts_segments - len(key_segments_pairs)
return key_segments_pairs
def _fetch_segments_to_local_tmpdir(self, key_segments_pairs):
raise NotImplementedError
def _construct_segment_path_recipe(self, key_segment_pairs):
with open(self.segment_path_recipe, 'w', encoding='utf8') as fw:
for _, segment in key_segment_pairs:
file_name = path_helper.resolve_file_name_by_uri(segment)
segment_file_path = os.path.join(self.tmpdir, file_name)
fw.write("file '{}'\n".format(segment_file_path))
def _merge_to_mp4(self):
if not self.file_path.endswith('mp4'):
warnings.warn('{} does not end with .mp4'.format(self.file_path))
logging.info("merging segments...")
# copy mode
merge_cmd = "ffmpeg " + \
"-y -f concat -threads {} -safe 0 ".format(get_core_count()) + \
"-i " + '"' + self.segment_path_recipe + '" ' + \
"-c copy " + \
'"' + self.file_path + '"'
p = subprocess.Popen(merge_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
p.communicate()
# change codec
if os.path.getsize(self.file_path) < 1:
logging.info("merged failed.")
logging.info("change codec and re-merge segments (it may take long time.)")
merge_cmd = "ffmpeg " + \
"-y -f concat -threads {} -safe 0 ".format(get_core_count()) + \
"-i " + '"' + self.segment_path_recipe + '" ' + \
'"' + self.file_path + '"'
p = subprocess.Popen(merge_cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
p.communicate()
def _merge_to_ts(self):
if not self.file_path.endswith('ts'):
warnings.warn('{} does not end with .mp4'.format(self.file_path))
ts_paths = list()
with open(self.segment_path_recipe, 'r', encoding='utf8') as fr:
for line in fr:
line = line.strip()
if len(line) < 1:
continue
ts_paths.append(line[6:-1])
with open(self.file_path, 'ab') as fw:
for ts_path in ts_paths:
with open(ts_path, 'rb') as fr:
fw.write(fr.read())
def _merge_to_tar(self):
if not self.file_path.endswith('tar'):
warnings.warn('{} does not end with .mp4'.format(self.file_path))
with tarfile.open(self.file_path, 'w:bz2') as targz:
targz.add(name=self.tmpdir, arcname=os.path.basename(self.tmpdir))
def fetch_mp4_by_m3u8_uri(self, format='ts'):
task_start_time = time.time()
# preparation
self._resolve_DNS()
# resolve ts segment uris
key_segments_pairs = self._create_tasks()
if len(key_segments_pairs) < 1:
raise ValueError('NO FOUND TASKS!\n Please check m3u8 url.')
key_segments_pairs = self._filter_ads_ts(key_segments_pairs)
self._construct_segment_path_recipe(key_segments_pairs)
key_segments_pairs = self._filter_done_ts(key_segments_pairs)
# download
if len(key_segments_pairs) > 0:
self._fetch_segments_to_local_tmpdir(key_segments_pairs)
fetch_end_time = time.time()
# merge
if format == 'ts':
self._merge_to_ts()
elif format == 'mp4':
self._merge_to_mp4()
elif format == 'tar':
self._merge_to_tar()
task_end_time = time.time()
printer_helper.display_speed(task_start_time, fetch_end_time, task_end_time, self.file_path)

View File

@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
import collections
import logging
import os.path
import m3u8
from m3u8_To_MP4.networks.synchronous import sync_http
from m3u8_To_MP4.v2_abstract_crawler_processor import AbstractCrawler
EncryptedKey = collections.namedtuple(typename='EncryptedKey',
field_names=['method', 'value', 'iv'])
class M3u8FileIsVariantException(Exception):
def __init__(self, name, reason):
self.name = name
self.reason = reason
class M3u8PlaylistIsNoneException(Exception):
def __init__(self, name, reason):
self.name = name
self.reason = reason
class AbstractFileCrawler(AbstractCrawler):
def __init__(self, m3u8_uri, m3u8_file_path, customized_http_header=None,
max_retry_times=3, num_concurrent=50, mp4_file_dir=None,
mp4_file_name='m3u8-To-Mp4.mp4', tmpdir=None):
file_path=os.path.join(mp4_file_dir,mp4_file_name)
super(AbstractFileCrawler, self).__init__(m3u8_uri,
file_path,
customized_http_header,
max_retry_times,
num_concurrent,
tmpdir)
self.m3u8_file_path = m3u8_file_path
def _read_m3u8_file(self):
m3u8_str = ''
with open(self.m3u8_file_path, 'r', encoding='utf8') as fin:
m3u8_str = fin.read().strip()
return m3u8_str
def _construct_key_segment_pairs_by_m3u8(self, m3u8_obj):
key_segments_pairs = list()
for key in m3u8_obj.keys:
if key:
if key.method.lower() == 'none':
continue
response_code, encryped_value = sync_http.retrieve_resource_from_url(
self.best_addr_info, key.absolute_uri,
customized_http_header=self.customized_http_header)
if response_code != 200:
raise Exception('DOWNLOAD KEY FAILED, URI IS {}'.format(
key.absolute_uri))
encryped_value = encryped_value.decode()
_encrypted_key = EncryptedKey(method=key.method,
value=encryped_value, iv=key.iv)
key_segments = m3u8_obj.segments.by_key(key)
segments_by_key = [(_encrypted_key, segment.absolute_uri) for
segment in key_segments]
key_segments_pairs.extend(segments_by_key)
if len(key_segments_pairs) == 0:
_encrypted_key = None
key_segments = m3u8_obj.segments
segments_by_key = [(_encrypted_key, segment.absolute_uri) for
segment in key_segments]
key_segments_pairs.extend(segments_by_key)
if len(key_segments_pairs) == 0:
raise M3u8PlaylistIsNoneException(name=self.m3u8_file_path,
reason='M3u8 playlist is null!')
return key_segments_pairs
def _create_tasks(self):
m3u8_str = self._read_m3u8_file()
m3u8_obj = m3u8.loads(content=m3u8_str, uri=self.m3u8_uri)
if m3u8_obj.is_variant:
raise M3u8FileIsVariantException(self.m3u8_file_path,
'M3u8 file is variant, and i do not support retrieve m3u8 in current mode!')
logging.info("Read m3u8 file from {}".format(self.m3u8_file_path))
key_segments_pairs = self._construct_key_segment_pairs_by_m3u8(
m3u8_obj)
return key_segments_pairs
class AbstractUriCrawler(AbstractCrawler):
def _request_m3u8_obj_from_url(self):
try:
response_code, m3u8_bytes = sync_http.retrieve_resource_from_url(
self.best_addr_info, self.m3u8_uri,
customized_http_header=self.customized_http_header)
if response_code != 200:
raise Exception(
'DOWNLOAD KEY FAILED, URI IS {}'.format(self.m3u8_uri))
m3u8_str = m3u8_bytes.decode()
m3u8_obj = m3u8.loads(content=m3u8_str, uri=self.m3u8_uri)
except Exception as exc:
logging.exception(
'Failed to load m3u8 file,reason is {}'.format(exc))
raise Exception('FAILED TO LOAD M3U8 FILE!')
return m3u8_obj
def _get_m3u8_obj_with_best_bandwitdth(self):
m3u8_obj = self._request_m3u8_obj_from_url()
if m3u8_obj.is_variant:
best_bandwidth = -1
best_bandwidth_m3u8_uri = None
for playlist in m3u8_obj.playlists:
if playlist.stream_info.bandwidth > best_bandwidth:
best_bandwidth = playlist.stream_info.bandwidth
best_bandwidth_m3u8_uri = playlist.absolute_uri
logging.info("Choose the best bandwidth, which is {}".format(
best_bandwidth))
logging.info("Best m3u8 uri is {}".format(best_bandwidth_m3u8_uri))
self.m3u8_uri = best_bandwidth_m3u8_uri
m3u8_obj = self._request_m3u8_obj_from_url()
return m3u8_obj
def _construct_key_segment_pairs_by_m3u8(self, m3u8_obj):
key_segments_pairs = list()
for key in m3u8_obj.keys:
if key:
if key.method.lower() == 'none':
continue
response_code, encryped_value = sync_http.retrieve_resource_from_url(
self.best_addr_info, key.absolute_uri,
customized_http_header=self.customized_http_header)
if response_code != 200:
raise Exception('DOWNLOAD KEY FAILED, URI IS {}'.format(
key.absolute_uri))
encryped_value = encryped_value.decode()
_encrypted_key = EncryptedKey(method=key.method,
value=encryped_value, iv=key.iv)
key_segments = m3u8_obj.segments.by_key(key)
segments_by_key = [(_encrypted_key, segment.absolute_uri) for
segment in key_segments]
key_segments_pairs.extend(segments_by_key)
if len(key_segments_pairs) == 0:
_encrypted_key = None
key_segments = m3u8_obj.segments
segments_by_key = [(_encrypted_key, segment.absolute_uri) for
segment in key_segments]
key_segments_pairs.extend(segments_by_key)
return key_segments_pairs
def _create_tasks(self):
# resolve ts segment uris
m3u8_obj = self._get_m3u8_obj_with_best_bandwitdth()
key_segments_pairs = self._construct_key_segment_pairs_by_m3u8(
m3u8_obj)
return key_segments_pairs

View File

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
from m3u8_To_MP4 import v2_abstract_task_processor
from m3u8_To_MP4.networks.asynchronous import async_producer_consumer
class AsynchronousFileCrawler(v2_abstract_task_processor.AbstractFileCrawler):
def _fetch_segments_to_local_tmpdir(self, key_segments_pairs):
async_producer_consumer.factory_pipeline(self.num_fetched_ts_segments,
key_segments_pairs,
self.available_addr_info_pool,
self.num_concurrent,
self.tmpdir)
class AsynchronousUriCrawler(v2_abstract_task_processor.AbstractUriCrawler):
def _fetch_segments_to_local_tmpdir(self, key_segments_pairs):
async_producer_consumer.factory_pipeline(self.num_fetched_ts_segments,
key_segments_pairs,
self.available_addr_info_pool,
self.num_concurrent,
self.tmpdir)

View File

@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-
import concurrent.futures
import logging
import os
import sys
from Crypto.Cipher import AES
from m3u8_To_MP4 import v2_abstract_task_processor
from m3u8_To_MP4.helpers import path_helper
from m3u8_To_MP4.helpers import printer_helper
from m3u8_To_MP4.networks.synchronous.sync_http_requester import request_for
def download_segment(segment_url, customized_http_header):
response_code, response_content = request_for(segment_url,
customized_http_header=customized_http_header)
return response_code, response_content
class MultiThreadsFileCrawler(v2_abstract_task_processor.AbstractFileCrawler):
def _fetch_segments_to_local_tmpdir(self, key_segments_pairs):
if len(key_segments_pairs) < 1:
return
progress_bar = printer_helper.ProcessBar(self.num_fetched_ts_segments,
self.num_fetched_ts_segments + len(
key_segments_pairs),
'segment set',
'downloading...',
'downloaded segments successfully!')
key_url_encrypted_data_triple = list()
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.num_concurrent) as executor:
while len(key_segments_pairs) > 0:
future_2_key_and_url = {
executor.submit(download_segment, segment_url,
self.customized_http_header): (
key, segment_url) for key, segment_url in key_segments_pairs}
response_code, response_data = None, None
for future in concurrent.futures.as_completed(
future_2_key_and_url):
key, segment_url = future_2_key_and_url[future]
try:
response_code, response_data = future.result()
except Exception as exc:
logging.exception(
'{} generated an exception: {}'.format(segment_url,
exc))
if response_code == 200:
key_url_encrypted_data_triple.append(
(key, segment_url, response_data))
key_segments_pairs.remove((key, segment_url))
progress_bar.update()
if len(key_segments_pairs) > 0:
sys.stdout.write('\n')
logging.info(
'{} segments are failed to download, retry...'.format(
len(key_segments_pairs)))
logging.info('decrypt and dump segments...')
for key, segment_url, encrypted_data in key_url_encrypted_data_triple:
file_name = path_helper.resolve_file_name_by_uri(segment_url)
file_path = os.path.join(self.tmpdir, file_name)
if key is not None:
crypt_ls = {"AES-128": AES}
crypt_obj = crypt_ls[key.method]
cryptor = crypt_obj.new(key.value.encode(),
crypt_obj.MODE_CBC)
encrypted_data = cryptor.decrypt(encrypted_data)
with open(file_path, 'wb') as fin:
fin.write(encrypted_data)
class MultiThreadsUriCrawler(v2_abstract_task_processor.AbstractUriCrawler):
def _fetch_segments_to_local_tmpdir(self, key_segments_pairs):
if len(key_segments_pairs) < 1:
return
progress_bar = printer_helper.ProcessBar(self.num_fetched_ts_segments,
self.num_fetched_ts_segments + len(
key_segments_pairs),
'segment set',
'downloading...',
'downloaded segments successfully!')
key_url_encrypted_data_triple = list()
with concurrent.futures.ThreadPoolExecutor(
max_workers=self.num_concurrent) as executor:
while len(key_segments_pairs) > 0:
future_2_key_and_url = {
executor.submit(download_segment, segment_url,
self.customized_http_header): (
key, segment_url) for key, segment_url in key_segments_pairs}
response_code, response_data = None, None
for future in concurrent.futures.as_completed(
future_2_key_and_url):
key, segment_url = future_2_key_and_url[future]
try:
response_code, response_data = future.result()
except Exception as exc:
logging.exception(
'{} generated an exception: {}'.format(segment_url,
exc))
if response_code == 200:
key_url_encrypted_data_triple.append(
(key, segment_url, response_data))
key_segments_pairs.remove((key, segment_url))
progress_bar.update()
if len(key_segments_pairs) > 0:
sys.stdout.write('\n')
logging.info(
'{} segments are failed to download, retry...'.format(
len(key_segments_pairs)))
logging.info('decrypt and dump segments...')
for key, segment_url, encrypted_data in key_url_encrypted_data_triple:
file_name = path_helper.resolve_file_name_by_uri(segment_url)
file_path = os.path.join(self.tmpdir, file_name)
if key is not None:
crypt_ls = {"AES-128": AES}
crypt_obj = crypt_ls[key.method]
cryptor = crypt_obj.new(key.value.encode(),
crypt_obj.MODE_CBC)
encrypted_data = cryptor.decrypt(encrypted_data)
with open(file_path, 'wb') as fin:
fin.write(encrypted_data)

46
m3u8_download.py Normal file
View File

@ -0,0 +1,46 @@
import m3u8_To_MP4
from utils.MySqlUtil import MySqlUtil
from apscheduler.schedulers.blocking import BlockingScheduler
import time
from utils.log import Log
from pathlib import Path
def download_m3u8(download_path='./mp4/'):
try:
log = Log().getlog()
# 初始化数据库连接
movie_config = MySqlUtil("movie")
# 获取未处理的电影记录
movie_message = MySqlUtil.get_one(movie_config, 'SELECT * FROM `movie` WHERE is_ok=0 LIMIT 1')
if not movie_message or len(movie_message) < 3: # 校验结果是否有效
log.info("没有找到电影记录或无效数据。")
return
id, name, url = movie_message[0], movie_message[1], movie_message[2]
# 构造目标文件路径
file_path = Path(download_path).joinpath(f"{name}.mp4")
# 更新数据库状态,使用参数化查询防止 SQL 注入
sql = f'UPDATE `movie`.`movie` SET `is_ok` = 1 WHERE `id` = {id}'
MySqlUtil.update(movie_config, sql=sql)
log.info(f"任务下载中,正在下载 {name}...")
# 下载 m3u8 文件并转换为 MP4
m3u8_To_MP4.multithread_download(url, file_path=str(file_path))
log.info(f"成功下载并转换 {name} to {file_path}.")
except Exception as e:
log.error(f"下载过程中出现错误: {e}")
if __name__ == '__main__':
download_m3u8()
# str_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
# sch = BlockingScheduler(timezone='Asia/Shanghai')
# sch.add_job(download_m3u8, 'cron', minute='*/2')
# sch.start()

6
requirements.txt Normal file
View File

@ -0,0 +1,6 @@
iso8601>=0.1.14
m3u8>=0.9.0
pycryptodome>=3.10.1
pandas
pymysql
apscheduler

26
utils/LoadConfig.py Normal file
View File

@ -0,0 +1,26 @@
import json, os
from utils.Log import log
current_directory = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.dirname(current_directory) + os.path.sep + ".")
# log.info(str(root_path))
project_name = root_path.split(os.path.sep)[-1]
# log.info(str(project_name))
project_root_path = os.path.abspath(os.path.dirname(__file__)).split(project_name)[0] + project_name
# log.info(str(project_root_path))
def loadconfig(config_key):
json_file = project_root_path + "/config/config.json"
with open(json_file) as f:
cfg = json.load(f)[config_key]
log.info(' 读取配置:' + config_key)
log.info(cfg)
return cfg
if __name__ == '__main__':
yz_mysql_config = loadconfig('yz_mysql')
print(yz_mysql_config['host'])

159
utils/MySqlUtil.py Normal file
View File

@ -0,0 +1,159 @@
#!/usr/bin/python
# -*- coding:utf-8 -*-
import pymysql
from utils.Log import log
import os
import platform
import pandas as pd
from utils.LoadConfig import loadconfig
class MySQLError(Exception):
def __init__(self, message):
self.message = message
class MySqlUtil:
"""mysql util"""
db = None
cursor = None
def get_section(db_name):
"""根据系统环境变量获取section"""
platform_ = platform.system()
if platform_ == "Windows" or platform_ == "Darwin":
section = db_name + '_test'
else:
section = db_name + '_pro'
log.info("MYSQL 环境为 {}".format(section))
return section
def __init__(self, db_name):
platform_ = platform.system()
if platform_ == "Windows" or platform_ == "Darwin":
section = db_name + '_test'
else:
section = db_name + '_pro'
log.info("MYSQL 环境为 {}".format(section))
# 打开数据库连接
conf = loadconfig('yz_mysql')
self.host = conf.get('host')
self.port = int(conf.get('port'))
self.userName = conf.get('userName')
self.password = conf.get('password')
self.dbName = conf.get("dbName")
mysql_config = {"host": self.host,
"port": self.port,
"user": self.userName,
"passwd": self.password,
"db": self.dbName}
log.info(mysql_config)
# 链接数据库
def get_con(self):
""" 获取conn """
self.db = pymysql.Connect(
host=self.host,
port=self.port,
user=self.userName,
passwd=self.password,
db=self.dbName
)
self.cursor = self.db.cursor()
# 关闭链接
def close(self):
self.cursor.close()
self.db.close()
# 主键查询数据
def get_one(self, sql):
res = None
try:
self.get_con()
self.cursor.execute(sql)
res = self.cursor.fetchone()
self.close()
log.info("查询:" + sql)
except Exception as e:
log.error("查询失败!" + str(e))
return res
# 查询列表数据
def get_all(self, sql):
res = None
try:
self.get_con()
self.cursor.execute(sql)
res = self.cursor.fetchall()
self.close()
except Exception as e:
log.error("查询失败!" + str(e))
return res
def get_data_frame_sql(self, sql, columns):
self.get_con()
self.cursor.execute(sql)
res = self.cursor.fetchall()
df = pd.DataFrame(res, columns=columns)
return df
# 插入数据
def insert_parameter(self, sql, parameter):
count = 0
try:
self.get_con()
count = self.cursor.execute(sql, parameter)
self.db.commit()
self.close()
except Exception as e:
log.error("操作失败!" + str(e))
self.db.rollback()
return count
# 插入数据
def insert_parameters(self, sql, parameters):
count = 0
try:
self.get_con()
count = self.cursor.executemany(sql, parameters)
self.db.commit()
self.close()
except Exception as e:
log.error("操作失败!" + str(e))
self.db.rollback()
return count
# 插入数据
def __insert(self, sql):
count = 0
try:
self.get_con()
log.info('执行sql:\r\n' + sql)
count = self.cursor.execute(sql)
self.db.commit()
self.close()
log.info('sql执行完成')
except Exception as e:
# log.error("操作失败!" + str(e))
self.db.rollback()
raise MySQLError("mysql操作异常")
return count
# 保存数据
def save(self, sql):
return self.__insert(sql)
# 更新数据
def update(self, sql,):
return self.__insert(sql)
# 删除数据
def delete(self, sql):
return self.__insert(sql)
if __name__ == '__main__':
print('')

69
utils/log.py Normal file
View File

@ -0,0 +1,69 @@
import logging
import os
from datetime import datetime
# 定义全局变量 log_path
cur_path = os.path.dirname(os.path.realpath(__file__))
log_path = os.path.join(os.path.dirname(cur_path), 'logs')
class Log():
def __init__(self, logger_name='my_logger'):
self.logger = logging.getLogger(logger_name)
if self.logger.hasHandlers():
self.logger.handlers.clear()
self.logger.setLevel(logging.INFO)
if not os.path.exists(log_path):
os.makedirs(log_path)
self.update_log_file()
def update_log_file(self):
current_date = datetime.now().strftime("%Y_%m_%d")
self.log_name = os.path.join(log_path, f'{current_date}.log')
for handler in self.logger.handlers[:]:
self.logger.removeHandler(handler)
fh = logging.FileHandler(self.log_name, 'a', encoding='utf-8')
fh.setLevel(logging.INFO)
ch = logging.StreamHandler()
ch.setLevel(logging.INFO)
formatter = logging.Formatter(
'[%(asctime)s] %(filename)s line:%(lineno)d [%(levelname)s]%(message)s',
datefmt="%Y-%m-%d %H:%M:%S"
)
fh.setFormatter(formatter)
ch.setFormatter(formatter)
self.logger.addHandler(fh)
self.logger.addHandler(ch)
def getlog(self):
current_date = datetime.now().strftime("%Y_%m_%d")
log_date = os.path.basename(self.log_name).split('.')[0]
if current_date != log_date:
self.update_log_file()
return self.logger
def info(self, msg, *args, **kwargs):
logger = self.getlog()
logger.info(msg, *args, **kwargs)
def error(self, msg, *args, **kwargs):
logger = self.getlog()
logger.error(msg, *args, **kwargs)
def warning(self, msg, *args, **kwargs):
logger = self.getlog()
logger.warning(msg, *args, **kwargs)
if __name__ == "__main__":
log = Log()
log.info("---测试开始----")
log.error("操作步骤1,2,3")
log.warning("----测试结束----")