diff --git a/update-poms.py b/update-poms.py index 70369e4..d148d69 100755 --- a/update-poms.py +++ b/update-poms.py @@ -4,16 +4,18 @@ import re import argparse import logging import asyncio -import subprocess -import copy -import aiohttp +from typing import Union, Optional +from enum import Enum, auto +from aiohttp import ClientSession from pathlib import Path from xml.etree import ElementTree as ET +logger = logging.getLogger(__name__) ns = {'': 'http://maven.apache.org/POM/4.0.0'} ET.register_namespace('', ns['']) +output_dir = Path('output') baseurl = 'https://search.maven.org' base_pom_path = Path('poms') mirrors = [ @@ -25,82 +27,330 @@ mirrors = [ "https://plugins.gradle.org/m2", ] -done: set[str] = set() -done_lock = asyncio.Lock() +java_version = 11 num_workers = 50 +queue: 'PackageQueue' +session: ClientSession + + +class RequiresPackage(Exception): + def __init__(self, package: 'Package'): + super().__init__(f'Requires package {package}') + self.package = package + + +class PackageState(Enum): + PENDING = auto() + DONE = auto() + + +class PackageQueue: + def __init__(self): + self.lock = asyncio.Lock() + self.queue = asyncio.Queue() + self.packages: dict[str, tuple['Package', 'PackageState']] = {} + + async def put(self, package: 'Package', timeout: int | None = None) -> None: + key = str(package) + async with self.lock: + if key not in self.packages: + logger.debug(f'{key}: Added to list') + self.packages[key] = (package, PackageState.PENDING) + else: + return + + await asyncio.wait_for(self.queue.put(package), timeout) + + async def requeue(self, package: 'Package', timeout: int | None = None) -> None: + key = str(package) + async with self.lock: + exists = key in self.packages + + if exists: + await asyncio.wait_for(self.queue.put(package), timeout) + self.queue.task_done() + + async def get(self, timeout: int | None = None) -> 'Package': + while True: + package = await asyncio.wait_for(self.queue.get(), timeout) + key = str(package) + + async with self.lock: + if key in self.packages: + state = self.packages[key][1] + match state: + case PackageState.DONE: + logger.debug(f'{package}: Already downloaded. Skipping.') + case PackageState.PENDING: + return package + case _: + logger.warning(f'{package}: Unknown state {state}') + else: + logger.warning(f'{package}: Package is in queue but not in package list') + + async def get_done(self, package: Union['Package', str]) -> Optional['Package']: + key = str(package) + + async with self.lock: + if key in self.packages: + pack, state = self.packages[key] + if state == PackageState.DONE: + return pack + + return None + + async def done(self, package: 'Package') -> None: + key = str(package) + async with self.lock: + if key not in self.packages: + logger.warning(f'{package}: Unknown package marked as done') + + self.packages[key] = (package, PackageState.DONE) + + self.queue.task_done() + + async def set_version(self, package): + key = str(package) + + if key.count(':') == 2: + a, b, c = key.split(':') + x = ':'.join([a, b, '----']) + else: + raise RuntimeError(f'Malformed package specifier {package}') + + async def join(self) -> None: + await self.queue.join() + + def empty(self) -> bool: + return self.queue.empty() + + + +class MavenVersion: + qualifiers = ( + ('alpha', 'a'), + ('beta', 'b'), + ('milestone', 'mc'), + ('rc'), + ('sp'), + ('ga'), + ('final'), + ) + + def __init__(self, version: str): + self.raw = version + self.parts = [x for x in re.split(r'[\.-]|((?<=\d)(?=\w))', version) if x] + + def __str__(self) -> str: + return self.raw + + def _compare(self, other): + match other: + case MavenVersion(): + pass + case str(): + other = MavenVersion(other) + case _: + return False + + for x, y in zip(self.parts, other.parts): + if x == y: + return 0 + elif x.isnumeric(): + if y.isnumeric(): + if int(x) > int(y): + return 1 + elif int(x) < int(y): + return -1 + else: + return 1 + elif y.isnumeric(): + return -1 + else: + def qualifier_index(qualifier: str) -> int | None: + for i, q in enumerate(self.qualifiers): + for alias in q: + if alias == qualifier: + return i + + return None + + xi = qualifier_index(x) + yi = qualifier_index(y) + + if xi is not None and yi is not None: + if xi < yi: + return 1 + elif xi > yi: + return -1 + + raise RuntimeError(f"Can't compare qualifier {x} and {y} from version {self} and {other}") + + return 0 + + def __eq__(self, other) -> bool: + try: + return self._compare(other) == 0 + except RuntimeError: + return False + + def __ne__(self, other) -> bool: + try: + return self._compare(other) != 0 + except RuntimeError: + return True + + def __gt__(self, other) -> bool: + return self._compare(other) == 1 + + def __ge__(self, other) -> bool: + return self._compare(other) in [0, 1] + + def __lt__(self, other) -> bool: + return self._compare(other) == -1 + + def __le__(self, other) -> bool: + return self._compare(other) in [-1, 0] + + +class MavenVersionRange: + def __init__(self, range_str: str): + self.range_str = range_str + + def is_version_in_range(self, version: str) -> bool: + is_match = False + + for _, low_bracket, content, high_bracket in re.findall(r'((\(|\[)([^\(\[\]\)]+)(\]|\)))', self.range_str): + bounds = content.split(',') + + if len(bounds) == 2: + low_bound, high_bound = [ + MavenVersion(bound) if bound else None + for bound in bounds + ] + elif len(bounds) == 1: + low_bound = high_bound = MavenVersion(bounds[0]) if bounds[0] else None + else: + raise RuntimeError(f'Invalid version range {self.range_str}') + + match low_bracket: + case '(': + if low_bound is not None: + if low_bound >= self: + continue + case '[': + if low_bound is not None: + if low_bound > self: + continue + else: + raise RuntimeError(f"Error in version range {self.range_str}. [ can't be unbounded") + + match high_bracket: + case ')': + if high_bound is not None: + if high_bound <= self: + continue + case ']': + if high_bound is not None: + if high_bound > self: + continue + else: + raise RuntimeError(f"Error in version range {self.range_str}. ] can't be unbounded") + + is_match = True + + return is_match + + +def get_maven_version_or_range(version_or_range: str) -> MavenVersion | MavenVersionRange: + if re.fullmatch(r'([^\(\[\]\)]+)', version_or_range): + return MavenVersion(version_or_range) + else: + return MavenVersionRange(version_or_range) + class PackagePOM: def __init__(self, package: 'Package', pom: str): logger.debug(f'{package}: Parsing POM') + self.package = package self.raw_root = ET.fromstring(pom) + async def parse_pom(self): + parent_tag = self.raw_root.find('parent', ns) + parent = await self._package_from_xml(parent_tag) if parent_tag else None + + if parent is not None: + if (parent := await queue.get_done(parent)) is not None: + logger.debug(f'{self.package}: Using parent {parent}') + self.parent = parent + else: + logger.debug(f'{self.package}: Requires parent {parent}') + raise RequiresPackage(parent) + else: + self.parent = None + if (packaging := self.raw_root.find('packaging', ns)) is not None: - self.packaging = packaging.text + self.packaging = await self._format_with_props(packaging.text) else: - self.packaging = '??' + self.packaging = 'jar' self.is_bom = self.packaging == 'pom' - if self.packaging == 'pom': - root_copy = copy.deepcopy(self.raw_root) - depman = root_copy.find('dependencyManagement', ns) - if depman is not None: - root_copy.extend(depman.findall('*')) - root_copy.remove(depman) - self.generated_root = root_copy - else: - self.generated_root = ET.fromstring( - f""" - - - 4.0.0 - tmp.{package.groupId} - placeholder-{package.artifactId} - {package.version} - Package {package.artifactId} - - - - {package.groupId} - {package.artifactId} - {package.version} - - - - """ - ) + logger.debug(f'{self.package}: POM parsed') + + async def get_property(self, prop: str) -> str | None: + match prop: + case 'project.groupId' | 'pom.groupId': + return self.package.groupId + case 'project.artifactId' | 'pom.artifactId': + return self.package.artifactId + case 'project.version' | 'pom.version' | 'version': + return str(self.package.version) + case 'java.version': + return str(java_version) + case _: + elem = self.raw_root.find(f'.//properties/{prop}', ns) + if elem is not None: + return await self._format_with_props(elem.text) + elif self.parent: + parent_pom = self.parent.pom + if parent_pom is not None: + parent_prop = prop.replace('parent.', '') + if prop == parent_prop: + logger.debug(f'{self.package}: Passing property {prop} to parent as {parent_prop}') + else: + logger.debug(f'{self.package}: Passing property {prop} to parent') + + return await parent_pom.get_property(parent_prop) + else: + logger.warning(f'{self.package}: Parent {self.parent} does not have a pom file') + import pdb; pdb.set_trace() + return None + else: + return None - logger.debug(f'{package}: POM parsed') + async def _format_with_props(self, text): + arr = re.split(r'\$\{([^\}]*)\}', text) - def write(self, f): - tree = ET.ElementTree(self.generated_root) - ET.indent(tree) - tree.write(f) + for i in range(1, len(arr), 2): + prop = arr[i] + value = await self.get_property(prop) - def get_property(self, prop: str): - elem = self.raw_root.find(f'.//properties/{prop}', ns) - if elem is not None: - return elem.text - else: - return None + if value is None: + logger.warning(f'{self.package}: Property {prop} not found. Defaulting to an empty string') + value = '' - def _package_from_xml_dep(self, dep: ET.Element): - def prop_replace(match): - prop = match.group(1) - value = self.get_property(match.group(1)) - logger.debug(f'Replacing property {prop} with {value}') - return value + arr[i] = value + logger.debug(f'{self.package}: Replacing property "{prop}" with "{value}"') - return Package( + return ''.join(arr) + + async def _package_from_xml(self, dep: ET.Element): + package = Package( *[ - re.sub( - r'\$\{([^\}]*)\}', - prop_replace, - elem.text or '' if (elem := dep.find(tag, ns)) is not None else '', + await self._format_with_props( + elem.text or '' if (elem := dep.find(tag, ns)) is not None else '' ) for tag in [ @@ -110,29 +360,56 @@ class PackagePOM: ] ] ) + return package @property - def dependency_management(self) -> list['Package']: + async def dependency_management(self) -> list['Package']: dependencies: list[Package] = [] for dep in self.raw_root.find('dependencyManagement/dependencies', ns) or []: - package = self._package_from_xml_dep(dep) + package = await self._package_from_xml(dep) + logger.debug(f'{self.package}: Adding dependency management {package}') + dependencies.append(package) + + return dependencies + + @property + async def dependencies(self) -> list['Package']: + dependencies: list[Package] = [] + + for dep in self.raw_root.find('dependencies', ns) or []: + package = await self._package_from_xml(dep) + logger.debug(f'{self.package}: Adding dependency {package}') dependencies.append(package) return dependencies class Package: - _pom: PackagePOM | None = None _verified: bool = False - - def __init__(self, groupId: str, artifactId: str, version: str = None): - self.groupId = groupId - self.artifactId = artifactId - self.version = version + pom: PackagePOM | None = None + groupId: str + artifactId: str + version: MavenVersion | None + version_range: MavenVersionRange | None + + def __init__(self, groupId: str, artifactId: str, version: str = ''): + self.groupId = groupId.strip() + self.artifactId = artifactId.strip() + + match (val := get_maven_version_or_range(version.strip())): + case MavenVersion(): + self.version = val + self.version_range = None + case MavenVersionRange(): + self.version = None + self.version_range = val + + if not self.groupId or not self.artifactId: + logger.warning(f'{self}: groupId or artifactId is empty') def __str__(self) -> str: - return f'{self.groupId}:{self.artifactId}:{self.version or "----"}' + return f'{self.groupId or "----"}:{self.artifactId or "----"}:{self.version or "----"}' def __eq__(self, other) -> bool: return ( @@ -145,44 +422,84 @@ class Package: return hash((self.groupId, self.artifactId, self.version)) @property - def dir_path(self): + def package_dir_path(self): group_path = self.groupId.replace(".", "/") - return f'{group_path}/{self.artifactId}/{self.version}' + return f'{group_path}/{self.artifactId}' + + @property + def version_dir_path(self): + return self.package_dir_path + f'/{self.version}' @property def base_filename(self): return f'{self.artifactId}-{self.version}' - async def download_file(self, extension): - filepath = f'{self.dir_path}/{self.base_filename}.{extension}' + async def download_file_bytes(self, extension) -> bytes | None: + filepath = f'{self.version_dir_path}/{self.base_filename}.{extension}' - async with aiohttp.ClientSession() as session: - for mirror in mirrors: - pom_url = f'{mirror}/{filepath}' - logger.debug(f'{self}: Downloading {extension} from {pom_url}') + for mirror in mirrors: + pom_url = f'{mirror}/{filepath}' + logger.debug(f'{self}: Downloading {extension} from {pom_url}') - async with session.get(pom_url) as response: - if response.status == 200: - logger.debug(f'{self}: {extension} downloaded') - return await response.text() - break - else: - logger.debug(f'{self}: HTTP error {response.status} from mirror {mirror}') - else: - logger.warning(f'{self}: File download of {extension} failed for all mirrors') - return None + async with session.get(pom_url) as response: + if response.status == 200: + logger.debug(f'{self}: {extension} downloaded') + return await response.read() + else: + logger.debug(f'{self}: HTTP error {response.status} from mirror {mirror}') + else: + logger.warning(f'{self}: File download of {extension} failed for all mirrors') + return None - @property - async def pom(self) -> PackagePOM: - if self._pom is not None: - return self._pom + async def download_file_text(self, extension: str) -> str | None: + b = await self.download_file_bytes(extension) + if b is not None: + return b.decode('utf-8') + else: + return None + + async def download_all(self, out_dir: Path) -> bool: + basedir = out_dir / self.version_dir_path + basedir.mkdir(exist_ok=True, parents=True) + basepath = basedir / self.base_filename + + pom = self.pom + if not pom: + return False + + match pom.packaging: + case 'pom': + return True + case 'jar' | 'maven-plugin' | 'eclipse-plugin' | 'bundle': + logger.debug(f'{self}: Downloading JAR') + return True # TODO: Remove to test JAR download + # TODO: Handle checksums + jar = await self.download_file_bytes('jar') + if jar: + with basepath.with_suffix('.jar').open('wb') as f: + f.write(jar) + return True + else: + logger.warning(f'{self}: JAR not found') + return False + case _: + logger.warning(f'{self}: Unknown packaging {pom.packaging}') + return False + + return False + + async def fetch_pom(self) -> None: + if self.pom is not None: + await self.pom.parse_pom() + return if self.version is None: await self._query_maven() - self._pom = PackagePOM(self, await self.download_file('pom')) - - return self._pom + xml = await self.download_file_text('pom') + if xml is not None: + self.pom = PackagePOM(self, xml) + await self.pom.parse_pom() @property def _urlquery(self) -> str: @@ -194,28 +511,58 @@ class Package: return q async def _query_maven(self) -> None: - url = f'{baseurl}/solrsearch/select?q={self._urlquery}&rows=1&wt=json' - logger.debug(f'{self}: Querying maven at url {url}') + for mirror in mirrors: + url = f'{mirror}/{self.package_dir_path}/maven-metadata.xml' + logger.debug(f'{self}: Querying maven metadata at url {url}') - async with aiohttp.ClientSession() as session: async with session.get(url) as response: if response.status == 200: - message = await response.json() - num = message['response']['numFound'] - - if num: - logger.debug(f'{self}: Query successful') - self._verified = True - if self.version is None: - version = message['response']['docs'][0]['latestVersion'] - logger.debug(f'{self}: Using newest version {version}') - self.version = version + message = await response.text() + xml = ET.fromstring(message) + version_tags = xml.findall('./versioning/versions/version') + versions = [MavenVersion(version.text) for version in version_tags if version.text is not None] + + logger.debug(f'{self}: Query successful') + + if not self.version: + elem = xml.find('./versioning/latest') + if elem is not None and (text := elem.text): + self._verified = True + self.version = MavenVersion(text) + logger.debug(f'{self}: Using newest version {self.version}') + break + else: + logger.debug(f'{self}: No latest version marked in metadata') + + if not versions: + logger.warning(f'{self}: No versions available in metadata') + else: + self._verified = True + self.version = max(versions) + logger.debug(f'{self}: Using version {self.version}') + break + elif self.version is not None: + if self.version in versions: + self._verified = True + break + elif self.version_range is not None: + valid_versions = [v for v in versions if self.version_range.matches(v)] + + if valid_versions: + self.version = max(valid_versions) + self._verified = True + logger.debug(f'{self}: Picked version {self.version} from range {self.version_range}') + break + else: + logger.debug(f"{self}: Available versions '{versions}' doesn't match range {self.version_range}") + continue else: - logger.warning(f'{self}: No matching packages found') - self._verified = False + raise RuntimeError(f'{self}: Unknown version type {self.version}') else: self._verified = False - logger.warning(f'{self}: HTTP error {response.status} downloading pom') + logger.debug(f'{self}: Package not found in mirror') + else: + logger.warning(f'{self}: Package not found') async def verify(self) -> bool: if not self._verified: @@ -223,7 +570,7 @@ class Package: return self._verified -def load_package_list(list_path: Path, queue: asyncio.Queue) -> None: +async def load_package_list(list_path: Path) -> None: logger.info(f'Parsing {list_path}') with list_path.open('r') as f: @@ -237,77 +584,140 @@ def load_package_list(list_path: Path, queue: asyncio.Queue) -> None: package = Package( sections[0], sections[1], - sections[2] if len(sections) == 3 else None, + sections[2] if len(sections) == 3 else '', ) - queue.put_nowait(package) - - -async def download(package: Package, queue: asyncio.Queue) -> None: - async with done_lock: - skip = str(package) in done - - if skip: - logger.info(f'{package}: Already downloaded. Skipping.') - elif await package.verify(): - async with done_lock: - done.add(str(package)) - - pom_dir = base_pom_path / str(package) - pom_path = pom_dir / 'pom.xml' + await package.verify() + await queue.put(package, timeout=10) - pom_dir.mkdir(exist_ok=True) - pom = await package.pom +async def download(package: Package) -> None: + if await package.verify(): + await package.fetch_pom() + success = await package.download_all(output_dir) - if not pom: + if success: + logger.info(f'{package}: Downloaded') + else: + logger.warning(f'{package}: Download failed') return - pom.write(pom_path) - logger.info(f'{package}: Downloaded') + pom = package.pom - if not pom.is_bom: - for dep in pom.dependency_management: - logger.info(f'{package}: Handling transitive dependency {dep}') + if pom is not None: + deps = await pom.dependencies + depman = await pom.dependency_management + + for dep in [*deps, *depman]: await queue.put(dep) + else: + logger.warning(f'{package}: POM not found. Skipping') else: - logger.warning(f'{package}: Package not found. Check package name and internet connection') + logger.warning(f"{package}: Can't verify package") -async def worker(queue: asyncio.Queue) -> None: +async def file_parse_worker(id: str) -> None: while True: - package = await queue.get() - await download(package, queue) - queue.task_done() - - -async def main() -> None: - queue: asyncio.Queue = asyncio.Queue() + try: + package = await queue.get(timeout=5) + except asyncio.TimeoutError: + logger.error(f'Worker {id} timed out waiting for queue. Stopping worker') + break + + try: + await asyncio.wait_for(package.fetch_pom(), timeout=60) + except asyncio.TimeoutError as e: + logger.error(f'{package}: Timeout out waiting for download', exc_info=e) + except RequiresPackage as e: + await queue.put(e.package) + await queue.requeue(package) + except Exception as e: + logger.error(f'{package}:', exc_info=e) + finally: + await queue.done(package) + + +async def download_worker(id: str) -> None: + while True: + try: + package = await queue.get(timeout=5) + except asyncio.TimeoutError: + if queue.empty(): + logger.error(f'Worker {id} timed out waiting for empty queue. Stopping worker') + else: + logger.error(f'Worker {id} timed out waiting for queue lock. Stopping worker') + break + + try: + if package: + await asyncio.wait_for(download(package), timeout=60) + except asyncio.TimeoutError as e: + logger.error(f'{package}: Timeout out waiting for download', exc_info=e) + except RequiresPackage as e: + await queue.put(e.package) + await queue.requeue(package) + except Exception as e: + logger.error(f'{package}:', exc_info=e) + finally: + await queue.done(package) + + +async def parse_list_tasks() -> None: + #tasks = [] + + #logger.debug(f'Starting {num_workers} download workers') + #for i in range(num_workers): + # tasks.append( + # asyncio.create_task( + # file_parse_worker(str(i)) + # ) + # ) + + await load_package_list(Path('package-list.txt')) + #exceptions = await asyncio.gather(*tasks, return_exceptions=True) + + #for e in exceptions: + #logger.debug('Worker exception', exc_info=e) + + +async def download_tasks() -> None: tasks = [] - load_package_list(Path('package-list.txt'), queue) - - logger.debug(f'Starting {num_workers} workers') + logger.debug(f'Starting {num_workers} download workers') for i in range(num_workers): tasks.append( asyncio.create_task( - worker(queue) + download_worker(str(i)) ) ) - await queue.join() + try: + await queue.join() + logger.debug('Queue is empty. Cancelling workers') + except asyncio.CancelledError: + logger.info('Tasks cancelled') - logger.debug('Queue is empty. Cancelling workers') for task in tasks: task.cancel() - await asyncio.gather(*tasks, return_exceptions=True) + exceptions = await asyncio.gather(*tasks, return_exceptions=True) - logger.info('Generating master POM') - subprocess.call(['sh', 'generate_master_pom.sh']) + for e in exceptions: + logger.debug(f'Worker exception {e}', exc_info=e) -logger = logging.getLogger(__name__) +async def main() -> None: + global queue + global session + queue = PackageQueue() + session = ClientSession() + + try: + await parse_list_tasks() + await download_tasks() + finally: + await session.close() + if __name__ == '__main__': parser = argparse.ArgumentParser()