Source code for airflow_fs.hooks.s3_hook

"""File system hook for the S3 file system."""

from builtins import super
import posixpath

try:
    import s3fs
except ImportError:
    s3fs = None

from . import FsHook


[docs]class S3Hook(FsHook): """Hook for interacting with files in S3.""" def __init__(self, conn_id=None): super().__init__() self._conn_id = conn_id self._conn = None def get_conn(self): if s3fs is None: raise ImportError("s3fs must be installed to use the S3Hook") if self._conn is None: if self._conn_id is None: self._conn = s3fs.S3FileSystem() else: config = self.get_connection(self._conn_id) extra_kwargs = {} if "encryption" in config.extra_dejson: extra_kwargs["ServerSideEncryption"] = config.extra_dejson[ "encryption" ] self._conn = s3fs.S3FileSystem( key=config.login, secret=config.password, s3_additional_kwargs=extra_kwargs, ) return self._conn
[docs] def disconnect(self): self._conn = None
[docs] def open(self, file_path, mode="rb"): return self.get_conn().open(file_path, mode=mode)
[docs] def exists(self, file_path): return self.get_conn().exists(file_path)
[docs] def isdir(self, path): if "/" not in path: # Path looks like a bucket name. return True parent_dir = posixpath.dirname(path) for child in self.get_conn().ls(parent_dir, detail=True): if child["Key"] == path and child["StorageClass"] == "DIRECTORY": return True return False
[docs] def mkdir(self, dir_path, mode=0o755, exist_ok=True): self.makedirs(dir_path, mode=mode, exist_ok=exist_ok)
[docs] def listdir(self, dir_path): return [posixpath.relpath(fp, start=dir_path) for fp in self.get_conn().ls(dir_path, details=False)]
[docs] def rm(self, file_path): self.get_conn().rm(file_path, recursive=False)
[docs] def rmtree(self, dir_path): self.get_conn().rm(dir_path, recursive=True)
# Overridden default implementations.
[docs] def makedirs(self, dir_path, mode=0o755, exist_ok=True): if self.exists(dir_path): if not exist_ok: self._raise_dir_exists(dir_path) else: self.get_conn().mkdir(dir_path)
[docs] def walk(self, root): root = _remove_trailing_slash(root) for entry in super().walk(root): yield entry
def _remove_trailing_slash(path): if path.endswith("/"): return path[:-1] return path