diff --git a/update-poms.py b/update-poms.py index dc8a3d5..96a6d20 100755 --- a/update-poms.py +++ b/update-poms.py @@ -3,11 +3,10 @@ import re import argparse import logging +import asyncio import subprocess -import json import copy -from urllib import request -from urllib.error import HTTPError +import aiohttp from pathlib import Path from xml.etree import ElementTree as ET @@ -15,6 +14,10 @@ from xml.etree import ElementTree as ET ET.register_namespace('', 'http://maven.apache.org/POM/4.0.0') baseurl = 'https://search.maven.org' +base_pom_path = Path('poms') +done: set[str] = set() +done_lock = asyncio.Lock() +num_workers = 50 class PackagePOM: @@ -128,7 +131,7 @@ class Package: return hash((self.groupId, self.artifactId, self.version)) @property - def pom(self) -> ET: + async def pom(self) -> ET: if self._pom is not None: return self._pom @@ -141,20 +144,14 @@ class Package: pom_url = f'{baseurl}/remotecontent?filepath={filepath}' logger.debug(f'{self}: Downloading pom from {pom_url}') - try: - response = request.urlopen(pom_url) - except HTTPError as e: - logger.warning(f'{self}: HTTP error downloading pom') - logger.debug(e) - return None - - status = response.status - if status == 200: - logger.debug(f'{self}: POM downloaded') - self._pom = PackagePOM(self, response.read()) - else: - logger.warning(f'{self}: HTTP error {status} downloading pom') + async with aiohttp.ClientSession() as session: + async with session.get(pom_url) as response: + if response.status == 200: + logger.debug(f'{self}: POM downloaded') + self._pom = PackagePOM(self, await response.text()) + else: + logger.warning(f'{self}: HTTP error {response.status} downloading pom') return self._pom @@ -167,38 +164,37 @@ class Package: return q - def _query_maven(self) -> None: + async def _query_maven(self) -> None: url = f'{baseurl}/solrsearch/select?q={self._urlquery}&rows=1&wt=json' logger.debug(f'{self}: Querying maven at url {url}') - response = request.urlopen(url) - status = response.status - - if status == 200: - message = json.loads(response.read()) - num = message['response']['numFound'] - - if num: - logger.debug(f'{self}: Query successful') - self._verified = True - if self.version is None: - version = message['response']['docs'][0]['latestVersion'] - logger.info(f'{self}: Using newest version {version}') - self.version = version - else: - logger.warning(f'{self}: No matching packages found') - self._verified = False - else: - self._verified = False - logger.warning(f'{self}: HTTP error {status} downloading pom') - def verify(self) -> bool: + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + message = await response.json() + num = message['response']['numFound'] + + if num: + logger.debug(f'{self}: Query successful') + self._verified = True + if self.version is None: + version = message['response']['docs'][0]['latestVersion'] + logger.debug(f'{self}: Using newest version {version}') + self.version = version + else: + logger.warning(f'{self}: No matching packages found') + self._verified = False + else: + self._verified = False + logger.warning(f'{self}: HTTP error {response.status} downloading pom') + + async def verify(self) -> bool: if not self._verified: - self._query_maven() + await self._query_maven() return self._verified -def load_package_list(list_path: Path) -> list[Package]: - packages = [] +def load_package_list(list_path: Path, queue: asyncio.Queue) -> None: logger.info(f'Parsing {list_path}') with list_path.open('r') as f: @@ -209,49 +205,71 @@ def load_package_list(list_path: Path) -> list[Package]: logger.warning(f'Invalid package format "{line}". It should be "groupID:artifactID" or "groupID:artifactID:version"') continue - query = Package( + package = Package( sections[0], sections[1], sections[2] if len(sections) == 3 else None, ) - packages.append(query) + queue.put_nowait(package) - return packages +async def download(package: Package, queue: asyncio.Queue) -> None: + async with done_lock: + skip = str(package) in done -def download(base_path: Path, package: Package, done: [str]) -> None: - if str(package) in done: + if skip: logger.info(f'{package}: Already downloaded. Skipping.') - elif package.verify(): - pom_dir = base_path / str(package) + elif await package.verify(): + async with done_lock: + done.add(str(package)) + + pom_dir = base_pom_path / str(package) pom_path = pom_dir / 'pom.xml' pom_dir.mkdir(exist_ok=True) - if not package.pom: + pom = await package.pom + + if not pom: return - package.pom.write(pom_path) - done.append(str(package)) + pom.write(pom_path) logger.info(f'{package}: Downloaded') - if not package.pom.is_bom: - for dep in package.pom.dependencyManagement: + if not pom.is_bom: + for dep in pom.dependencyManagement: logger.info(f'{package}: Handling transitive dependency {dep}') - download(base_path, dep, done) + await queue.put(dep) else: logger.warning(f'{package}: Package not found. Check package name and internet connection') -def main() -> None: - packages = load_package_list(Path('package-list.txt')) +async def worker(queue: asyncio.Queue) -> None: + while True: + package = await queue.get() + await download(package, queue) + queue.task_done() + + +async def main() -> None: + queue = asyncio.Queue() + tasks = [] + + load_package_list(Path('package-list.txt'), queue) + + for i in range(num_workers): + tasks.append( + asyncio.create_task( + worker(queue) + ) + ) - base_pom_path = Path('poms') - done = [] + await queue.join() + for task in tasks: + task.cancel() - for package in packages: - download(base_pom_path, package, done) + await asyncio.gather(*tasks, return_exceptions=True) subprocess.call(['sh', 'generate_master_pom.sh']) @@ -272,4 +290,4 @@ if __name__ == '__main__': logging.basicConfig(level=log_level) - main() + asyncio.run(main())