import os
from torchnlp.download import download_files_maybe_extract
[docs]def trec_dataset(directory='data/trec/',
train=False,
test=False,
train_filename='train_5500.label',
test_filename='TREC_10.label',
check_files=['train_5500.label'],
urls=[
'http://cogcomp.org/Data/QA/QC/train_5500.label',
'http://cogcomp.org/Data/QA/QC/TREC_10.label'
],
fine_grained=False):
"""
Load the Text REtrieval Conference (TREC) Question Classification dataset.
TREC dataset contains 5500 labeled questions in training set and another 500 for test set. The
dataset has 6 labels, 50 level-2 labels. Average length of each sentence is 10, vocabulary size
of 8700.
References:
* https://nlp.stanford.edu/courses/cs224n/2004/may-steinberg-project.pdf
* http://cogcomp.org/Data/QA/QC/
* http://www.aclweb.org/anthology/C02-1150
**Citation:**
Xin Li, Dan Roth, Learning Question Classifiers. COLING'02, Aug., 2002.
Args:
directory (str, optional): Directory to cache the dataset.
train (bool, optional): If to load the training split of the dataset.
test (bool, optional): If to load the test split of the dataset.
train_filename (str, optional): The filename of the training split.
test_filename (str, optional): The filename of the test split.
check_files (str, optional): Check if these files exist, then this download was successful.
urls (str, optional): URLs to download.
Returns:
:class:`tuple` of :class:`iterable` or :class:`iterable`:
Returns between one and all dataset splits (train, dev and test) depending on if their
respective boolean argument is ``True``.
Example:
>>> from torchnlp.datasets import trec_dataset # doctest: +SKIP
>>> train = trec_dataset(train=True) # doctest: +SKIP
>>> train[:2] # doctest: +SKIP
[{
'label': 'DESC',
'text': 'How did serfdom develop in and then leave Russia ?'
}, {
'label': 'ENTY',
'text': 'What films featured the character Popeye Doyle ?'
}]
"""
download_files_maybe_extract(urls=urls, directory=directory, check_files=check_files)
ret = []
splits = [(train, train_filename), (test, test_filename)]
splits = [f for (requested, f) in splits if requested]
for filename in splits:
full_path = os.path.join(directory, filename)
examples = []
for line in open(full_path, 'rb'):
# there is one non-ASCII byte: sisterBADBYTEcity; replaced with space
label, _, text = line.replace(b'\xf0', b' ').strip().decode().partition(' ')
label, _, label_fine = label.partition(':')
if fine_grained:
examples.append({'label': label_fine, 'text': text})
else:
examples.append({'label': label, 'text': text})
ret.append(examples)
if len(ret) == 1:
return ret[0]
else:
return tuple(ret)