presage 0.9.2~beta
smoothedNgramPredictor.cpp
Go to the documentation of this file.
1
2/******************************************************
3 * Presage, an extensible predictive text entry system
4 * ---------------------------------------------------
5 *
6 * Copyright (C) 2008 Matteo Vescovi <matteo.vescovi@yahoo.co.uk>
7
8 This program is free software; you can redistribute it and/or modify
9 it under the terms of the GNU General Public License as published by
10 the Free Software Foundation; either version 2 of the License, or
11 (at your option) any later version.
12
13 This program is distributed in the hope that it will be useful,
14 but WITHOUT ANY WARRANTY; without even the implied warranty of
15 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
16 GNU General Public License for more details.
17
18 You should have received a copy of the GNU General Public License along
19 with this program; if not, write to the Free Software Foundation, Inc.,
20 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
21 *
22 **********(*)*/
23
24
26
27#include <sstream>
28#include <algorithm>
29
30
33 ct,
34 name,
35 "SmoothedNgramPredictor, a linear interpolating n-gram predictor",
36 "SmoothedNgramPredictor, long description." ),
37 db (0),
38 count_threshold (0),
39 cardinality (0),
40 learn_mode_set (false),
41 dispatcher (this)
42{
43 LOGGER = PREDICTORS + name + ".LOGGER";
44 DBFILENAME = PREDICTORS + name + ".DBFILENAME";
45 DELTAS = PREDICTORS + name + ".DELTAS";
46 COUNT_THRESHOLD = PREDICTORS + name + ".COUNT_THRESHOLD";
47 LEARN = PREDICTORS + name + ".LEARN";
48 DATABASE_LOGGER = PREDICTORS + name + ".DatabaseConnector.LOGGER";
49
50 // build notification dispatch map
57}
58
59
60
62{
63 delete db;
64}
65
66
67void SmoothedNgramPredictor::set_dbfilename (const std::string& filename)
68{
69 dbfilename = filename;
70 logger << INFO << "DBFILENAME: " << dbfilename << endl;
71
73}
74
75
77{
78 dbloglevel = value;
79}
80
81
82void SmoothedNgramPredictor::set_deltas (const std::string& value)
83{
84 std::stringstream ss_deltas(value);
85 cardinality = 0;
86 std::string delta;
87 while (ss_deltas >> delta) {
88 logger << DEBUG << "Pushing delta: " << delta << endl;
89 deltas.push_back (Utility::toDouble (delta));
91 }
92 logger << INFO << "DELTAS: " << value << endl;
93 logger << INFO << "CARDINALITY: " << cardinality << endl;
94
96}
97
98
99void SmoothedNgramPredictor::set_count_threshold (const std::string& value)
100{
102 logger << INFO << "COUNT_THRESHOLD: " << count_threshold << endl;
103}
104
105
106void SmoothedNgramPredictor::set_learn (const std::string& value)
107{
108 learn_mode = Utility::isYes (value);
109 logger << INFO << "LEARN: " << value << endl;
110
111 learn_mode_set = true;
112
114}
115
116
118{
119 // we can only init the sqlite database connector once we know the
120 // following:
121 // - what database file we need to open
122 // - what cardinality we expect the database file to be
123 // - whether we need to open the database in read only or
124 // read/write mode (learning requires read/write access)
125 //
126 if (! dbfilename.empty()
127 && cardinality > 0
128 && learn_mode_set ) {
129
130 delete db;
131
132 if (dbloglevel.empty ()) {
133 // open database connector
136 learn_mode);
137 } else {
138 // open database connector with logger lever
142 dbloglevel);
143 }
144 }
145}
146
147
148// convenience function to convert ngram to string
149//
150static std::string ngram_to_string(const Ngram& ngram)
151{
152 const char separator[] = "|";
153 std::string result = separator;
154
155 for (Ngram::const_iterator it = ngram.begin();
156 it != ngram.end();
157 it++)
158 {
159 result += *it + separator;
160 }
161
162 return result;
163}
164
165
181unsigned int SmoothedNgramPredictor::count(const std::vector<std::string>& tokens, int offset, int ngram_size) const
182{
183 unsigned int result = 0;
184
185 assert(offset <= 0); // TODO: handle this better
186 assert(ngram_size >= 0);
187
188 if (ngram_size > 0) {
189 Ngram ngram(ngram_size);
190 copy(tokens.end() - ngram_size + offset , tokens.end() + offset, ngram.begin());
191 result = db->getNgramCount(ngram);
192 logger << DEBUG << "count ngram: " << ngram_to_string (ngram) << " : " << result << endl;
193 } else {
194 result = db->getUnigramCountsSum();
195 logger << DEBUG << "unigram counts sum: " << result << endl;
196 }
197
198 return result;
199}
200
201Prediction SmoothedNgramPredictor::predict(const size_t max_partial_prediction_size, const char** filter) const
202{
203 logger << DEBUG << "predict()" << endl;
204
205 // Result prediction
206 Prediction prediction;
207
208 // Cache all the needed tokens.
209 // tokens[k] corresponds to w_{i-k} in the generalized smoothed
210 // n-gram probability formula
211 //
212 std::vector<std::string> tokens(cardinality);
213 for (int i = 0; i < cardinality; i++) {
214 tokens[cardinality - 1 - i] = contextTracker->getToken(i);
215 logger << DEBUG << "Cached tokens[" << cardinality - 1 - i << "] = " << tokens[cardinality - 1 - i] << endl;
216 }
217
218 // Generate list of prefix completition candidates.
219 //
220 // The prefix completion candidates used to be obtained from the
221 // _1_gram table because in a well-constructed ngram database the
222 // _1_gram table (which contains all known tokens). However, this
223 // introduced a skew, since the unigram counts will take
224 // precedence over the higher-order counts.
225 //
226 // The current solution retrieves candidates from the highest
227 // n-gram table, falling back on lower order n-gram tables if
228 // initial completion set is smaller than required.
229 //
230 std::vector<std::string> prefixCompletionCandidates;
231 for (size_t k = cardinality; (k > 0 && prefixCompletionCandidates.size() < max_partial_prediction_size); k--) {
232 logger << DEBUG << "Building partial prefix completion table of cardinality: " << k << endl;
233 // create n-gram used to retrieve initial prefix completion table
234 Ngram prefix_ngram(k);
235 copy(tokens.end() - k, tokens.end(), prefix_ngram.begin());
236
237 if (logger.shouldLog()) {
238 logger << DEBUG << "prefix_ngram: ";
239 for (size_t r = 0; r < prefix_ngram.size(); r++) {
240 logger << DEBUG << prefix_ngram[r] << ' ';
241 }
242 logger << DEBUG << endl;
243 }
244
245 // obtain initial prefix completion candidates
247
248 NgramTable partial;
249
250 partial = db->getNgramLikeTable(prefix_ngram,
251 filter,
253 max_partial_prediction_size - prefixCompletionCandidates.size());
254
256
257 if (logger.shouldLog()) {
258 logger << DEBUG << "partial prefixCompletionCandidates" << endl
259 << DEBUG << "----------------------------------" << endl;
260 for (size_t j = 0; j < partial.size(); j++) {
261 for (size_t k = 0; k < partial[j].size(); k++) {
262 logger << DEBUG << partial[j][k] << " ";
263 }
264 logger << endl;
265 }
266 }
267
268 logger << DEBUG << "Partial prefix completion table contains " << partial.size() << " potential completions." << endl;
269
270 // append newly discovered potential completions to prefix
271 // completion candidates array to fill it up to
272 // max_partial_prediction_size
273 //
274 std::vector<Ngram>::const_iterator it = partial.begin();
275 while (it != partial.end() && prefixCompletionCandidates.size() < max_partial_prediction_size) {
276 // only add new candidates, iterator it points to Ngram,
277 // it->end() - 2 points to the token candidate
278 //
279 std::string candidate = *(it->end() - 2);
280 if (find(prefixCompletionCandidates.begin(),
281 prefixCompletionCandidates.end(),
282 candidate) == prefixCompletionCandidates.end()) {
283 prefixCompletionCandidates.push_back(candidate);
284 }
285 it++;
286 }
287 }
288
289 if (logger.shouldLog()) {
290 logger << DEBUG << "prefixCompletionCandidates" << endl
291 << DEBUG << "--------------------------" << endl;
292 for (size_t j = 0; j < prefixCompletionCandidates.size(); j++) {
293 logger << DEBUG << prefixCompletionCandidates[j] << endl;
294 }
295 }
296
297 // compute smoothed probabilities for all candidates
298 //
300 // getUnigramCountsSum is an expensive SQL query
301 // caching it here saves much time later inside the loop
302 int unigrams_counts_sum = db->getUnigramCountsSum();
303 for (size_t j = 0; (j < prefixCompletionCandidates.size() && j < max_partial_prediction_size); j++) {
304 // store w_i candidate at end of tokens
305 tokens[cardinality - 1] = prefixCompletionCandidates[j];
306
307 logger << DEBUG << "------------------" << endl;
308 logger << DEBUG << "w_i: " << tokens[cardinality - 1] << endl;
309
310 double probability = 0;
311 for (int k = 0; k < cardinality; k++) {
312 double numerator = count(tokens, 0, k+1);
313 // reuse cached unigrams_counts_sum to speed things up
314 double denominator = (k == 0 ? unigrams_counts_sum : count(tokens, -1, k));
315 double frequency = ((denominator > 0) ? (numerator / denominator) : 0);
316 probability += deltas[k] * frequency;
317
318 logger << DEBUG << "numerator: " << numerator << endl;
319 logger << DEBUG << "denominator: " << denominator << endl;
320 logger << DEBUG << "frequency: " << frequency << endl;
321 logger << DEBUG << "delta: " << deltas[k] << endl;
322
323 // for some sanity checks
324 assert(numerator <= denominator);
325 assert(frequency <= 1);
326 }
327
328 logger << DEBUG << "____________" << endl;
329 logger << DEBUG << "probability: " << probability << endl;
330
331 if (probability > 0) {
332 prediction.addSuggestion(Suggestion(tokens[cardinality - 1], probability));
333 }
334 }
336
337 logger << DEBUG << "Prediction:" << endl;
338 logger << DEBUG << "-----------" << endl;
339 logger << DEBUG << prediction << endl;
340
341 return prediction;
342}
343
344void SmoothedNgramPredictor::learn(const std::vector<std::string>& change)
345{
346 logger << INFO << "learn(\"" << ngram_to_string(change) << "\")" << endl;
347
348 if (learn_mode) {
349 // learning is turned on
350
351 std::map<std::list<std::string>, int> ngramMap;
352
353 // build up ngram map for all cardinalities
354 // i.e. learn all ngrams and counts in memory
355 for (size_t curr_cardinality = 1;
356 curr_cardinality < cardinality + 1;
357 curr_cardinality++)
358 {
359 int change_idx = 0;
360 int change_size = change.size();
361
362 std::list<std::string> ngram_list;
363
364 // take care of first N-1 tokens
365 for (int i = 0;
366 (i < curr_cardinality - 1 && change_idx < change_size);
367 i++)
368 {
369 ngram_list.push_back(change[change_idx]);
370 change_idx++;
371 }
372
373 while (change_idx < change_size)
374 {
375 ngram_list.push_back(change[change_idx++]);
376 ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
377 ngram_list.pop_front();
378 }
379 }
380
381 // use (past stream - change) to learn token at the boundary
382 // change, i.e.
383 //
384
385 // if change is "bar foobar", then "bar" will only occur in a
386 // 1-gram, since there are no token before it. By dipping in
387 // the past stream, we additional context to learn a 2-gram by
388 // getting extra tokens (assuming past stream ends with token
389 // "foo":
390 //
391 // <"foo", "bar"> will be learnt
392 //
393 // We do this till we build up to n equal to cardinality.
394 //
395 // First check that change is not empty (nothing to learn) and
396 // that change and past stream match by sampling first and
397 // last token in change and comparing them with corresponding
398 // tokens from past stream
399 //
400 if (change.size() > 0 &&
401 change.back() == contextTracker->getToken(1) &&
402 change.front() == contextTracker->getToken(change.size()))
403 {
404 // create ngram list with first (oldest) token from change
405 std::list<std::string> ngram_list(change.begin(), change.begin() + 1);
406
407 // prepend token to ngram list by grabbing extra tokens
408 // from past stream (if there are any) till we have built
409 // up to n==cardinality ngrams, and commit them to
410 // ngramMap
411 //
412 for (int tk_idx = 1;
413 ngram_list.size() < cardinality;
414 tk_idx++)
415 {
416 // getExtraTokenToLearn returns tokens from
417 // past stream that come before and are not in
418 // change vector
419 //
420 std::string extra_token = contextTracker->getExtraTokenToLearn(tk_idx, change);
421 logger << DEBUG << "Adding extra token: " << extra_token << endl;
422
423 if (extra_token.empty())
424 {
425 break;
426 }
427 ngram_list.push_front(extra_token);
428
429 ngramMap[ngram_list] = ngramMap[ngram_list] + 1;
430 }
431 }
432
433 // then write out to language model database
434 try
435 {
437
438 std::map<std::list<std::string>, int>::const_iterator it;
439 for (it = ngramMap.begin(); it != ngramMap.end(); it++)
440 {
441 // convert ngram from list to vector based Ngram
442 Ngram ngram((it->first).begin(), (it->first).end());
443
444 // update the counts
445 int count = db->getNgramCount(ngram);
446 if (count > 0)
447 {
448 // ngram already in database, update count
449 db->updateNgram(ngram, count + it->second);
451 }
452 else
453 {
454 // ngram not in database, insert it
455 db->insertNgram(ngram, it->second);
456 }
457 }
458
460 logger << INFO << "Committed learning update to database" << endl;
461 }
463 {
465 logger << ERROR << "Rolling back learning update : " << ex.what() << endl;
466 throw;
467 }
468 }
469
470 logger << DEBUG << "end learn()" << endl;
471}
472
474{
475 // no need to begin a new transaction, as we'll be called from
476 // within an existing transaction from learn()
477
478 // BEWARE: if the previous sentence is not true, then performance
479 // WILL suffer!
480
481 size_t size = ngram.size();
482 for (size_t i = 0; i < size; i++) {
483 if (count(ngram, -i, size - i) > count(ngram, -(i + 1), size - (i + 1))) {
484 logger << INFO << "consistency adjustment needed!" << endl;
485
486 int offset = -(i + 1);
487 int sub_ngram_size = size - (i + 1);
488
489 logger << DEBUG << "i: " << i << " | offset: " << offset << " | sub_ngram_size: " << sub_ngram_size << endl;
490
491 Ngram sub_ngram(sub_ngram_size); // need to init to right size for sub_ngram
492 copy(ngram.end() - sub_ngram_size + offset, ngram.end() + offset, sub_ngram.begin());
493
494 if (logger.shouldLog()) {
495 logger << "ngram to be count adjusted is: ";
496 for (size_t i = 0; i < sub_ngram.size(); i++) {
497 logger << sub_ngram[i] << ' ';
498 }
499 logger << endl;
500 }
501
502 db->incrementNgramCount(sub_ngram);
503 logger << DEBUG << "consistency adjusted" << endl;
504 }
505 }
506}
507
509{
510 logger << DEBUG << "About to invoke dispatcher: " << var->get_name () << " - " << var->get_value() << endl;
511 dispatcher.dispatch (var);
512}
Tracks user interaction and context.
std::string getExtraTokenToLearn(const int index, const std::vector< std::string > &change) const
std::string getToken(const int) const
virtual void endTransaction() const
virtual void beginTransaction() const
virtual void rollbackTransaction() const
int incrementNgramCount(const Ngram ngram) const
void insertNgram(const Ngram ngram, const int count) const
int getUnigramCountsSum() const
int getNgramCount(const Ngram ngram) const
void updateNgram(const Ngram ngram, const int count) const
NgramTable getNgramLikeTable(const Ngram ngram, const char **filter, const int count_threshold, int limit=-1) const
Definition: ngram.h:33
virtual std::string get_name() const =0
virtual std::string get_value() const =0
void addSuggestion(Suggestion)
Definition: prediction.cpp:90
ContextTracker * contextTracker
Definition: predictor.h:83
const std::string PREDICTORS
Definition: predictor.h:81
virtual void set_logger(const std::string &level)
Definition: predictor.cpp:88
Logger< char > logger
Definition: predictor.h:87
const std::string name
Definition: predictor.h:77
virtual const char * what() const
void check_learn_consistency(const Ngram &name) const
Dispatcher< SmoothedNgramPredictor > dispatcher
std::vector< double > deltas
void set_database_logger_level(const std::string &level)
virtual void learn(const std::vector< std::string > &change)
unsigned int count(const std::vector< std::string > &tokens, int offset, int ngram_size) const
Builds the required n-gram and returns its count.
virtual void update(const Observable *variable)
void set_dbfilename(const std::string &filename)
void set_learn(const std::string &learn_mode)
SmoothedNgramPredictor(Configuration *, ContextTracker *, const char *)
virtual Prediction predict(const size_t size, const char **filter) const
Generate prediction.
void set_deltas(const std::string &deltas)
void set_count_threshold(const std::string &value)
static double toDouble(const std::string)
Definition: utility.cpp:258
static bool isYes(const char *)
Definition: utility.cpp:185
static int toInt(const std::string)
Definition: utility.cpp:266
std::vector< Ngram > NgramTable
const Logger< _charT, _Traits > & endl(const Logger< _charT, _Traits > &lgr)
Definition: logger.h:278
std::string config
Definition: presageDemo.cpp:70
static std::string ngram_to_string(const Ngram &ngram)