# This file is part of Lerot.
#
# Lerot is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Lerot is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Lerot. If not, see <http://www.gnu.org/licenses/>.
"""
Retrieval system implementation for use in learning experiments.
"""
import argparse
import numpy
from .AbstractLearningSystem import AbstractLearningSystem
from ..utils import get_class, split_arg_str, create_ranking_vector
[docs]class PerturbationLearningSystem(AbstractLearningSystem):
"""A retrieval system that learns online from pairwise comparisons. The
system keeps track of all necessary state variables (current query,
weights, etc.) so that comparison and learning classes can be stateless
(implement only static / class methods)."""
def __init__(self, feature_count, arg_str):
# parse arguments
parser = argparse.ArgumentParser(
description="Initialize retrieval "
"system with the specified feedback and learning mechanism.",
prog="PerturbationLearningSystem"
)
parser.add_argument("-w", "--init_weights", help="Initialization "
"method for weights (random, zero).",
required=True)
# Perturbation arguments
parser.add_argument("-p", "--perturbator", required=True)
parser.add_argument("-b", "--swap_prob", default=0.25, type=float)
# parser.add_argument("-f", "--perturbator_args", nargs="*")
parser.add_argument("-r", "--ranker", required=True)
parser.add_argument("-s", "--ranker_args", nargs="*", default=tuple())
parser.add_argument("-t", "--ranker_tie", default="random")
parser.add_argument("-l", "--max_results", default=float('Inf'))
args = vars(parser.parse_known_args(split_arg_str(arg_str))[0])
self.ranker_class = get_class(args["ranker"])
self.ranker_args = args["ranker_args"]
self.ranker_tie = args["ranker_tie"]
self.init_weights = args["init_weights"]
self.feature_count = feature_count
self.ranker = self.ranker_class(
self.ranker_args,
self.ranker_tie,
self.feature_count,
sample=None, # explicitly break sampling
init=self.init_weights
)
self.max_results = args["max_results"]
self.perturbator = get_class(args["perturbator"])(args["swap_prob"])
[docs] def get_ranked_list(self, query):
new_ranking, single_start = self.perturbator.perturb(
self.ranker, query, self.max_results
)
self.current_ranking = new_ranking
self.current_single_start = single_start
self.current_query = query
return new_ranking
[docs] def update_solution_once(self, clicks):
"""
Update the ranker weights without regard to multiple clicks
on a single link
"""
new_ranking = self._get_feedback(clicks)
# Calculate ranking vectors
current_vector = create_ranking_vector(
self.current_query,
self.current_ranking
)
new_vector = create_ranking_vector(
self.current_query,
new_ranking
)
self.perturbator.update(
new_vector,
current_vector,
self.current_query,
self.ranker
)
# Update the weights
self.ranker.update_weights(new_vector - current_vector, 1)
return self.get_solution()
[docs] def update_solution(self, clicks):
"""
Update the ranker weights
while keeping in mind that documents with a relevance of > 1 are
clicked more than once
"""
# Loop through clicks until no clicks are left
while numpy.count_nonzero(clicks) > 0:
# Update
self.update_solution_once(clicks)
# Remove one click per click
relevant_clicks = numpy.nonzero(clicks)
for click_index in relevant_clicks:
clicks[click_index] -= 1
return self.get_solution()
[docs] def get_solution(self):
return self.ranker
def _get_feedback(self, clicks):
"""
Get a new ranking of documents, swapped according to user clicks
"""
max_length = len(self.current_ranking)
# Check whether new ranking should start with a single start
if self.current_single_start:
new_ranking = [self.current_ranking[0]]
else:
new_ranking = []
# Loop for swapping pairs of documents according to clicks
for i in xrange(self.current_single_start, max_length-1, 2):
# Swap if there is a click on the lower item of a pair
if clicks[i+1] and not clicks[i]:
new_ranking.append(self.current_ranking[i+1])
new_ranking.append(self.current_ranking[i])
# Don't swap
else:
new_ranking.append(self.current_ranking[i])
new_ranking.append(self.current_ranking[i+1])
# Add last index if it hasn't been added yet
if len(new_ranking) < max_length:
new_ranking.append(self.current_ranking[max_length-1])
return new_ranking