Source code for torchnlp.datasets.wikitext_2

import os
import io

from import download_file_maybe_extract
from torchnlp.encoders.text import DEFAULT_EOS_TOKEN
from torchnlp.encoders.text import DEFAULT_UNKNOWN_TOKEN

[docs]def wikitext_2_dataset( directory='data/', train=False, dev=False, test=False, train_filename='wiki.train.tokens', dev_filename='wiki.valid.tokens', test_filename='wiki.test.tokens', extracted_name='wikitext-2', check_files=['wikitext-2/wiki.train.tokens'], url='', unknown_token=DEFAULT_UNKNOWN_TOKEN, eos_token=DEFAULT_EOS_TOKEN): """ Load the WikiText-2 dataset. The WikiText language modeling dataset is a collection of over 100 million tokens extracted from the set of verified Good and Featured articles on Wikipedia. The dataset is available under the Creative Commons Attribution-ShareAlike License. **Reference:** Args: directory (str, optional): Directory to cache the dataset. train (bool, optional): If to load the training split of the dataset. dev (bool, optional): If to load the development split of the dataset. test (bool, optional): If to load the test split of the dataset. train_filename (str, optional): The filename of the training split. dev_filename (str, optional): The filename of the development split. test_filename (str, optional): The filename of the test split. extracted_name (str, optional): Name of the extracted dataset directory. check_files (str, optional): Check if these files exist, then this download was successful. url (str, optional): URL of the dataset `tar.gz` file. unknown_token (str, optional): Token to use for unknown words. eos_token (str, optional): Token to use at the end of sentences. Returns: :class:`tuple` of :class:`iterable` or :class:`iterable`: Returns between one and all dataset splits (train, dev and test) depending on if their respective boolean argument is ``True``. Example: >>> from torchnlp.datasets import wikitext_2_dataset # doctest: +SKIP >>> train = wikitext_2_dataset(train=True) # doctest: +SKIP >>> train[:10] # doctest: +SKIP ['</s>', '=', 'Valkyria', 'Chronicles', 'III', '=', '</s>', '</s>', 'Senjō', 'no'] """ download_file_maybe_extract(url=url, directory=directory, check_files=check_files) ret = [] splits = [(train, train_filename), (dev, dev_filename), (test, test_filename)] splits = [f for (requested, f) in splits if requested] for filename in splits: full_path = os.path.join(directory, extracted_name, filename) text = [] with, encoding='utf-8') as f: for line in f: text.extend(line.replace('<unk>', unknown_token).split()) text.append(eos_token) ret.append(text) if len(ret) == 1: return ret[0] else: return tuple(ret)