from urllib.parse import urlparse
import logging
import os
import subprocess
import urllib.request
import zipfile
from torchnlp._third_party.lazy_loader import LazyLoader
from tqdm import tqdm
requests = LazyLoader('requests', globals(), 'requests')
logger = logging.getLogger(__name__)
def _reporthook(t):
""" ``reporthook`` to use with ``urllib.request`` that prints the process of the download.
Uses ``tqdm`` for progress bar.
**Reference:**
https://github.com/tqdm/tqdm
Args:
t (tqdm.tqdm) Progress bar.
Example:
>>> with tqdm(unit='B', unit_scale=True, miniters=1, desc=filename) as t: # doctest: +SKIP
... urllib.request.urlretrieve(file_url, filename=full_path, reporthook=reporthook(t))
"""
last_b = [0]
def inner(b=1, bsize=1, tsize=None):
"""
Args:
b (int, optional): Number of blocks just transferred [default: 1].
bsize (int, optional): Size of each block (in tqdm units) [default: 1].
tsize (int, optional): Total size (in tqdm units). If [default: None] remains unchanged.
"""
if tsize is not None:
t.total = tsize
t.update((b - last_b[0]) * bsize)
last_b[0] = b
return inner
def _download_file_from_drive(filename, url): # pragma: no cover
""" Download filename from google drive unless it's already in directory.
Args:
filename (str): Name of the file to download to (do nothing if it already exists).
url (str): URL to download from.
"""
confirm_token = None
# Since the file is big, drive will scan it for virus and take it to a
# warning page. We find the confirm token on this page and append it to the
# URL to start the download process.
confirm_token = None
session = requests.Session()
response = session.get(url, stream=True)
for k, v in response.cookies.items():
if k.startswith("download_warning"):
confirm_token = v
if confirm_token:
url = url + "&confirm=" + confirm_token
logger.info("Downloading %s to %s" % (url, filename))
response = session.get(url, stream=True)
# Now begin the download.
chunk_size = 16 * 1024
with open(filename, "wb") as f:
for chunk in response.iter_content(chunk_size):
if chunk:
f.write(chunk)
# Print newline to clear the carriage return from the download progress
statinfo = os.stat(filename)
logger.info("Successfully downloaded %s, %s bytes." % (filename, statinfo.st_size))
def _maybe_extract(compressed_filename, directory, extension=None):
""" Extract a compressed file to ``directory``.
Args:
compressed_filename (str): Compressed file.
directory (str): Extract to directory.
extension (str, optional): Extension of the file; Otherwise, attempts to extract extension
from the filename.
"""
logger.info('Extracting {}'.format(compressed_filename))
if extension is None:
basename = os.path.basename(compressed_filename)
extension = basename.split('.', 1)[1]
if 'zip' in extension:
with zipfile.ZipFile(compressed_filename, "r") as zip_:
zip_.extractall(directory)
elif 'tar.gz' in extension or 'tgz' in extension:
# `tar` is much faster than python's `tarfile` implementation
subprocess.call(['tar', '-C', directory, '-zxvf', compressed_filename])
elif 'tar' in extension:
subprocess.call(['tar', '-C', directory, '-xvf', compressed_filename])
logger.info('Extracted {}'.format(compressed_filename))
def _get_filename_from_url(url):
""" Return a filename from a URL
Args:
url (str): URL to extract filename from
Returns:
(str): Filename in URL
"""
parse = urlparse(url)
return os.path.basename(parse.path)
def _check_download(*filepaths):
""" Check if the downloaded files are found.
Args:
filepaths (list of str): Check if these filepaths exist
Returns:
(bool): Returns True if all filepaths exist
"""
return all([os.path.isfile(filepath) for filepath in filepaths])