#!/bin/python3 import re import argparse import logging import asyncio from typing import Union, Optional from enum import Enum, auto from aiohttp import ClientSession from pathlib import Path from xml.etree import ElementTree as ET logger = logging.getLogger(__name__) ns = {'': 'http://maven.apache.org/POM/4.0.0'} ET.register_namespace('', ns['']) output_dir = Path('output') baseurl = 'https://search.maven.org' base_pom_path = Path('poms') mirrors = [ "https://repo.maven.apache.org/maven2", "https://repo1.maven.org/maven2", "https://oss.sonatype.org/content/repositories/snapshots", "https://packages.confluent.io/maven", "https://registry.quarkus.io/maven", "https://plugins.gradle.org/m2", ] java_version = 11 num_workers = 50 queue: 'PackageQueue' session: ClientSession class RequiresPackage(Exception): def __init__(self, package: 'Package'): super().__init__(f'Requires package {package}') self.package = package class PackageState(Enum): PENDING = auto() DONE = auto() class PackageQueue: def __init__(self): self.lock = asyncio.Lock() self.queue = asyncio.Queue() self.packages: dict[str, tuple['Package', 'PackageState']] = {} async def put(self, package: 'Package', timeout: int | None = None) -> None: key = str(package) async with self.lock: if key not in self.packages: logger.debug(f'{key}: Added to list') self.packages[key] = (package, PackageState.PENDING) else: return await asyncio.wait_for(self.queue.put(package), timeout) async def requeue(self, package: 'Package', timeout: int | None = None) -> None: key = str(package) async with self.lock: exists = key in self.packages if exists: await asyncio.wait_for(self.queue.put(package), timeout) self.queue.task_done() async def get(self, timeout: int | None = None) -> 'Package': while True: package = await asyncio.wait_for(self.queue.get(), timeout) key = str(package) async with self.lock: if key in self.packages: state = self.packages[key][1] match state: case PackageState.DONE: logger.debug(f'{package}: Already downloaded. Skipping.') case PackageState.PENDING: return package case _: logger.warning(f'{package}: Unknown state {state}') else: logger.warning(f'{package}: Package is in queue but not in package list') async def get_done(self, package: Union['Package', str]) -> Optional['Package']: key = str(package) async with self.lock: if key in self.packages: pack, state = self.packages[key] if state == PackageState.DONE: return pack return None async def done(self, package: 'Package') -> None: key = str(package) async with self.lock: if key not in self.packages: logger.warning(f'{package}: Unknown package marked as done') self.packages[key] = (package, PackageState.DONE) self.queue.task_done() async def set_version(self, package): key = str(package) if key.count(':') == 2: a, b, c = key.split(':') x = ':'.join([a, b, '----']) else: raise RuntimeError(f'Malformed package specifier {package}') async def join(self) -> None: await self.queue.join() def empty(self) -> bool: return self.queue.empty() class MavenVersion: qualifiers = ( ('alpha', 'a'), ('beta', 'b'), ('milestone', 'mc'), ('rc'), ('sp'), ('ga'), ('final'), ) def __init__(self, version: str): self.raw = version self.parts = [x for x in re.split(r'[\.-]|((?<=\d)(?=\w))', version) if x] def __str__(self) -> str: return self.raw def _compare(self, other): match other: case MavenVersion(): pass case str(): other = MavenVersion(other) case _: return False for x, y in zip(self.parts, other.parts): if x == y: return 0 elif x.isnumeric(): if y.isnumeric(): if int(x) > int(y): return 1 elif int(x) < int(y): return -1 else: return 1 elif y.isnumeric(): return -1 else: def qualifier_index(qualifier: str) -> int | None: for i, q in enumerate(self.qualifiers): for alias in q: if alias == qualifier: return i return None xi = qualifier_index(x) yi = qualifier_index(y) if xi is not None and yi is not None: if xi < yi: return 1 elif xi > yi: return -1 raise RuntimeError(f"Can't compare qualifier {x} and {y} from version {self} and {other}") return 0 def __eq__(self, other) -> bool: try: return self._compare(other) == 0 except RuntimeError: return False def __ne__(self, other) -> bool: try: return self._compare(other) != 0 except RuntimeError: return True def __gt__(self, other) -> bool: return self._compare(other) == 1 def __ge__(self, other) -> bool: return self._compare(other) in [0, 1] def __lt__(self, other) -> bool: return self._compare(other) == -1 def __le__(self, other) -> bool: return self._compare(other) in [-1, 0] class MavenVersionRange: def __init__(self, range_str: str): self.range_str = range_str def is_version_in_range(self, version: str) -> bool: is_match = False for _, low_bracket, content, high_bracket in re.findall(r'((\(|\[)([^\(\[\]\)]+)(\]|\)))', self.range_str): bounds = content.split(',') if len(bounds) == 2: low_bound, high_bound = [ MavenVersion(bound) if bound else None for bound in bounds ] elif len(bounds) == 1: low_bound = high_bound = MavenVersion(bounds[0]) if bounds[0] else None else: raise RuntimeError(f'Invalid version range {self.range_str}') match low_bracket: case '(': if low_bound is not None: if low_bound >= self: continue case '[': if low_bound is not None: if low_bound > self: continue else: raise RuntimeError(f"Error in version range {self.range_str}. [ can't be unbounded") match high_bracket: case ')': if high_bound is not None: if high_bound <= self: continue case ']': if high_bound is not None: if high_bound > self: continue else: raise RuntimeError(f"Error in version range {self.range_str}. ] can't be unbounded") is_match = True return is_match def get_maven_version_or_range(version_or_range: str) -> MavenVersion | MavenVersionRange: if re.fullmatch(r'([^\(\[\]\)]+)', version_or_range): return MavenVersion(version_or_range) else: return MavenVersionRange(version_or_range) class PackagePOM: def __init__(self, package: 'Package', pom: str): logger.debug(f'{package}: Parsing POM') self.package = package self.raw_root = ET.fromstring(pom) async def parse_pom(self): parent_tag = self.raw_root.find('parent', ns) parent = await self._package_from_xml(parent_tag) if parent_tag else None if parent is not None: if (parent := await queue.get_done(parent)) is not None: logger.debug(f'{self.package}: Using parent {parent}') self.parent = parent else: logger.debug(f'{self.package}: Requires parent {parent}') raise RequiresPackage(parent) else: self.parent = None if (packaging := self.raw_root.find('packaging', ns)) is not None: self.packaging = await self._format_with_props(packaging.text) else: self.packaging = 'jar' self.is_bom = self.packaging == 'pom' logger.debug(f'{self.package}: POM parsed') async def get_property(self, prop: str) -> str | None: match prop: case 'project.groupId' | 'pom.groupId': return self.package.groupId case 'project.artifactId' | 'pom.artifactId': return self.package.artifactId case 'project.version' | 'pom.version' | 'version': return str(self.package.version) case 'java.version': return str(java_version) case _: elem = self.raw_root.find(f'.//properties/{prop}', ns) if elem is not None: return await self._format_with_props(elem.text) elif self.parent: parent_pom = self.parent.pom if parent_pom is not None: parent_prop = prop.replace('parent.', '') if prop == parent_prop: logger.debug(f'{self.package}: Passing property {prop} to parent as {parent_prop}') else: logger.debug(f'{self.package}: Passing property {prop} to parent') return await parent_pom.get_property(parent_prop) else: logger.warning(f'{self.package}: Parent {self.parent} does not have a pom file') import pdb; pdb.set_trace() return None else: return None async def _format_with_props(self, text): arr = re.split(r'\$\{([^\}]*)\}', text) for i in range(1, len(arr), 2): prop = arr[i] value = await self.get_property(prop) if value is None: logger.warning(f'{self.package}: Property {prop} not found. Defaulting to an empty string') value = '' arr[i] = value logger.debug(f'{self.package}: Replacing property "{prop}" with "{value}"') return ''.join(arr) async def _package_from_xml(self, dep: ET.Element): package = Package( *[ await self._format_with_props( elem.text or '' if (elem := dep.find(tag, ns)) is not None else '' ) for tag in [ 'groupId', 'artifactId', 'version', ] ] ) return package @property async def dependency_management(self) -> list['Package']: dependencies: list[Package] = [] for dep in self.raw_root.find('dependencyManagement/dependencies', ns) or []: package = await self._package_from_xml(dep) logger.debug(f'{self.package}: Adding dependency management {package}') dependencies.append(package) return dependencies @property async def dependencies(self) -> list['Package']: dependencies: list[Package] = [] for dep in self.raw_root.find('dependencies', ns) or []: package = await self._package_from_xml(dep) logger.debug(f'{self.package}: Adding dependency {package}') dependencies.append(package) return dependencies class Package: _verified: bool = False pom: PackagePOM | None = None groupId: str artifactId: str version: MavenVersion | None version_range: MavenVersionRange | None def __init__(self, groupId: str, artifactId: str, version: str = ''): self.groupId = groupId.strip() self.artifactId = artifactId.strip() match (val := get_maven_version_or_range(version.strip())): case MavenVersion(): self.version = val self.version_range = None case MavenVersionRange(): self.version = None self.version_range = val if not self.groupId or not self.artifactId: logger.warning(f'{self}: groupId or artifactId is empty') def __str__(self) -> str: return f'{self.groupId or "----"}:{self.artifactId or "----"}:{self.version or "----"}' def __eq__(self, other) -> bool: return ( self.groupId == other.groupId and self.artifactId == other.artifactId and self.version == other.version ) def __hash__(self) -> int: return hash((self.groupId, self.artifactId, self.version)) @property def package_dir_path(self): group_path = self.groupId.replace(".", "/") return f'{group_path}/{self.artifactId}' @property def version_dir_path(self): return self.package_dir_path + f'/{self.version}' @property def base_filename(self): return f'{self.artifactId}-{self.version}' async def download_file_bytes(self, extension) -> bytes | None: filepath = f'{self.version_dir_path}/{self.base_filename}.{extension}' for mirror in mirrors: pom_url = f'{mirror}/{filepath}' logger.debug(f'{self}: Downloading {extension} from {pom_url}') async with session.get(pom_url) as response: if response.status == 200: logger.debug(f'{self}: {extension} downloaded') return await response.read() else: logger.debug(f'{self}: HTTP error {response.status} from mirror {mirror}') else: logger.warning(f'{self}: File download of {extension} failed for all mirrors') return None async def download_file_text(self, extension: str) -> str | None: b = await self.download_file_bytes(extension) if b is not None: return b.decode('utf-8') else: return None async def download_all(self, out_dir: Path) -> bool: basedir = out_dir / self.version_dir_path basedir.mkdir(exist_ok=True, parents=True) basepath = basedir / self.base_filename pom = self.pom if not pom: return False match pom.packaging: case 'pom': return True case 'jar' | 'maven-plugin' | 'eclipse-plugin' | 'bundle': logger.debug(f'{self}: Downloading JAR') return True # TODO: Remove to test JAR download # TODO: Handle checksums jar = await self.download_file_bytes('jar') if jar: with basepath.with_suffix('.jar').open('wb') as f: f.write(jar) return True else: logger.warning(f'{self}: JAR not found') return False case _: logger.warning(f'{self}: Unknown packaging {pom.packaging}') return False return False async def fetch_pom(self) -> None: if self.pom is not None: await self.pom.parse_pom() return if self.version is None: await self._query_maven() xml = await self.download_file_text('pom') if xml is not None: self.pom = PackagePOM(self, xml) await self.pom.parse_pom() @property def _urlquery(self) -> str: q = f'g:{self.groupId}+AND+a:{self.artifactId}' if self.version is not None: q += f'+AND+v:{self.version}' return q async def _query_maven(self) -> None: for mirror in mirrors: url = f'{mirror}/{self.package_dir_path}/maven-metadata.xml' logger.debug(f'{self}: Querying maven metadata at url {url}') async with session.get(url) as response: if response.status == 200: message = await response.text() xml = ET.fromstring(message) version_tags = xml.findall('./versioning/versions/version') versions = [MavenVersion(version.text) for version in version_tags if version.text is not None] logger.debug(f'{self}: Query successful') if not self.version: elem = xml.find('./versioning/latest') if elem is not None and (text := elem.text): self._verified = True self.version = MavenVersion(text) logger.debug(f'{self}: Using newest version {self.version}') break else: logger.debug(f'{self}: No latest version marked in metadata') if not versions: logger.warning(f'{self}: No versions available in metadata') else: self._verified = True self.version = max(versions) logger.debug(f'{self}: Using version {self.version}') break elif self.version is not None: if self.version in versions: self._verified = True break elif self.version_range is not None: valid_versions = [v for v in versions if self.version_range.matches(v)] if valid_versions: self.version = max(valid_versions) self._verified = True logger.debug(f'{self}: Picked version {self.version} from range {self.version_range}') break else: logger.debug(f"{self}: Available versions '{versions}' doesn't match range {self.version_range}") continue else: raise RuntimeError(f'{self}: Unknown version type {self.version}') else: self._verified = False logger.debug(f'{self}: Package not found in mirror') else: logger.warning(f'{self}: Package not found') async def verify(self) -> bool: if not self._verified: await self._query_maven() return self._verified async def load_package_list(list_path: Path) -> None: logger.info(f'Parsing {list_path}') with list_path.open('r') as f: for line in f.readlines(): sections = line.strip().split(':') if len(sections) < 2 or len(sections) > 3: logger.warning(f'Invalid package format "{line}". It should be "groupID:artifactID" or "groupID:artifactID:version"') continue package = Package( sections[0], sections[1], sections[2] if len(sections) == 3 else '', ) await package.verify() await queue.put(package, timeout=10) async def download(package: Package) -> None: if await package.verify(): await package.fetch_pom() success = await package.download_all(output_dir) if success: logger.info(f'{package}: Downloaded') else: logger.warning(f'{package}: Download failed') return pom = package.pom if pom is not None: deps = await pom.dependencies depman = await pom.dependency_management for dep in [*deps, *depman]: await queue.put(dep) else: logger.warning(f'{package}: POM not found. Skipping') else: logger.warning(f"{package}: Can't verify package") async def file_parse_worker(id: str) -> None: while True: try: package = await queue.get(timeout=5) except asyncio.TimeoutError: logger.error(f'Worker {id} timed out waiting for queue. Stopping worker') break try: await asyncio.wait_for(package.fetch_pom(), timeout=60) except asyncio.TimeoutError as e: logger.error(f'{package}: Timeout out waiting for download', exc_info=e) except RequiresPackage as e: await queue.put(e.package) await queue.requeue(package) except Exception as e: logger.error(f'{package}:', exc_info=e) finally: await queue.done(package) async def download_worker(id: str) -> None: while True: try: package = await queue.get(timeout=5) except asyncio.TimeoutError: if queue.empty(): logger.error(f'Worker {id} timed out waiting for empty queue. Stopping worker') else: logger.error(f'Worker {id} timed out waiting for queue lock. Stopping worker') break try: if package: await asyncio.wait_for(download(package), timeout=60) except asyncio.TimeoutError as e: logger.error(f'{package}: Timeout out waiting for download', exc_info=e) except RequiresPackage as e: await queue.put(e.package) await queue.requeue(package) except Exception as e: logger.error(f'{package}:', exc_info=e) finally: await queue.done(package) async def parse_list_tasks() -> None: #tasks = [] #logger.debug(f'Starting {num_workers} download workers') #for i in range(num_workers): # tasks.append( # asyncio.create_task( # file_parse_worker(str(i)) # ) # ) await load_package_list(Path('package-list.txt')) #exceptions = await asyncio.gather(*tasks, return_exceptions=True) #for e in exceptions: #logger.debug('Worker exception', exc_info=e) async def download_tasks() -> None: tasks = [] logger.debug(f'Starting {num_workers} download workers') for i in range(num_workers): tasks.append( asyncio.create_task( download_worker(str(i)) ) ) try: await queue.join() logger.debug('Queue is empty. Cancelling workers') except asyncio.CancelledError: logger.info('Tasks cancelled') for task in tasks: task.cancel() exceptions = await asyncio.gather(*tasks, return_exceptions=True) for e in exceptions: logger.debug(f'Worker exception {e}', exc_info=e) async def main() -> None: global queue global session queue = PackageQueue() session = ClientSession() try: await parse_list_tasks() await download_tasks() finally: await session.close() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('-w', '--workers', type=int, default=num_workers) parser.add_argument('-v', '--verbose', dest='verbosity', action='count', default=0) args = parser.parse_args() if args.verbosity == 0: log_level = 'WARNING' elif args.verbosity == 1: log_level = 'INFO' else: log_level = 'DEBUG' logging.basicConfig(level=log_level) num_workers = args.workers asyncio.run(main())