This message was deleted.
# general
s
This message was deleted.
🙌 5
🔥 3
e
Copy code
def s3() -> boto3.resource:
    """Returns a boto3 resource"""

    return boto3.resource("s3")


@dataclasses.dataclass
class ToDownload:
    key: str
    bucket: str


def ensured_save_dir(save_dir: str) -> str:
    if not os.path.exists(save_dir):
        Path(save_dir).mkdir()
    return save_dir


def downloadable(
        s3: boto3.resource,
        bucket: str,
        path_in_bucket: str,
        slice: int = None
) -> Parallelizable[ToDownload]:
    """Lists downloadables from the s3 bucket"""

    bucket_obj = s3.Bucket(bucket)
    objs = list(bucket_obj.objects.filter(Prefix=path_in_bucket).all())
    if slice is not None:
        objs = objs[:slice]
    <http://logger.info|logger.info>(f"Found {len(objs)} objects in {bucket}/{path_in_bucket}")
    for obj in objs:
        yield ToDownload(key=obj.key, bucket=bucket)


def _already_downloaded(path: str) -> bool:
    """Checks if the data is already downloaded"""
    if os.path.exists(path):
        return True
    return False


def downloaded_data(
        downloadable: ToDownload,
        ensured_save_dir: str,
) -> str:
    """Downloads data, short-circuiting if the data already exists locally

    :param s3:
    :param bucket:
    :param path_in_bucket:
    :param save_dir:
    :return:
    """
    download_location = os.path.join(ensured_save_dir, downloadable.key)
    if _already_downloaded(download_location):
        <http://logger.info|logger.info>(f"Already downloaded {download_location}")
        return download_location
    parent_path = os.path.dirname(download_location)
    if not os.path.exists(parent_path):
        os.makedirs(parent_path, exist_ok=True)
    s3_resource = s3()  # we want to ensure threadsafety --
    # we could do this in a pool, but for now we'll just create it cause we're doing this in
    # parallel

    bucket = s3_resource.Bucket(downloadable.bucket)
    bucket.download_file(downloadable.key, download_location)
    <http://logger.info|logger.info>(f"Downloaded {download_location}")
    return download_location


def all_downloaded_data(downloaded_data: Collect[str]) -> List[str]:
    """Returns a list of all downloaded locations"""
    out = []
    for path in downloaded_data:
        out.append(path)
    return out