|
- #!/bin/python3
-
- import re
- import argparse
- import logging
- import asyncio
- from typing import Union, Optional
- from enum import Enum, auto
- from aiohttp import ClientSession
- from pathlib import Path
- from xml.etree import ElementTree as ET
-
- logger = logging.getLogger(__name__)
-
- ns = {'': 'http://maven.apache.org/POM/4.0.0'}
- ET.register_namespace('', ns[''])
-
- output_dir = Path('output')
- baseurl = 'https://search.maven.org'
- base_pom_path = Path('poms')
- mirrors = [
- "https://repo.maven.apache.org/maven2",
- "https://repo1.maven.org/maven2",
- "https://oss.sonatype.org/content/repositories/snapshots",
- "https://packages.confluent.io/maven",
- "https://registry.quarkus.io/maven",
- "https://plugins.gradle.org/m2",
- ]
-
- java_version = 11
- num_workers = 50
-
- queue: 'PackageQueue'
- session: ClientSession
-
-
- class RequiresPackage(Exception):
- def __init__(self, package: 'Package'):
- super().__init__(f'Requires package {package}')
- self.package = package
-
-
- class PackageState(Enum):
- PENDING = auto()
- DONE = auto()
-
-
- class PackageQueue:
- def __init__(self):
- self.lock = asyncio.Lock()
- self.queue = asyncio.Queue()
- self.packages: dict[str, tuple['Package', 'PackageState']] = {}
-
- async def put(self, package: 'Package', timeout: int | None = None) -> None:
- key = str(package)
- async with self.lock:
- if key not in self.packages:
- logger.debug(f'{key}: Added to list')
- self.packages[key] = (package, PackageState.PENDING)
- else:
- return
-
- await asyncio.wait_for(self.queue.put(package), timeout)
-
- async def requeue(self, package: 'Package', timeout: int | None = None) -> None:
- key = str(package)
- async with self.lock:
- exists = key in self.packages
-
- if exists:
- await asyncio.wait_for(self.queue.put(package), timeout)
- self.queue.task_done()
-
- async def get(self, timeout: int | None = None) -> 'Package':
- while True:
- package = await asyncio.wait_for(self.queue.get(), timeout)
- key = str(package)
-
- async with self.lock:
- if key in self.packages:
- state = self.packages[key][1]
- match state:
- case PackageState.DONE:
- logger.debug(f'{package}: Already downloaded. Skipping.')
- case PackageState.PENDING:
- return package
- case _:
- logger.warning(f'{package}: Unknown state {state}')
- else:
- logger.warning(f'{package}: Package is in queue but not in package list')
-
- async def get_done(self, package: Union['Package', str]) -> Optional['Package']:
- key = str(package)
-
- async with self.lock:
- if key in self.packages:
- pack, state = self.packages[key]
- if state == PackageState.DONE:
- return pack
-
- return None
-
- async def done(self, package: 'Package') -> None:
- key = str(package)
- async with self.lock:
- if key not in self.packages:
- logger.warning(f'{package}: Unknown package marked as done')
-
- self.packages[key] = (package, PackageState.DONE)
-
- self.queue.task_done()
-
- async def set_version(self, package):
- key = str(package)
-
- if key.count(':') == 2:
- a, b, c = key.split(':')
- x = ':'.join([a, b, '----'])
- else:
- raise RuntimeError(f'Malformed package specifier {package}')
-
- async def join(self) -> None:
- await self.queue.join()
-
- def empty(self) -> bool:
- return self.queue.empty()
-
-
-
- class MavenVersion:
- qualifiers = (
- ('alpha', 'a'),
- ('beta', 'b'),
- ('milestone', 'mc'),
- ('rc'),
- ('sp'),
- ('ga'),
- ('final'),
- )
-
- def __init__(self, version: str):
- self.raw = version
- self.parts = [x for x in re.split(r'[\.-]|((?<=\d)(?=\w))', version) if x]
-
- def __str__(self) -> str:
- return self.raw
-
- def _compare(self, other):
- match other:
- case MavenVersion():
- pass
- case str():
- other = MavenVersion(other)
- case _:
- return False
-
- for x, y in zip(self.parts, other.parts):
- if x == y:
- return 0
- elif x.isnumeric():
- if y.isnumeric():
- if int(x) > int(y):
- return 1
- elif int(x) < int(y):
- return -1
- else:
- return 1
- elif y.isnumeric():
- return -1
- else:
- def qualifier_index(qualifier: str) -> int | None:
- for i, q in enumerate(self.qualifiers):
- for alias in q:
- if alias == qualifier:
- return i
-
- return None
-
- xi = qualifier_index(x)
- yi = qualifier_index(y)
-
- if xi is not None and yi is not None:
- if xi < yi:
- return 1
- elif xi > yi:
- return -1
-
- raise RuntimeError(f"Can't compare qualifier {x} and {y} from version {self} and {other}")
-
- return 0
-
- def __eq__(self, other) -> bool:
- try:
- return self._compare(other) == 0
- except RuntimeError:
- return False
-
- def __ne__(self, other) -> bool:
- try:
- return self._compare(other) != 0
- except RuntimeError:
- return True
-
- def __gt__(self, other) -> bool:
- return self._compare(other) == 1
-
- def __ge__(self, other) -> bool:
- return self._compare(other) in [0, 1]
-
- def __lt__(self, other) -> bool:
- return self._compare(other) == -1
-
- def __le__(self, other) -> bool:
- return self._compare(other) in [-1, 0]
-
-
- class MavenVersionRange:
- def __init__(self, range_str: str):
- self.range_str = range_str
-
- def is_version_in_range(self, version: str) -> bool:
- is_match = False
-
- for _, low_bracket, content, high_bracket in re.findall(r'((\(|\[)([^\(\[\]\)]+)(\]|\)))', self.range_str):
- bounds = content.split(',')
-
- if len(bounds) == 2:
- low_bound, high_bound = [
- MavenVersion(bound) if bound else None
- for bound in bounds
- ]
- elif len(bounds) == 1:
- low_bound = high_bound = MavenVersion(bounds[0]) if bounds[0] else None
- else:
- raise RuntimeError(f'Invalid version range {self.range_str}')
-
- match low_bracket:
- case '(':
- if low_bound is not None:
- if low_bound >= self:
- continue
- case '[':
- if low_bound is not None:
- if low_bound > self:
- continue
- else:
- raise RuntimeError(f"Error in version range {self.range_str}. [ can't be unbounded")
-
- match high_bracket:
- case ')':
- if high_bound is not None:
- if high_bound <= self:
- continue
- case ']':
- if high_bound is not None:
- if high_bound > self:
- continue
- else:
- raise RuntimeError(f"Error in version range {self.range_str}. ] can't be unbounded")
-
- is_match = True
-
- return is_match
-
-
- def get_maven_version_or_range(version_or_range: str) -> MavenVersion | MavenVersionRange:
- if re.fullmatch(r'([^\(\[\]\)]+)', version_or_range):
- return MavenVersion(version_or_range)
- else:
- return MavenVersionRange(version_or_range)
-
-
- class PackagePOM:
- def __init__(self, package: 'Package', pom: str):
- logger.debug(f'{package}: Parsing POM')
- self.package = package
- self.raw_root = ET.fromstring(pom)
-
- async def parse_pom(self):
- parent_tag = self.raw_root.find('parent', ns)
- parent = await self._package_from_xml(parent_tag) if parent_tag else None
-
- if parent is not None:
- if (parent := await queue.get_done(parent)) is not None:
- logger.debug(f'{self.package}: Using parent {parent}')
- self.parent = parent
- else:
- logger.debug(f'{self.package}: Requires parent {parent}')
- raise RequiresPackage(parent)
- else:
- self.parent = None
-
- if (packaging := self.raw_root.find('packaging', ns)) is not None:
- self.packaging = await self._format_with_props(packaging.text)
- else:
- self.packaging = 'jar'
-
- self.is_bom = self.packaging == 'pom'
-
- logger.debug(f'{self.package}: POM parsed')
-
- async def get_property(self, prop: str) -> str | None:
- match prop:
- case 'project.groupId' | 'pom.groupId':
- return self.package.groupId
- case 'project.artifactId' | 'pom.artifactId':
- return self.package.artifactId
- case 'project.version' | 'pom.version' | 'version':
- return str(self.package.version)
- case 'java.version':
- return str(java_version)
- case _:
- elem = self.raw_root.find(f'.//properties/{prop}', ns)
- if elem is not None:
- return await self._format_with_props(elem.text)
- elif self.parent:
- parent_pom = self.parent.pom
- if parent_pom is not None:
- parent_prop = prop.replace('parent.', '')
- if prop == parent_prop:
- logger.debug(f'{self.package}: Passing property {prop} to parent as {parent_prop}')
- else:
- logger.debug(f'{self.package}: Passing property {prop} to parent')
-
- return await parent_pom.get_property(parent_prop)
- else:
- logger.warning(f'{self.package}: Parent {self.parent} does not have a pom file')
- import pdb; pdb.set_trace()
- return None
- else:
- return None
-
- async def _format_with_props(self, text):
- arr = re.split(r'\$\{([^\}]*)\}', text)
-
- for i in range(1, len(arr), 2):
- prop = arr[i]
- value = await self.get_property(prop)
-
- if value is None:
- logger.warning(f'{self.package}: Property {prop} not found. Defaulting to an empty string')
- value = ''
-
- arr[i] = value
- logger.debug(f'{self.package}: Replacing property "{prop}" with "{value}"')
-
- return ''.join(arr)
-
- async def _package_from_xml(self, dep: ET.Element):
- package = Package(
- *[
- await self._format_with_props(
- elem.text or '' if (elem := dep.find(tag, ns)) is not None else ''
- )
-
- for tag in [
- 'groupId',
- 'artifactId',
- 'version',
- ]
- ]
- )
- return package
-
- @property
- async def dependency_management(self) -> list['Package']:
- dependencies: list[Package] = []
-
- for dep in self.raw_root.find('dependencyManagement/dependencies', ns) or []:
- package = await self._package_from_xml(dep)
- logger.debug(f'{self.package}: Adding dependency management {package}')
- dependencies.append(package)
-
- return dependencies
-
- @property
- async def dependencies(self) -> list['Package']:
- dependencies: list[Package] = []
-
- for dep in self.raw_root.find('dependencies', ns) or []:
- package = await self._package_from_xml(dep)
- logger.debug(f'{self.package}: Adding dependency {package}')
- dependencies.append(package)
-
- return dependencies
-
-
- class Package:
- _verified: bool = False
- pom: PackagePOM | None = None
- groupId: str
- artifactId: str
- version: MavenVersion | None
- version_range: MavenVersionRange | None
-
- def __init__(self, groupId: str, artifactId: str, version: str = ''):
- self.groupId = groupId.strip()
- self.artifactId = artifactId.strip()
-
- match (val := get_maven_version_or_range(version.strip())):
- case MavenVersion():
- self.version = val
- self.version_range = None
- case MavenVersionRange():
- self.version = None
- self.version_range = val
-
- if not self.groupId or not self.artifactId:
- logger.warning(f'{self}: groupId or artifactId is empty')
-
- def __str__(self) -> str:
- return f'{self.groupId or "----"}:{self.artifactId or "----"}:{self.version or "----"}'
-
- def __eq__(self, other) -> bool:
- return (
- self.groupId == other.groupId
- and self.artifactId == other.artifactId
- and self.version == other.version
- )
-
- def __hash__(self) -> int:
- return hash((self.groupId, self.artifactId, self.version))
-
- @property
- def package_dir_path(self):
- group_path = self.groupId.replace(".", "/")
- return f'{group_path}/{self.artifactId}'
-
- @property
- def version_dir_path(self):
- return self.package_dir_path + f'/{self.version}'
-
- @property
- def base_filename(self):
- return f'{self.artifactId}-{self.version}'
-
- async def download_file_bytes(self, extension) -> bytes | None:
- filepath = f'{self.version_dir_path}/{self.base_filename}.{extension}'
-
- for mirror in mirrors:
- pom_url = f'{mirror}/{filepath}'
- logger.debug(f'{self}: Downloading {extension} from {pom_url}')
-
- async with session.get(pom_url) as response:
- if response.status == 200:
- logger.debug(f'{self}: {extension} downloaded')
- return await response.read()
- else:
- logger.debug(f'{self}: HTTP error {response.status} from mirror {mirror}')
- else:
- logger.warning(f'{self}: File download of {extension} failed for all mirrors')
- return None
-
- async def download_file_text(self, extension: str) -> str | None:
- b = await self.download_file_bytes(extension)
- if b is not None:
- return b.decode('utf-8')
- else:
- return None
-
- async def download_all(self, out_dir: Path) -> bool:
- basedir = out_dir / self.version_dir_path
- basedir.mkdir(exist_ok=True, parents=True)
- basepath = basedir / self.base_filename
-
- pom = self.pom
- if not pom:
- return False
-
- match pom.packaging:
- case 'pom':
- return True
- case 'jar' | 'maven-plugin' | 'eclipse-plugin' | 'bundle':
- logger.debug(f'{self}: Downloading JAR')
- return True # TODO: Remove to test JAR download
- # TODO: Handle checksums
- jar = await self.download_file_bytes('jar')
- if jar:
- with basepath.with_suffix('.jar').open('wb') as f:
- f.write(jar)
- return True
- else:
- logger.warning(f'{self}: JAR not found')
- return False
- case _:
- logger.warning(f'{self}: Unknown packaging {pom.packaging}')
- return False
-
- return False
-
- async def fetch_pom(self) -> None:
- if self.pom is not None:
- await self.pom.parse_pom()
- return
-
- if self.version is None:
- await self._query_maven()
-
- xml = await self.download_file_text('pom')
- if xml is not None:
- self.pom = PackagePOM(self, xml)
- await self.pom.parse_pom()
-
- @property
- def _urlquery(self) -> str:
- q = f'g:{self.groupId}+AND+a:{self.artifactId}'
-
- if self.version is not None:
- q += f'+AND+v:{self.version}'
-
- return q
-
- async def _query_maven(self) -> None:
- for mirror in mirrors:
- url = f'{mirror}/{self.package_dir_path}/maven-metadata.xml'
- logger.debug(f'{self}: Querying maven metadata at url {url}')
-
- async with session.get(url) as response:
- if response.status == 200:
- message = await response.text()
- xml = ET.fromstring(message)
- version_tags = xml.findall('./versioning/versions/version')
- versions = [MavenVersion(version.text) for version in version_tags if version.text is not None]
-
- logger.debug(f'{self}: Query successful')
-
- if not self.version:
- elem = xml.find('./versioning/latest')
- if elem is not None and (text := elem.text):
- self._verified = True
- self.version = MavenVersion(text)
- logger.debug(f'{self}: Using newest version {self.version}')
- break
- else:
- logger.debug(f'{self}: No latest version marked in metadata')
-
- if not versions:
- logger.warning(f'{self}: No versions available in metadata')
- else:
- self._verified = True
- self.version = max(versions)
- logger.debug(f'{self}: Using version {self.version}')
- break
- elif self.version is not None:
- if self.version in versions:
- self._verified = True
- break
- elif self.version_range is not None:
- valid_versions = [v for v in versions if self.version_range.matches(v)]
-
- if valid_versions:
- self.version = max(valid_versions)
- self._verified = True
- logger.debug(f'{self}: Picked version {self.version} from range {self.version_range}')
- break
- else:
- logger.debug(f"{self}: Available versions '{versions}' doesn't match range {self.version_range}")
- continue
- else:
- raise RuntimeError(f'{self}: Unknown version type {self.version}')
- else:
- self._verified = False
- logger.debug(f'{self}: Package not found in mirror')
- else:
- logger.warning(f'{self}: Package not found')
-
- async def verify(self) -> bool:
- if not self._verified:
- await self._query_maven()
- return self._verified
-
-
- async def load_package_list(list_path: Path) -> None:
- logger.info(f'Parsing {list_path}')
-
- with list_path.open('r') as f:
- for line in f.readlines():
- sections = line.strip().split(':')
-
- if len(sections) < 2 or len(sections) > 3:
- logger.warning(f'Invalid package format "{line}". It should be "groupID:artifactID" or "groupID:artifactID:version"')
- continue
-
- package = Package(
- sections[0],
- sections[1],
- sections[2] if len(sections) == 3 else '',
- )
-
- await package.verify()
- await queue.put(package, timeout=10)
-
-
- async def download(package: Package) -> None:
- if await package.verify():
- await package.fetch_pom()
- success = await package.download_all(output_dir)
-
- if success:
- logger.info(f'{package}: Downloaded')
- else:
- logger.warning(f'{package}: Download failed')
- return
-
- pom = package.pom
-
- if pom is not None:
- deps = await pom.dependencies
- depman = await pom.dependency_management
-
- for dep in [*deps, *depman]:
- await queue.put(dep)
- else:
- logger.warning(f'{package}: POM not found. Skipping')
- else:
- logger.warning(f"{package}: Can't verify package")
-
-
- async def file_parse_worker(id: str) -> None:
- while True:
- try:
- package = await queue.get(timeout=5)
- except asyncio.TimeoutError:
- logger.error(f'Worker {id} timed out waiting for queue. Stopping worker')
- break
-
- try:
- await asyncio.wait_for(package.fetch_pom(), timeout=60)
- except asyncio.TimeoutError as e:
- logger.error(f'{package}: Timeout out waiting for download', exc_info=e)
- except RequiresPackage as e:
- await queue.put(e.package)
- await queue.requeue(package)
- except Exception as e:
- logger.error(f'{package}:', exc_info=e)
- finally:
- await queue.done(package)
-
-
- async def download_worker(id: str) -> None:
- while True:
- try:
- package = await queue.get(timeout=5)
- except asyncio.TimeoutError:
- if queue.empty():
- logger.error(f'Worker {id} timed out waiting for empty queue. Stopping worker')
- else:
- logger.error(f'Worker {id} timed out waiting for queue lock. Stopping worker')
- break
-
- try:
- if package:
- await asyncio.wait_for(download(package), timeout=60)
- except asyncio.TimeoutError as e:
- logger.error(f'{package}: Timeout out waiting for download', exc_info=e)
- except RequiresPackage as e:
- await queue.put(e.package)
- await queue.requeue(package)
- except Exception as e:
- logger.error(f'{package}:', exc_info=e)
- finally:
- await queue.done(package)
-
-
- async def parse_list_tasks() -> None:
- #tasks = []
-
- #logger.debug(f'Starting {num_workers} download workers')
- #for i in range(num_workers):
- # tasks.append(
- # asyncio.create_task(
- # file_parse_worker(str(i))
- # )
- # )
-
- await load_package_list(Path('package-list.txt'))
- #exceptions = await asyncio.gather(*tasks, return_exceptions=True)
-
- #for e in exceptions:
- #logger.debug('Worker exception', exc_info=e)
-
-
- async def download_tasks() -> None:
- tasks = []
-
- logger.debug(f'Starting {num_workers} download workers')
- for i in range(num_workers):
- tasks.append(
- asyncio.create_task(
- download_worker(str(i))
- )
- )
-
- try:
- await queue.join()
- logger.debug('Queue is empty. Cancelling workers')
- except asyncio.CancelledError:
- logger.info('Tasks cancelled')
-
- for task in tasks:
- task.cancel()
-
- exceptions = await asyncio.gather(*tasks, return_exceptions=True)
-
- for e in exceptions:
- logger.debug(f'Worker exception {e}', exc_info=e)
-
-
- async def main() -> None:
- global queue
- global session
- queue = PackageQueue()
- session = ClientSession()
-
- try:
- await parse_list_tasks()
- await download_tasks()
- finally:
- await session.close()
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('-w', '--workers', type=int, default=num_workers)
- parser.add_argument('-v', '--verbose', dest='verbosity', action='count', default=0)
- args = parser.parse_args()
-
- if args.verbosity == 0:
- log_level = 'WARNING'
- elif args.verbosity == 1:
- log_level = 'INFO'
- else:
- log_level = 'DEBUG'
-
- logging.basicConfig(level=log_level)
-
- num_workers = args.workers
-
- asyncio.run(main())
|