Source code for torchnlp.datasets.penn_treebank

import os
import io

from torchnlp.download import download_files_maybe_extract
from torchnlp.encoders.text import DEFAULT_EOS_TOKEN
from torchnlp.encoders.text import DEFAULT_UNKNOWN_TOKEN


[docs]def penn_treebank_dataset( directory='data/penn-treebank', train=False, dev=False, test=False, train_filename='ptb.train.txt', dev_filename='ptb.valid.txt', test_filename='ptb.test.txt', check_files=['ptb.train.txt'], urls=[ 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.train.txt', 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.valid.txt', 'https://raw.githubusercontent.com/wojzaremba/lstm/master/data/ptb.test.txt' ], unknown_token=DEFAULT_UNKNOWN_TOKEN, eos_token=DEFAULT_EOS_TOKEN): """ Load the Penn Treebank dataset. This is the Penn Treebank Project: Release 2 CDROM, featuring a million words of 1989 Wall Street Journal material. **Reference:** https://catalog.ldc.upenn.edu/LDC99T42 **Citation:** Marcus, Mitchell P., Marcinkiewicz, Mary Ann & Santorini, Beatrice (1993). Building a Large Annotated Corpus of English: The Penn Treebank Args: directory (str, optional): Directory to cache the dataset. train (bool, optional): If to load the training split of the dataset. dev (bool, optional): If to load the development split of the dataset. test (bool, optional): If to load the test split of the dataset. train_filename (str, optional): The filename of the training split. dev_filename (str, optional): The filename of the development split. test_filename (str, optional): The filename of the test split. name (str, optional): Name of the dataset directory. check_files (str, optional): Check if these files exist, then this download was successful. urls (str, optional): URLs to download. unknown_token (str, optional): Token to use for unknown words. eos_token (str, optional): Token to use at the end of sentences. Returns: :class:`tuple` of :class:`iterable` or :class:`iterable`: Returns between one and all dataset splits (train, dev and test) depending on if their respective boolean argument is ``True``. Example: >>> from torchnlp.datasets import penn_treebank_dataset # doctest: +SKIP >>> train = penn_treebank_dataset(train=True) # doctest: +SKIP >>> train[:10] # doctest: +SKIP ['aer', 'banknote', 'berlitz', 'calloway', 'centrust', 'cluett', 'fromstein', 'gitano', 'guterman', 'hydro-quebec'] """ download_files_maybe_extract(urls=urls, directory=directory, check_files=check_files) ret = [] splits = [(train, train_filename), (dev, dev_filename), (test, test_filename)] splits = [f for (requested, f) in splits if requested] for filename in splits: full_path = os.path.join(directory, filename) text = [] with io.open(full_path, encoding='utf-8') as f: for line in f: text.extend(line.replace('<unk>', unknown_token).split()) text.append(eos_token) ret.append(text) if len(ret) == 1: return ret[0] else: return tuple(ret)