| @@ -4,16 +4,18 @@ import re | |||
| import argparse | |||
| import logging | |||
| import asyncio | |||
| import subprocess | |||
| import copy | |||
| import aiohttp | |||
| from typing import Union, Optional | |||
| from enum import Enum, auto | |||
| from aiohttp import ClientSession | |||
| from pathlib import Path | |||
| from xml.etree import ElementTree as ET | |||
| logger = logging.getLogger(__name__) | |||
| ns = {'': 'http://maven.apache.org/POM/4.0.0'} | |||
| ET.register_namespace('', ns['']) | |||
| output_dir = Path('output') | |||
| baseurl = 'https://search.maven.org' | |||
| base_pom_path = Path('poms') | |||
| mirrors = [ | |||
| @@ -25,82 +27,330 @@ mirrors = [ | |||
| "https://plugins.gradle.org/m2", | |||
| ] | |||
| done: set[str] = set() | |||
| done_lock = asyncio.Lock() | |||
| java_version = 11 | |||
| num_workers = 50 | |||
| queue: 'PackageQueue' | |||
| session: ClientSession | |||
| class RequiresPackage(Exception): | |||
| def __init__(self, package: 'Package'): | |||
| super().__init__(f'Requires package {package}') | |||
| self.package = package | |||
| class PackageState(Enum): | |||
| PENDING = auto() | |||
| DONE = auto() | |||
| class PackageQueue: | |||
| def __init__(self): | |||
| self.lock = asyncio.Lock() | |||
| self.queue = asyncio.Queue() | |||
| self.packages: dict[str, tuple['Package', 'PackageState']] = {} | |||
| async def put(self, package: 'Package', timeout: int | None = None) -> None: | |||
| key = str(package) | |||
| async with self.lock: | |||
| if key not in self.packages: | |||
| logger.debug(f'{key}: Added to list') | |||
| self.packages[key] = (package, PackageState.PENDING) | |||
| else: | |||
| return | |||
| await asyncio.wait_for(self.queue.put(package), timeout) | |||
| async def requeue(self, package: 'Package', timeout: int | None = None) -> None: | |||
| key = str(package) | |||
| async with self.lock: | |||
| exists = key in self.packages | |||
| if exists: | |||
| await asyncio.wait_for(self.queue.put(package), timeout) | |||
| self.queue.task_done() | |||
| async def get(self, timeout: int | None = None) -> 'Package': | |||
| while True: | |||
| package = await asyncio.wait_for(self.queue.get(), timeout) | |||
| key = str(package) | |||
| async with self.lock: | |||
| if key in self.packages: | |||
| state = self.packages[key][1] | |||
| match state: | |||
| case PackageState.DONE: | |||
| logger.debug(f'{package}: Already downloaded. Skipping.') | |||
| case PackageState.PENDING: | |||
| return package | |||
| case _: | |||
| logger.warning(f'{package}: Unknown state {state}') | |||
| else: | |||
| logger.warning(f'{package}: Package is in queue but not in package list') | |||
| async def get_done(self, package: Union['Package', str]) -> Optional['Package']: | |||
| key = str(package) | |||
| async with self.lock: | |||
| if key in self.packages: | |||
| pack, state = self.packages[key] | |||
| if state == PackageState.DONE: | |||
| return pack | |||
| return None | |||
| async def done(self, package: 'Package') -> None: | |||
| key = str(package) | |||
| async with self.lock: | |||
| if key not in self.packages: | |||
| logger.warning(f'{package}: Unknown package marked as done') | |||
| self.packages[key] = (package, PackageState.DONE) | |||
| self.queue.task_done() | |||
| async def set_version(self, package): | |||
| key = str(package) | |||
| if key.count(':') == 2: | |||
| a, b, c = key.split(':') | |||
| x = ':'.join([a, b, '----']) | |||
| else: | |||
| raise RuntimeError(f'Malformed package specifier {package}') | |||
| async def join(self) -> None: | |||
| await self.queue.join() | |||
| def empty(self) -> bool: | |||
| return self.queue.empty() | |||
| class MavenVersion: | |||
| qualifiers = ( | |||
| ('alpha', 'a'), | |||
| ('beta', 'b'), | |||
| ('milestone', 'mc'), | |||
| ('rc'), | |||
| ('sp'), | |||
| ('ga'), | |||
| ('final'), | |||
| ) | |||
| def __init__(self, version: str): | |||
| self.raw = version | |||
| self.parts = [x for x in re.split(r'[\.-]|((?<=\d)(?=\w))', version) if x] | |||
| def __str__(self) -> str: | |||
| return self.raw | |||
| def _compare(self, other): | |||
| match other: | |||
| case MavenVersion(): | |||
| pass | |||
| case str(): | |||
| other = MavenVersion(other) | |||
| case _: | |||
| return False | |||
| for x, y in zip(self.parts, other.parts): | |||
| if x == y: | |||
| return 0 | |||
| elif x.isnumeric(): | |||
| if y.isnumeric(): | |||
| if int(x) > int(y): | |||
| return 1 | |||
| elif int(x) < int(y): | |||
| return -1 | |||
| else: | |||
| return 1 | |||
| elif y.isnumeric(): | |||
| return -1 | |||
| else: | |||
| def qualifier_index(qualifier: str) -> int | None: | |||
| for i, q in enumerate(self.qualifiers): | |||
| for alias in q: | |||
| if alias == qualifier: | |||
| return i | |||
| return None | |||
| xi = qualifier_index(x) | |||
| yi = qualifier_index(y) | |||
| if xi is not None and yi is not None: | |||
| if xi < yi: | |||
| return 1 | |||
| elif xi > yi: | |||
| return -1 | |||
| raise RuntimeError(f"Can't compare qualifier {x} and {y} from version {self} and {other}") | |||
| return 0 | |||
| def __eq__(self, other) -> bool: | |||
| try: | |||
| return self._compare(other) == 0 | |||
| except RuntimeError: | |||
| return False | |||
| def __ne__(self, other) -> bool: | |||
| try: | |||
| return self._compare(other) != 0 | |||
| except RuntimeError: | |||
| return True | |||
| def __gt__(self, other) -> bool: | |||
| return self._compare(other) == 1 | |||
| def __ge__(self, other) -> bool: | |||
| return self._compare(other) in [0, 1] | |||
| def __lt__(self, other) -> bool: | |||
| return self._compare(other) == -1 | |||
| def __le__(self, other) -> bool: | |||
| return self._compare(other) in [-1, 0] | |||
| class MavenVersionRange: | |||
| def __init__(self, range_str: str): | |||
| self.range_str = range_str | |||
| def is_version_in_range(self, version: str) -> bool: | |||
| is_match = False | |||
| for _, low_bracket, content, high_bracket in re.findall(r'((\(|\[)([^\(\[\]\)]+)(\]|\)))', self.range_str): | |||
| bounds = content.split(',') | |||
| if len(bounds) == 2: | |||
| low_bound, high_bound = [ | |||
| MavenVersion(bound) if bound else None | |||
| for bound in bounds | |||
| ] | |||
| elif len(bounds) == 1: | |||
| low_bound = high_bound = MavenVersion(bounds[0]) if bounds[0] else None | |||
| else: | |||
| raise RuntimeError(f'Invalid version range {self.range_str}') | |||
| match low_bracket: | |||
| case '(': | |||
| if low_bound is not None: | |||
| if low_bound >= self: | |||
| continue | |||
| case '[': | |||
| if low_bound is not None: | |||
| if low_bound > self: | |||
| continue | |||
| else: | |||
| raise RuntimeError(f"Error in version range {self.range_str}. [ can't be unbounded") | |||
| match high_bracket: | |||
| case ')': | |||
| if high_bound is not None: | |||
| if high_bound <= self: | |||
| continue | |||
| case ']': | |||
| if high_bound is not None: | |||
| if high_bound > self: | |||
| continue | |||
| else: | |||
| raise RuntimeError(f"Error in version range {self.range_str}. ] can't be unbounded") | |||
| is_match = True | |||
| return is_match | |||
| def get_maven_version_or_range(version_or_range: str) -> MavenVersion | MavenVersionRange: | |||
| if re.fullmatch(r'([^\(\[\]\)]+)', version_or_range): | |||
| return MavenVersion(version_or_range) | |||
| else: | |||
| return MavenVersionRange(version_or_range) | |||
| class PackagePOM: | |||
| def __init__(self, package: 'Package', pom: str): | |||
| logger.debug(f'{package}: Parsing POM') | |||
| self.package = package | |||
| self.raw_root = ET.fromstring(pom) | |||
| async def parse_pom(self): | |||
| parent_tag = self.raw_root.find('parent', ns) | |||
| parent = await self._package_from_xml(parent_tag) if parent_tag else None | |||
| if parent is not None: | |||
| if (parent := await queue.get_done(parent)) is not None: | |||
| logger.debug(f'{self.package}: Using parent {parent}') | |||
| self.parent = parent | |||
| else: | |||
| logger.debug(f'{self.package}: Requires parent {parent}') | |||
| raise RequiresPackage(parent) | |||
| else: | |||
| self.parent = None | |||
| if (packaging := self.raw_root.find('packaging', ns)) is not None: | |||
| self.packaging = packaging.text | |||
| self.packaging = await self._format_with_props(packaging.text) | |||
| else: | |||
| self.packaging = '??' | |||
| self.packaging = 'jar' | |||
| self.is_bom = self.packaging == 'pom' | |||
| if self.packaging == 'pom': | |||
| root_copy = copy.deepcopy(self.raw_root) | |||
| depman = root_copy.find('dependencyManagement', ns) | |||
| if depman is not None: | |||
| root_copy.extend(depman.findall('*')) | |||
| root_copy.remove(depman) | |||
| self.generated_root = root_copy | |||
| else: | |||
| self.generated_root = ET.fromstring( | |||
| f""" | |||
| <project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 | |||
| https://maven.apache.org/xsd/maven-4.0.0.xsd" | |||
| xmlns="http://maven.apache.org/POM/4.0.0" | |||
| xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"> | |||
| <modelVersion>4.0.0</modelVersion> | |||
| <groupId>tmp.{package.groupId}</groupId> | |||
| <artifactId>placeholder-{package.artifactId}</artifactId> | |||
| <version>{package.version}</version> | |||
| <name>Package {package.artifactId}</name> | |||
| <dependencies> | |||
| <dependency> | |||
| <groupId>{package.groupId}</groupId> | |||
| <artifactId>{package.artifactId}</artifactId> | |||
| <version>{package.version}</version> | |||
| </dependency> | |||
| </dependencies> | |||
| </project> | |||
| """ | |||
| ) | |||
| logger.debug(f'{self.package}: POM parsed') | |||
| async def get_property(self, prop: str) -> str | None: | |||
| match prop: | |||
| case 'project.groupId' | 'pom.groupId': | |||
| return self.package.groupId | |||
| case 'project.artifactId' | 'pom.artifactId': | |||
| return self.package.artifactId | |||
| case 'project.version' | 'pom.version' | 'version': | |||
| return str(self.package.version) | |||
| case 'java.version': | |||
| return str(java_version) | |||
| case _: | |||
| elem = self.raw_root.find(f'.//properties/{prop}', ns) | |||
| if elem is not None: | |||
| return await self._format_with_props(elem.text) | |||
| elif self.parent: | |||
| parent_pom = self.parent.pom | |||
| if parent_pom is not None: | |||
| parent_prop = prop.replace('parent.', '') | |||
| if prop == parent_prop: | |||
| logger.debug(f'{self.package}: Passing property {prop} to parent as {parent_prop}') | |||
| else: | |||
| logger.debug(f'{self.package}: Passing property {prop} to parent') | |||
| return await parent_pom.get_property(parent_prop) | |||
| else: | |||
| logger.warning(f'{self.package}: Parent {self.parent} does not have a pom file') | |||
| import pdb; pdb.set_trace() | |||
| return None | |||
| else: | |||
| return None | |||
| logger.debug(f'{package}: POM parsed') | |||
| async def _format_with_props(self, text): | |||
| arr = re.split(r'\$\{([^\}]*)\}', text) | |||
| def write(self, f): | |||
| tree = ET.ElementTree(self.generated_root) | |||
| ET.indent(tree) | |||
| tree.write(f) | |||
| for i in range(1, len(arr), 2): | |||
| prop = arr[i] | |||
| value = await self.get_property(prop) | |||
| def get_property(self, prop: str): | |||
| elem = self.raw_root.find(f'.//properties/{prop}', ns) | |||
| if elem is not None: | |||
| return elem.text | |||
| else: | |||
| return None | |||
| if value is None: | |||
| logger.warning(f'{self.package}: Property {prop} not found. Defaulting to an empty string') | |||
| value = '' | |||
| def _package_from_xml_dep(self, dep: ET.Element): | |||
| def prop_replace(match): | |||
| prop = match.group(1) | |||
| value = self.get_property(match.group(1)) | |||
| logger.debug(f'Replacing property {prop} with {value}') | |||
| return value | |||
| arr[i] = value | |||
| logger.debug(f'{self.package}: Replacing property "{prop}" with "{value}"') | |||
| return Package( | |||
| return ''.join(arr) | |||
| async def _package_from_xml(self, dep: ET.Element): | |||
| package = Package( | |||
| *[ | |||
| re.sub( | |||
| r'\$\{([^\}]*)\}', | |||
| prop_replace, | |||
| elem.text or '' if (elem := dep.find(tag, ns)) is not None else '', | |||
| await self._format_with_props( | |||
| elem.text or '' if (elem := dep.find(tag, ns)) is not None else '' | |||
| ) | |||
| for tag in [ | |||
| @@ -110,29 +360,56 @@ class PackagePOM: | |||
| ] | |||
| ] | |||
| ) | |||
| return package | |||
| @property | |||
| def dependency_management(self) -> list['Package']: | |||
| async def dependency_management(self) -> list['Package']: | |||
| dependencies: list[Package] = [] | |||
| for dep in self.raw_root.find('dependencyManagement/dependencies', ns) or []: | |||
| package = self._package_from_xml_dep(dep) | |||
| package = await self._package_from_xml(dep) | |||
| logger.debug(f'{self.package}: Adding dependency management {package}') | |||
| dependencies.append(package) | |||
| return dependencies | |||
| @property | |||
| async def dependencies(self) -> list['Package']: | |||
| dependencies: list[Package] = [] | |||
| for dep in self.raw_root.find('dependencies', ns) or []: | |||
| package = await self._package_from_xml(dep) | |||
| logger.debug(f'{self.package}: Adding dependency {package}') | |||
| dependencies.append(package) | |||
| return dependencies | |||
| class Package: | |||
| _pom: PackagePOM | None = None | |||
| _verified: bool = False | |||
| def __init__(self, groupId: str, artifactId: str, version: str = None): | |||
| self.groupId = groupId | |||
| self.artifactId = artifactId | |||
| self.version = version | |||
| pom: PackagePOM | None = None | |||
| groupId: str | |||
| artifactId: str | |||
| version: MavenVersion | None | |||
| version_range: MavenVersionRange | None | |||
| def __init__(self, groupId: str, artifactId: str, version: str = ''): | |||
| self.groupId = groupId.strip() | |||
| self.artifactId = artifactId.strip() | |||
| match (val := get_maven_version_or_range(version.strip())): | |||
| case MavenVersion(): | |||
| self.version = val | |||
| self.version_range = None | |||
| case MavenVersionRange(): | |||
| self.version = None | |||
| self.version_range = val | |||
| if not self.groupId or not self.artifactId: | |||
| logger.warning(f'{self}: groupId or artifactId is empty') | |||
| def __str__(self) -> str: | |||
| return f'{self.groupId}:{self.artifactId}:{self.version or "----"}' | |||
| return f'{self.groupId or "----"}:{self.artifactId or "----"}:{self.version or "----"}' | |||
| def __eq__(self, other) -> bool: | |||
| return ( | |||
| @@ -145,44 +422,84 @@ class Package: | |||
| return hash((self.groupId, self.artifactId, self.version)) | |||
| @property | |||
| def dir_path(self): | |||
| def package_dir_path(self): | |||
| group_path = self.groupId.replace(".", "/") | |||
| return f'{group_path}/{self.artifactId}/{self.version}' | |||
| return f'{group_path}/{self.artifactId}' | |||
| @property | |||
| def version_dir_path(self): | |||
| return self.package_dir_path + f'/{self.version}' | |||
| @property | |||
| def base_filename(self): | |||
| return f'{self.artifactId}-{self.version}' | |||
| async def download_file(self, extension): | |||
| filepath = f'{self.dir_path}/{self.base_filename}.{extension}' | |||
| async def download_file_bytes(self, extension) -> bytes | None: | |||
| filepath = f'{self.version_dir_path}/{self.base_filename}.{extension}' | |||
| async with aiohttp.ClientSession() as session: | |||
| for mirror in mirrors: | |||
| pom_url = f'{mirror}/{filepath}' | |||
| logger.debug(f'{self}: Downloading {extension} from {pom_url}') | |||
| for mirror in mirrors: | |||
| pom_url = f'{mirror}/{filepath}' | |||
| logger.debug(f'{self}: Downloading {extension} from {pom_url}') | |||
| async with session.get(pom_url) as response: | |||
| if response.status == 200: | |||
| logger.debug(f'{self}: {extension} downloaded') | |||
| return await response.text() | |||
| break | |||
| else: | |||
| logger.debug(f'{self}: HTTP error {response.status} from mirror {mirror}') | |||
| else: | |||
| logger.warning(f'{self}: File download of {extension} failed for all mirrors') | |||
| return None | |||
| async with session.get(pom_url) as response: | |||
| if response.status == 200: | |||
| logger.debug(f'{self}: {extension} downloaded') | |||
| return await response.read() | |||
| else: | |||
| logger.debug(f'{self}: HTTP error {response.status} from mirror {mirror}') | |||
| else: | |||
| logger.warning(f'{self}: File download of {extension} failed for all mirrors') | |||
| return None | |||
| @property | |||
| async def pom(self) -> PackagePOM: | |||
| if self._pom is not None: | |||
| return self._pom | |||
| async def download_file_text(self, extension: str) -> str | None: | |||
| b = await self.download_file_bytes(extension) | |||
| if b is not None: | |||
| return b.decode('utf-8') | |||
| else: | |||
| return None | |||
| async def download_all(self, out_dir: Path) -> bool: | |||
| basedir = out_dir / self.version_dir_path | |||
| basedir.mkdir(exist_ok=True, parents=True) | |||
| basepath = basedir / self.base_filename | |||
| pom = self.pom | |||
| if not pom: | |||
| return False | |||
| match pom.packaging: | |||
| case 'pom': | |||
| return True | |||
| case 'jar' | 'maven-plugin' | 'eclipse-plugin' | 'bundle': | |||
| logger.debug(f'{self}: Downloading JAR') | |||
| return True # TODO: Remove to test JAR download | |||
| # TODO: Handle checksums | |||
| jar = await self.download_file_bytes('jar') | |||
| if jar: | |||
| with basepath.with_suffix('.jar').open('wb') as f: | |||
| f.write(jar) | |||
| return True | |||
| else: | |||
| logger.warning(f'{self}: JAR not found') | |||
| return False | |||
| case _: | |||
| logger.warning(f'{self}: Unknown packaging {pom.packaging}') | |||
| return False | |||
| return False | |||
| async def fetch_pom(self) -> None: | |||
| if self.pom is not None: | |||
| await self.pom.parse_pom() | |||
| return | |||
| if self.version is None: | |||
| await self._query_maven() | |||
| self._pom = PackagePOM(self, await self.download_file('pom')) | |||
| return self._pom | |||
| xml = await self.download_file_text('pom') | |||
| if xml is not None: | |||
| self.pom = PackagePOM(self, xml) | |||
| await self.pom.parse_pom() | |||
| @property | |||
| def _urlquery(self) -> str: | |||
| @@ -194,28 +511,58 @@ class Package: | |||
| return q | |||
| async def _query_maven(self) -> None: | |||
| url = f'{baseurl}/solrsearch/select?q={self._urlquery}&rows=1&wt=json' | |||
| logger.debug(f'{self}: Querying maven at url {url}') | |||
| for mirror in mirrors: | |||
| url = f'{mirror}/{self.package_dir_path}/maven-metadata.xml' | |||
| logger.debug(f'{self}: Querying maven metadata at url {url}') | |||
| async with aiohttp.ClientSession() as session: | |||
| async with session.get(url) as response: | |||
| if response.status == 200: | |||
| message = await response.json() | |||
| num = message['response']['numFound'] | |||
| if num: | |||
| logger.debug(f'{self}: Query successful') | |||
| self._verified = True | |||
| if self.version is None: | |||
| version = message['response']['docs'][0]['latestVersion'] | |||
| logger.debug(f'{self}: Using newest version {version}') | |||
| self.version = version | |||
| message = await response.text() | |||
| xml = ET.fromstring(message) | |||
| version_tags = xml.findall('./versioning/versions/version') | |||
| versions = [MavenVersion(version.text) for version in version_tags if version.text is not None] | |||
| logger.debug(f'{self}: Query successful') | |||
| if not self.version: | |||
| elem = xml.find('./versioning/latest') | |||
| if elem is not None and (text := elem.text): | |||
| self._verified = True | |||
| self.version = MavenVersion(text) | |||
| logger.debug(f'{self}: Using newest version {self.version}') | |||
| break | |||
| else: | |||
| logger.debug(f'{self}: No latest version marked in metadata') | |||
| if not versions: | |||
| logger.warning(f'{self}: No versions available in metadata') | |||
| else: | |||
| self._verified = True | |||
| self.version = max(versions) | |||
| logger.debug(f'{self}: Using version {self.version}') | |||
| break | |||
| elif self.version is not None: | |||
| if self.version in versions: | |||
| self._verified = True | |||
| break | |||
| elif self.version_range is not None: | |||
| valid_versions = [v for v in versions if self.version_range.matches(v)] | |||
| if valid_versions: | |||
| self.version = max(valid_versions) | |||
| self._verified = True | |||
| logger.debug(f'{self}: Picked version {self.version} from range {self.version_range}') | |||
| break | |||
| else: | |||
| logger.debug(f"{self}: Available versions '{versions}' doesn't match range {self.version_range}") | |||
| continue | |||
| else: | |||
| logger.warning(f'{self}: No matching packages found') | |||
| self._verified = False | |||
| raise RuntimeError(f'{self}: Unknown version type {self.version}') | |||
| else: | |||
| self._verified = False | |||
| logger.warning(f'{self}: HTTP error {response.status} downloading pom') | |||
| logger.debug(f'{self}: Package not found in mirror') | |||
| else: | |||
| logger.warning(f'{self}: Package not found') | |||
| async def verify(self) -> bool: | |||
| if not self._verified: | |||
| @@ -223,7 +570,7 @@ class Package: | |||
| return self._verified | |||
| def load_package_list(list_path: Path, queue: asyncio.Queue) -> None: | |||
| async def load_package_list(list_path: Path) -> None: | |||
| logger.info(f'Parsing {list_path}') | |||
| with list_path.open('r') as f: | |||
| @@ -237,77 +584,140 @@ def load_package_list(list_path: Path, queue: asyncio.Queue) -> None: | |||
| package = Package( | |||
| sections[0], | |||
| sections[1], | |||
| sections[2] if len(sections) == 3 else None, | |||
| sections[2] if len(sections) == 3 else '', | |||
| ) | |||
| queue.put_nowait(package) | |||
| async def download(package: Package, queue: asyncio.Queue) -> None: | |||
| async with done_lock: | |||
| skip = str(package) in done | |||
| if skip: | |||
| logger.info(f'{package}: Already downloaded. Skipping.') | |||
| elif await package.verify(): | |||
| async with done_lock: | |||
| done.add(str(package)) | |||
| pom_dir = base_pom_path / str(package) | |||
| pom_path = pom_dir / 'pom.xml' | |||
| await package.verify() | |||
| await queue.put(package, timeout=10) | |||
| pom_dir.mkdir(exist_ok=True) | |||
| pom = await package.pom | |||
| async def download(package: Package) -> None: | |||
| if await package.verify(): | |||
| await package.fetch_pom() | |||
| success = await package.download_all(output_dir) | |||
| if not pom: | |||
| if success: | |||
| logger.info(f'{package}: Downloaded') | |||
| else: | |||
| logger.warning(f'{package}: Download failed') | |||
| return | |||
| pom.write(pom_path) | |||
| logger.info(f'{package}: Downloaded') | |||
| pom = package.pom | |||
| if not pom.is_bom: | |||
| for dep in pom.dependency_management: | |||
| logger.info(f'{package}: Handling transitive dependency {dep}') | |||
| if pom is not None: | |||
| deps = await pom.dependencies | |||
| depman = await pom.dependency_management | |||
| for dep in [*deps, *depman]: | |||
| await queue.put(dep) | |||
| else: | |||
| logger.warning(f'{package}: POM not found. Skipping') | |||
| else: | |||
| logger.warning(f'{package}: Package not found. Check package name and internet connection') | |||
| logger.warning(f"{package}: Can't verify package") | |||
| async def worker(queue: asyncio.Queue) -> None: | |||
| async def file_parse_worker(id: str) -> None: | |||
| while True: | |||
| package = await queue.get() | |||
| await download(package, queue) | |||
| queue.task_done() | |||
| async def main() -> None: | |||
| queue: asyncio.Queue = asyncio.Queue() | |||
| try: | |||
| package = await queue.get(timeout=5) | |||
| except asyncio.TimeoutError: | |||
| logger.error(f'Worker {id} timed out waiting for queue. Stopping worker') | |||
| break | |||
| try: | |||
| await asyncio.wait_for(package.fetch_pom(), timeout=60) | |||
| except asyncio.TimeoutError as e: | |||
| logger.error(f'{package}: Timeout out waiting for download', exc_info=e) | |||
| except RequiresPackage as e: | |||
| await queue.put(e.package) | |||
| await queue.requeue(package) | |||
| except Exception as e: | |||
| logger.error(f'{package}:', exc_info=e) | |||
| finally: | |||
| await queue.done(package) | |||
| async def download_worker(id: str) -> None: | |||
| while True: | |||
| try: | |||
| package = await queue.get(timeout=5) | |||
| except asyncio.TimeoutError: | |||
| if queue.empty(): | |||
| logger.error(f'Worker {id} timed out waiting for empty queue. Stopping worker') | |||
| else: | |||
| logger.error(f'Worker {id} timed out waiting for queue lock. Stopping worker') | |||
| break | |||
| try: | |||
| if package: | |||
| await asyncio.wait_for(download(package), timeout=60) | |||
| except asyncio.TimeoutError as e: | |||
| logger.error(f'{package}: Timeout out waiting for download', exc_info=e) | |||
| except RequiresPackage as e: | |||
| await queue.put(e.package) | |||
| await queue.requeue(package) | |||
| except Exception as e: | |||
| logger.error(f'{package}:', exc_info=e) | |||
| finally: | |||
| await queue.done(package) | |||
| async def parse_list_tasks() -> None: | |||
| #tasks = [] | |||
| #logger.debug(f'Starting {num_workers} download workers') | |||
| #for i in range(num_workers): | |||
| # tasks.append( | |||
| # asyncio.create_task( | |||
| # file_parse_worker(str(i)) | |||
| # ) | |||
| # ) | |||
| await load_package_list(Path('package-list.txt')) | |||
| #exceptions = await asyncio.gather(*tasks, return_exceptions=True) | |||
| #for e in exceptions: | |||
| #logger.debug('Worker exception', exc_info=e) | |||
| async def download_tasks() -> None: | |||
| tasks = [] | |||
| load_package_list(Path('package-list.txt'), queue) | |||
| logger.debug(f'Starting {num_workers} workers') | |||
| logger.debug(f'Starting {num_workers} download workers') | |||
| for i in range(num_workers): | |||
| tasks.append( | |||
| asyncio.create_task( | |||
| worker(queue) | |||
| download_worker(str(i)) | |||
| ) | |||
| ) | |||
| await queue.join() | |||
| try: | |||
| await queue.join() | |||
| logger.debug('Queue is empty. Cancelling workers') | |||
| except asyncio.CancelledError: | |||
| logger.info('Tasks cancelled') | |||
| logger.debug('Queue is empty. Cancelling workers') | |||
| for task in tasks: | |||
| task.cancel() | |||
| await asyncio.gather(*tasks, return_exceptions=True) | |||
| exceptions = await asyncio.gather(*tasks, return_exceptions=True) | |||
| logger.info('Generating master POM') | |||
| subprocess.call(['sh', 'generate_master_pom.sh']) | |||
| for e in exceptions: | |||
| logger.debug(f'Worker exception {e}', exc_info=e) | |||
| logger = logging.getLogger(__name__) | |||
| async def main() -> None: | |||
| global queue | |||
| global session | |||
| queue = PackageQueue() | |||
| session = ClientSession() | |||
| try: | |||
| await parse_list_tasks() | |||
| await download_tasks() | |||
| finally: | |||
| await session.close() | |||
| if __name__ == '__main__': | |||
| parser = argparse.ArgumentParser() | |||