Source code for lerot.retrieval_system.PrudentListwiseLearningSystem

# This file is part of Lerot.
#
# Lerot is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Lerot is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Lerot.  If not, see <http://www.gnu.org/licenses/>.

# KH, 2012/06/14
"""
Retrieval system implementation for use in learning experiments.
"""

import argparse
from numpy import array

from .AbstractLearningSystem import AbstractLearningSystem
from ..utils import get_class, split_arg_str


[docs]class PrudentListwiseLearningSystem(AbstractLearningSystem): """A retrieval system that learns online from listwise comparisons. The system keeps track of all necessary state variables (current query, weights, etc.) so that comparison and learning classes can be stateless (implement only static / class methods).""" def __init__(self, feature_count, arg_str): # parse arguments parser = argparse.ArgumentParser(description="Initialize retrieval " "system with the specified feedback and learning mechanism.", prog="ListwiseLearningSystem") parser.add_argument("-w", "--init_weights", help="Initialization " "method for weights (random, zero).", required=True) parser.add_argument("--sample_weights", default="sample_unit_sphere") parser.add_argument("-c", "--comparison", required=True) parser.add_argument("-f", "--comparison_args", nargs="*") parser.add_argument("-r", "--ranker", required=True) parser.add_argument("-s", "--ranker_args", nargs="*") parser.add_argument("-t", "--ranker_tie", default="random") parser.add_argument("-d", "--delta", required=True, type=str) parser.add_argument("-a", "--alpha", required=True, type=str) parser.add_argument("--anneal", type=int, default=0) parser.add_argument("--normalize", default="False") args = vars(parser.parse_known_args(split_arg_str(arg_str))[0]) self.ranker_class = get_class(args["ranker"]) self.ranker_args = args["ranker_args"] self.ranker_tie = args["ranker_tie"] self.sample_weights = args["sample_weights"] self.init_weights = args["init_weights"] self.feature_count = feature_count self.ranker = self.ranker_class(self.ranker_args, self.ranker_tie, self.feature_count, sample=self.sample_weights, init=self.init_weights) if "," in args["delta"]: self.delta = array([float(x) for x in args["delta"].split(",")]) else: self.delta = float(args["delta"]) if "," in args["alpha"]: self.alpha = array([float(x) for x in args["alpha"].split(",")]) else: self.alpha = float(args["alpha"]) self.anneal = args["anneal"] self.comparison_class = get_class(args["comparison"]) if "comparison_args" in args and args["comparison_args"] is not None: self.comparison_args = " ".join(args["comparison_args"]) self.comparison_args = self.comparison_args.strip("\"") else: self.comparison_args = None self.comparison = self.comparison_class(self.comparison_args) def _get_new_candidate(self): w, u = self.ranker.get_candidate_weight(self.delta) # candidate_ranker1 = copy.deepcopy(self.ranker) # candidate_ranker1.update_weights(w) candidate_ranker2 = self.ranker_class(self.ranker_args, self.ranker_tie, self.feature_count, sample=self.sample_weights, init=self.init_weights) candidate_ranker2.update_weights(w) return candidate_ranker2, u def _get_candidate(self): return self._get_new_candidate()
[docs] def get_ranked_list(self, query, getNewCandidate=True): if getNewCandidate == True: self.candidate_ranker, self.current_u = self._get_candidate() (l, context) = self.comparison.interleave(self.ranker, self.candidate_ranker, query, 10) self.current_l = l self.current_context = context self.current_query = query return l
def _update_weights(self): self.ranker.update_weights(self.current_u, self.alpha) def _update_solution(self, outcome, _): if outcome > 0: return self._update_weights()
[docs] def get_outcome(self,clicks): return self.comparison.infer_outcome(self.current_l, self.current_context, clicks, self.current_query)
[docs] def update_solution(self): self._update_solution(1, []) return self.get_solution()
[docs] def get_solution(self): return self.ranker