import torch
from torchnlp.utils import torch_equals_ignore_index
is_scalar = lambda t: torch.is_tensor(t) and len(t.size()) == 0
[docs]def get_accuracy(targets, outputs, k=1, ignore_index=None):
""" Get the accuracy top-k accuracy between two tensors.
Args:
targets (1 - 2D :class:`torch.Tensor`): Target or true vector against which to measure
saccuracy
outputs (1 - 3D :class:`torch.Tensor`): Prediction or output vector
ignore_index (int, optional): Specifies a target index that is ignored
Returns:
:class:`tuple` consisting of accuracy (:class:`float`), number correct (:class:`int`) and
total (:class:`int`)
Example:
>>> import torch
>>> from torchnlp.metrics import get_accuracy
>>> targets = torch.LongTensor([1, 2, 3, 4, 5])
>>> outputs = torch.LongTensor([1, 2, 2, 3, 5])
>>> accuracy, n_correct, n_total = get_accuracy(targets, outputs, ignore_index=3)
>>> accuracy
0.8
>>> n_correct
4
>>> n_total
5
"""
n_correct = 0.0
for target, output in zip(targets, outputs):
if not torch.is_tensor(target) or is_scalar(target):
target = torch.LongTensor([target])
if not torch.is_tensor(output) or is_scalar(output):
output = torch.LongTensor([[output]])
predictions = output.topk(k=min(k, len(output)), dim=0)[0]
for prediction in predictions:
if torch_equals_ignore_index(
target.squeeze(), prediction.squeeze(), ignore_index=ignore_index):
n_correct += 1
break
return n_correct / len(targets), int(n_correct), len(targets)
[docs]def get_token_accuracy(targets, outputs, ignore_index=None):
""" Get the accuracy token accuracy between two tensors.
Args:
targets (1 - 2D :class:`torch.Tensor`): Target or true vector against which to measure
saccuracy
outputs (1 - 3D :class:`torch.Tensor`): Prediction or output vector
ignore_index (int, optional): Specifies a target index that is ignored
Returns:
:class:`tuple` consisting of accuracy (:class:`float`), number correct (:class:`int`) and
total (:class:`int`)
Example:
>>> import torch
>>> from torchnlp.metrics import get_token_accuracy
>>> targets = torch.LongTensor([[1, 1], [2, 2], [3, 3]])
>>> outputs = torch.LongTensor([[1, 1], [2, 3], [4, 4]])
>>> accuracy, n_correct, n_total = get_token_accuracy(targets, outputs, ignore_index=3)
>>> accuracy
0.75
>>> n_correct
3.0
>>> n_total
4.0
"""
n_correct = 0.0
n_total = 0.0
for target, output in zip(targets, outputs):
if not torch.is_tensor(target) or is_scalar(target):
target = torch.LongTensor([target])
if not torch.is_tensor(output) or is_scalar(output):
output = torch.LongTensor([[output]])
if len(target.size()) != len(output.size()):
prediction = output.max(dim=0)[0].view(-1)
else:
prediction = output
if ignore_index is not None:
mask = target.ne(ignore_index)
n_correct += prediction.eq(target).masked_select(mask).sum().item()
n_total += mask.sum().item()
else:
n_total += len(target)
n_correct += prediction.eq(target).sum().item()
return n_correct / n_total, n_correct, n_total