Source code for lerot.retrieval_system.SamplerSystem

# This file is part of Lerot.
#
# Lerot is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Lerot is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Lerot.  If not, see <http://www.gnu.org/licenses/>.

from numpy import *
import argparse
import logging
from random import gauss
import yaml
import glob
import os

from AbstractLearningSystem import AbstractLearningSystem
from ..utils import get_class


[docs]class SamplerSystem(AbstractLearningSystem): def __init__(self, feature_count, arg_str, run_count=""): logging.info("Initializing SamplerSystem") self.feature_count = feature_count # parse arguments parser = argparse.ArgumentParser() parser.add_argument("-w", "--init_weights", help="Initialization " "colon seperated list of inital weight vectors, weight vectors are" " comma seperated", required=True) parser.add_argument("--sample_weights", default="sample_unit_sphere") parser.add_argument("--nr_rankers", type=int) parser.add_argument("-c", "--comparison", required=True) parser.add_argument("-f", "--comparison_args", nargs="*") parser.add_argument("-r", "--ranker", required=True) parser.add_argument("-s", "--ranker_args", nargs="*") parser.add_argument("-t", "--ranker_tie", default="random") parser.add_argument("--sampler", required=True) args = vars(parser.parse_known_args(arg_str.split())[0]) # initialize weights, comparison method, and learner weights = [] if args["init_weights"].startswith("random"): for i in range(args["nr_rankers"]): v = zeros(self.feature_count) for i in range(0, self.feature_count): v[i] = gauss(0, 1) weights.append(list(v)) else: for f in sorted(glob.glob(os.path.join(args["init_weights"], "*.txt")))[:args["nr_rankers"]]: yamldata = yaml.load(open(f, 'r')) weight = yamldata["final_weights"] if len(weight) != feature_count: raise Exception("List of initial weights does not have the" " expected length (length is %d, expected %d)." % (len(weight), feature_count)) weights.append(weight) logging.info("Loaded weight from file %s." % f) logging.info("Loaded %d weights." % len(weights)) self.ranker_class = get_class(args["ranker"]) if "ranker_args" in args and args["ranker_args"] is not None: self.ranker_args = " ".join(args["ranker_args"]) self.ranker_args = self.ranker_args.strip("\"") else: self.ranker_args = None self.ranker_tie = args["ranker_tie"] self.sample_weights = args["sample_weights"] self.rankers = [self.ranker_class(self.ranker_args, self.ranker_tie, feature_count, init=",".join([str(n) for n in w]), sample=self.sample_weights) for w in weights] self.comparison_class = get_class(args["comparison"]) if "comparison_args" in args and args["comparison_args"] is not None: self.comparison_args = " ".join(args["comparison_args"]) self.comparison_args = self.comparison_args.strip("\"") else: self.comparison_args = None self.comparison = self.comparison_class(self.comparison_args) self.r1 = 0 # One ranker to be compared in live evaluation. self.r2 = 0 # The other ranker to be compared against self.r1. sampler_class = get_class(args["sampler"]) try: self.sampler = sampler_class(self.rankers, arg_str, run_count) except TypeError: self.sampler = sampler_class(self.rankers, arg_str) self.logging_frequency = 1000 self.iteration = 1
[docs] def get_ranked_list(self, query): self.r1, self.r2, i1, i2 = self.sampler.get_arms() i1s = [i1] i2s = [i2] #while self.r1 == self.r2: #self.iteration += 1 #self.sampler.update_scores(self.r1, self.r2) #self.r1, self.r2, i1, i2 = self.sampler.get_arms() #i1s.append(i1) #i2s.append(i2) (l, context) = self.comparison.interleave(self.r1, self.r2, query, 10) self.current_l = l self.current_context = context self.current_query = query return l, i1s, i2s
[docs] def update_solution(self, clicks): outcome = self.comparison.infer_outcome( self.current_l, self.current_context, clicks, self.current_query) if outcome < 0: win = self.sampler.update_scores(self.r1, self.r2)#, -outcome) elif outcome > 0: win = self.sampler.update_scores(self.r2, self.r1)#, outcome) else: if gauss(0,1) > 0: win = self.sampler.update_scores(self.r1, self.r2)#, -outcome) else: win = self.sampler.update_scores(self.r2, self.r1)#, outcome) self.iteration += 1 return win
[docs] def get_solution(self): if self.iteration % self.logging_frequency == 0: logging.info("Iteration %d" % self.iteration) return self.sampler.get_winner().w