SphinxBase 5prealpha
lm_trie.c
1/* -*- c-basic-offset: 4; indent-tabs-mode: nil -*- */
2/* ====================================================================
3 * Copyright (c) 2015 Carnegie Mellon University. All rights
4 * reserved.
5 *
6 * Redistribution and use in source and binary forms, with or without
7 * modification, are permitted provided that the following conditions
8 * are met:
9 *
10 * 1. Redistributions of source code must retain the above copyright
11 * notice, this list of conditions and the following disclaimer.
12 *
13 * 2. Redistributions in binary form must reproduce the above copyright
14 * notice, this list of conditions and the following disclaimer in
15 * the documentation and/or other materials provided with the
16 * distribution.
17 *
18 * This work was supported in part by funding from the Defense Advanced
19 * Research Projects Agency and the National Science Foundation of the
20 * United States of America, and the CMU Sphinx Speech Consortium.
21 *
22 * THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND
23 * ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
24 * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
25 * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
26 * NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32 * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 *
34 * ====================================================================
35 *
36 */
37
38#include <string.h>
39#include <stdio.h>
40#include <assert.h>
41
44#include <sphinxbase/err.h>
45#include <sphinxbase/priority_queue.h>
46
47#include "lm_trie.h"
48#include "lm_trie_quant.h"
49
50static void lm_trie_alloc_ngram(lm_trie_t * trie, uint32 * counts, int order);
51
52static uint32
53base_size(uint32 entries, uint32 max_vocab, uint8 remaining_bits)
54{
55 uint8 total_bits = bitarr_required_bits(max_vocab) + remaining_bits;
56 /* Extra entry for next pointer at the end.
57 * +7 then / 8 to round up bits and convert to bytes
58 * +sizeof(uint64) so that ReadInt57 etc don't go segfault.
59 * Note that this waste is O(order), not O(number of ngrams).*/
60 return ((1 + entries) * total_bits + 7) / 8 + sizeof(uint64);
61}
62
63uint32
64middle_size(uint8 quant_bits, uint32 entries, uint32 max_vocab,
65 uint32 max_ptr)
66{
67 return base_size(entries, max_vocab,
68 quant_bits + bitarr_required_bits(max_ptr));
69}
70
71uint32
72longest_size(uint8 quant_bits, uint32 entries, uint32 max_vocab)
73{
74 return base_size(entries, max_vocab, quant_bits);
75}
76
77static void
78base_init(base_t * base, void *base_mem, uint32 max_vocab,
79 uint8 remaining_bits)
80{
81 base->word_bits = bitarr_required_bits(max_vocab);
82 base->word_mask = (1U << base->word_bits) - 1U;
83 if (base->word_bits > 25)
85 ("Sorry, word indices more than %d are not implemented. Edit util/bit_packing.hh and fix the bit packing functions\n",
86 (1U << 25));
87 base->total_bits = base->word_bits + remaining_bits;
88
89 base->base = (uint8 *) base_mem;
90 base->insert_index = 0;
91 base->max_vocab = max_vocab;
92}
93
94void
95middle_init(middle_t * middle, void *base_mem, uint8 quant_bits,
96 uint32 entries, uint32 max_vocab, uint32 max_next,
97 void *next_source)
98{
99 middle->quant_bits = quant_bits;
100 bitarr_mask_from_max(&middle->next_mask, max_next);
101 middle->next_source = next_source;
102 if (entries + 1 >= (1U << 25) || (max_next >= (1U << 25)))
103 E_ERROR
104 ("Sorry, this does not support more than %d n-grams of a particular order. Edit util/bit_packing.hh and fix the bit packing functions\n",
105 (1U << 25));
106 base_init(&middle->base, base_mem, max_vocab,
107 quant_bits + middle->next_mask.bits);
108}
109
110void
111longest_init(longest_t * longest, void *base_mem, uint8 quant_bits,
112 uint32 max_vocab)
113{
114 base_init(&longest->base, base_mem, max_vocab, quant_bits);
115}
116
117static bitarr_address_t
118middle_insert(middle_t * middle, uint32 word, int order, int max_order)
119{
120 uint32 at_pointer;
121 uint32 next;
122 bitarr_address_t address;
123 assert(word <= middle->base.word_mask);
124 address.base = middle->base.base;
125 address.offset = middle->base.insert_index * middle->base.total_bits;
126 bitarr_write_int25(address, middle->base.word_bits, word);
127 address.offset += middle->base.word_bits;
128 at_pointer = address.offset;
129 address.offset += middle->quant_bits;
130 if (order == max_order - 1) {
131 next = ((longest_t *) middle->next_source)->base.insert_index;
132 }
133 else {
134 next = ((middle_t *) middle->next_source)->base.insert_index;
135 }
136
137 bitarr_write_int25(address, middle->next_mask.bits, next);
138 middle->base.insert_index++;
139 address.offset = at_pointer;
140 return address;
141}
142
143static bitarr_address_t
144longest_insert(longest_t * longest, uint32 index)
145{
146 bitarr_address_t address;
147 assert(index <= longest->base.word_mask);
148 address.base = longest->base.base;
149 address.offset = longest->base.insert_index * longest->base.total_bits;
150 bitarr_write_int25(address, longest->base.word_bits, index);
151 address.offset += longest->base.word_bits;
152 longest->base.insert_index++;
153 return address;
154}
155
156static void
157middle_finish_loading(middle_t * middle, uint32 next_end)
158{
159 bitarr_address_t address;
160 address.base = middle->base.base;
161 address.offset =
162 (middle->base.insert_index + 1) * middle->base.total_bits -
163 middle->next_mask.bits;
164 bitarr_write_int25(address, middle->next_mask.bits, next_end);
165}
166
167static uint32
168unigram_next(lm_trie_t * trie, int order)
169{
170 return order ==
171 2 ? trie->longest->base.insert_index : trie->middle_begin->base.
172 insert_index;
173}
174
175void
176lm_trie_fix_counts(ngram_raw_t ** raw_ngrams, uint32 * counts,
177 uint32 * fixed_counts, int order)
178{
179 priority_queue_t *ngrams =
180 priority_queue_create(order - 1, &ngram_ord_comparator);
181 uint32 raw_ngram_ptrs[NGRAM_MAX_ORDER - 1];
182 uint32 words[NGRAM_MAX_ORDER];
183 int i;
184
185 memset(words, -1, sizeof(words));
186 memcpy(fixed_counts, counts, order * sizeof(*fixed_counts));
187 for (i = 2; i <= order; i++) {
188 ngram_raw_t *tmp_ngram;
189
190 if (counts[i - 1] <= 0)
191 continue;
192
193 raw_ngram_ptrs[i - 2] = 0;
194
195 tmp_ngram =
196 (ngram_raw_t *) ckd_calloc(1, sizeof(*tmp_ngram));
197 *tmp_ngram = raw_ngrams[i - 2][0];
198 tmp_ngram->order = i;
199 priority_queue_add(ngrams, tmp_ngram);
200 }
201
202 for (;;) {
203 int32 to_increment = TRUE;
204 ngram_raw_t *top;
205 if (priority_queue_size(ngrams) == 0) {
206 break;
207 }
208 top = (ngram_raw_t *) priority_queue_poll(ngrams);
209 if (top->order == 2) {
210 memcpy(words, top->words, 2 * sizeof(*words));
211 }
212 else {
213 for (i = 0; i < top->order - 1; i++) {
214 if (words[i] != top->words[i]) {
215 int num;
216 num = (i == 0) ? 1 : i;
217 memcpy(words, top->words,
218 (num + 1) * sizeof(*words));
219 fixed_counts[num]++;
220 to_increment = FALSE;
221 break;
222 }
223 }
224 words[top->order - 1] = top->words[top->order - 1];
225 }
226 if (to_increment) {
227 raw_ngram_ptrs[top->order - 2]++;
228 }
229 if (raw_ngram_ptrs[top->order - 2] < counts[top->order - 1]) {
230 *top = raw_ngrams[top->order - 2][raw_ngram_ptrs[top->order - 2]];
231 priority_queue_add(ngrams, top);
232 }
233 else {
234 ckd_free(top);
235 }
236 }
237
238 assert(priority_queue_size(ngrams) == 0);
239 priority_queue_free(ngrams, NULL);
240}
241
242
243static void
244recursive_insert(lm_trie_t * trie, ngram_raw_t ** raw_ngrams,
245 uint32 * counts, int order)
246{
247 uint32 unigram_idx = 0;
248 uint32 *words;
249 float *probs;
250 const uint32 unigram_count = (uint32) counts[0];
251 priority_queue_t *ngrams =
252 priority_queue_create(order, &ngram_ord_comparator);
253 ngram_raw_t *ngram;
254 uint32 *raw_ngrams_ptr;
255 int i;
256
257 words = (uint32 *) ckd_calloc(order, sizeof(*words));
258 probs = (float *) ckd_calloc(order - 1, sizeof(*probs));
259 ngram = (ngram_raw_t *) ckd_calloc(1, sizeof(*ngram));
260 ngram->order = 1;
261 ngram->words = &unigram_idx;
262 priority_queue_add(ngrams, ngram);
263 raw_ngrams_ptr =
264 (uint32 *) ckd_calloc(order - 1, sizeof(*raw_ngrams_ptr));
265 for (i = 2; i <= order; ++i) {
266 ngram_raw_t *tmp_ngram;
267
268 if (counts[i - 1] <= 0)
269 continue;
270
271 raw_ngrams_ptr[i - 2] = 0;
272 tmp_ngram =
273 (ngram_raw_t *) ckd_calloc(1, sizeof(*tmp_ngram));
274 *tmp_ngram = raw_ngrams[i - 2][0];
275 tmp_ngram->order = i;
276
277 priority_queue_add(ngrams, tmp_ngram);
278 }
279
280 for (;;) {
281 ngram_raw_t *top =
282 (ngram_raw_t *) priority_queue_poll(ngrams);
283
284 if (top->order == 1) {
285 trie->unigrams[unigram_idx].next = unigram_next(trie, order);
286 words[0] = unigram_idx;
287 probs[0] = trie->unigrams[unigram_idx].prob;
288 if (++unigram_idx == unigram_count + 1) {
289 ckd_free(top);
290 break;
291 }
292 priority_queue_add(ngrams, top);
293 }
294 else {
295 for (i = 0; i < top->order - 1; i++) {
296 if (words[i] != top->words[i]) {
297 /* need to insert dummy suffixes to make ngram of higher order reachable */
298 int j;
299 assert(i > 0); /* unigrams are not pruned without removing ngrams that contains them */
300 for (j = i; j < top->order - 1; j++) {
301 middle_t *middle = &trie->middle_begin[j - 1];
302 bitarr_address_t address =
303 middle_insert(middle, top->words[j],
304 j + 1, order);
305 /* calculate prob for blank */
306 float calc_prob =
307 probs[j - 1] +
308 trie->unigrams[top->words[j]].bo;
309 probs[j] = calc_prob;
310 lm_trie_quant_mwrite(trie->quant, address, j - 1,
311 calc_prob, 0.0f);
312 }
313 }
314 }
315 memcpy(words, top->words,
316 top->order * sizeof(*words));
317 if (top->order == order) {
318 bitarr_address_t address =
319 longest_insert(trie->longest,
320 top->words[top->order - 1]);
321 lm_trie_quant_lwrite(trie->quant, address, top->prob);
322 }
323 else {
324 middle_t *middle = &trie->middle_begin[top->order - 2];
325 bitarr_address_t address =
326 middle_insert(middle,
327 top->words[top->order - 1],
328 top->order, order);
329 /* write prob and backoff */
330 probs[top->order - 1] = top->prob;
331 lm_trie_quant_mwrite(trie->quant, address, top->order - 2,
332 top->prob, top->backoff);
333 }
334 raw_ngrams_ptr[top->order - 2]++;
335 if (raw_ngrams_ptr[top->order - 2] < counts[top->order - 1]) {
336 *top = raw_ngrams[top->order -
337 2][raw_ngrams_ptr[top->order - 2]];
338
339 priority_queue_add(ngrams, top);
340 }
341 else {
342 ckd_free(top);
343 }
344 }
345 }
346 assert(priority_queue_size(ngrams) == 0);
347 priority_queue_free(ngrams, NULL);
348 ckd_free(raw_ngrams_ptr);
349 ckd_free(words);
350 ckd_free(probs);
351}
352
353static lm_trie_t *
354lm_trie_init(uint32 unigram_count)
355{
356 lm_trie_t *trie;
357
358 trie = (lm_trie_t *) ckd_calloc(1, sizeof(*trie));
359 memset(trie->hist_cache, -1, sizeof(trie->hist_cache)); /* prepare request history */
360 memset(trie->backoff_cache, 0, sizeof(trie->backoff_cache));
361 trie->unigrams =
362 (unigram_t *) ckd_calloc((unigram_count + 1),
363 sizeof(*trie->unigrams));
364 trie->ngram_mem = NULL;
365 return trie;
366}
367
368lm_trie_t *
369lm_trie_create(uint32 unigram_count, int order)
370{
371 lm_trie_t *trie = lm_trie_init(unigram_count);
372 trie->quant =
373 (order > 1) ? lm_trie_quant_create(order) : 0;
374 return trie;
375}
376
377lm_trie_t *
378lm_trie_read_bin(uint32 * counts, int order, FILE * fp)
379{
380 lm_trie_t *trie = lm_trie_init(counts[0]);
381 trie->quant = (order > 1) ? lm_trie_quant_read_bin(fp, order) : NULL;
382 fread(trie->unigrams, sizeof(*trie->unigrams), (counts[0] + 1), fp);
383 if (order > 1) {
384 lm_trie_alloc_ngram(trie, counts, order);
385 fread(trie->ngram_mem, 1, trie->ngram_mem_size, fp);
386 }
387 return trie;
388}
389
390void
391lm_trie_write_bin(lm_trie_t * trie, uint32 unigram_count, FILE * fp)
392{
393
394 if (trie->quant)
395 lm_trie_quant_write_bin(trie->quant, fp);
396 fwrite(trie->unigrams, sizeof(*trie->unigrams), (unigram_count + 1),
397 fp);
398 if (trie->ngram_mem)
399 fwrite(trie->ngram_mem, 1, trie->ngram_mem_size, fp);
400}
401
402void
403lm_trie_free(lm_trie_t * trie)
404{
405 if (trie->ngram_mem) {
406 ckd_free(trie->ngram_mem);
407 ckd_free(trie->middle_begin);
408 ckd_free(trie->longest);
409 }
410 if (trie->quant)
411 lm_trie_quant_free(trie->quant);
412 ckd_free(trie->unigrams);
413 ckd_free(trie);
414}
415
416static void
417lm_trie_alloc_ngram(lm_trie_t * trie, uint32 * counts, int order)
418{
419 int i;
420 uint8 *mem_ptr;
421 uint8 **middle_starts;
422
423 trie->ngram_mem_size = 0;
424 for (i = 1; i < order - 1; i++) {
425 trie->ngram_mem_size +=
426 middle_size(lm_trie_quant_msize(trie->quant), counts[i],
427 counts[0], counts[i + 1]);
428 }
429 trie->ngram_mem_size +=
430 longest_size(lm_trie_quant_lsize(trie->quant), counts[order - 1],
431 counts[0]);
432 trie->ngram_mem =
433 (uint8 *) ckd_calloc(trie->ngram_mem_size,
434 sizeof(*trie->ngram_mem));
435 mem_ptr = trie->ngram_mem;
436 trie->middle_begin =
437 (middle_t *) ckd_calloc(order - 2, sizeof(*trie->middle_begin));
438 trie->middle_end = trie->middle_begin + (order - 2);
439 middle_starts =
440 (uint8 **) ckd_calloc(order - 2, sizeof(*middle_starts));
441 for (i = 2; i < order; i++) {
442 middle_starts[i - 2] = mem_ptr;
443 mem_ptr +=
444 middle_size(lm_trie_quant_msize(trie->quant), counts[i - 1],
445 counts[0], counts[i]);
446 }
447 trie->longest = (longest_t *) ckd_calloc(1, sizeof(*trie->longest));
448 /* Crazy backwards thing so we initialize using pointers to ones that have already been initialized */
449 for (i = order - 1; i >= 2; --i) {
450 middle_t *middle_ptr = &trie->middle_begin[i - 2];
451 middle_init(middle_ptr, middle_starts[i - 2],
452 lm_trie_quant_msize(trie->quant), counts[i - 1],
453 counts[0], counts[i],
454 (i ==
455 order -
456 1) ? (void *) trie->longest : (void *) &trie->
457 middle_begin[i - 1]);
458 }
459 ckd_free(middle_starts);
460 longest_init(trie->longest, mem_ptr, lm_trie_quant_lsize(trie->quant),
461 counts[0]);
462}
463
464void
465lm_trie_build(lm_trie_t * trie, ngram_raw_t ** raw_ngrams, uint32 * counts, uint32 *out_counts,
466 int order)
467{
468 int i;
469
470 lm_trie_fix_counts(raw_ngrams, counts, out_counts, order);
471 lm_trie_alloc_ngram(trie, out_counts, order);
472
473 if (order > 1)
474 E_INFO("Training quantizer\n");
475 for (i = 2; i < order; i++) {
476 lm_trie_quant_train(trie->quant, i, counts[i - 1],
477 raw_ngrams[i - 2]);
478 }
479 lm_trie_quant_train_prob(trie->quant, order, counts[order - 1],
480 raw_ngrams[order - 2]);
481
482 E_INFO("Building LM trie\n");
483 recursive_insert(trie, raw_ngrams, counts, order);
484 /* Set ending offsets so the last entry will be sized properly */
485 /* Last entry for unigrams was already set. */
486 if (trie->middle_begin != trie->middle_end) {
487 middle_t *middle_ptr;
488 for (middle_ptr = trie->middle_begin;
489 middle_ptr != trie->middle_end - 1; ++middle_ptr) {
490 middle_t *next_middle_ptr = middle_ptr + 1;
491 middle_finish_loading(middle_ptr,
492 next_middle_ptr->base.insert_index);
493 }
494 middle_ptr = trie->middle_end - 1;
495 middle_finish_loading(middle_ptr,
496 trie->longest->base.insert_index);
497 }
498}
499
500unigram_t *
501unigram_find(unigram_t * u, uint32 word, node_range_t * next)
502{
503 unigram_t *ptr = &u[word];
504 next->begin = ptr->next;
505 next->end = (ptr + 1)->next;
506 return ptr;
507}
508
509static size_t
510calc_pivot(uint32 off, uint32 range, uint32 width)
511{
512 return (size_t) ((off * width) / (range + 1));
513}
514
515static uint8
516uniform_find(void *base, uint8 total_bits, uint8 key_bits, uint32 key_mask,
517 uint32 before_it, uint32 before_v,
518 uint32 after_it, uint32 after_v, uint32 key, uint32 * out)
519{
520 bitarr_address_t address;
521 address.base = base;
522 while (after_it - before_it > 1) {
523 uint32 mid;
524 uint32 pivot =
525 before_it + (1 +
526 calc_pivot(key - before_v, after_v - before_v,
527 after_it - before_it - 1));
528 /* access by pivot */
529 address.offset = pivot * (uint32) total_bits;
530 mid = bitarr_read_int25(address, key_bits, key_mask);
531 if (mid < key) {
532 before_it = pivot;
533 before_v = mid;
534 }
535 else if (mid > key) {
536 after_it = pivot;
537 after_v = mid;
538 }
539 else {
540 *out = pivot;
541 return TRUE;
542 }
543 }
544 return FALSE;
545}
546
547static bitarr_address_t
548middle_find(middle_t * middle, uint32 word, node_range_t * range)
549{
550 uint32 at_pointer;
551 bitarr_address_t address;
552
553 /* finding BitPacked with uniform find */
554 if (!uniform_find
555 ((void *) middle->base.base, middle->base.total_bits,
556 middle->base.word_bits, middle->base.word_mask, range->begin - 1,
557 0, range->end, middle->base.max_vocab, word, &at_pointer)) {
558 address.base = NULL;
559 address.offset = 0;
560 return address;
561 }
562
563 address.base = middle->base.base;
564 at_pointer *= middle->base.total_bits;
565 at_pointer += middle->base.word_bits;
566 address.offset = at_pointer + middle->quant_bits;
567 range->begin =
568 bitarr_read_int25(address, middle->next_mask.bits,
569 middle->next_mask.mask);
570 address.offset += middle->base.total_bits;
571 range->end =
572 bitarr_read_int25(address, middle->next_mask.bits,
573 middle->next_mask.mask);
574 address.offset = at_pointer;
575
576 return address;
577}
578
579static bitarr_address_t
580longest_find(longest_t * longest, uint32 word, node_range_t * range)
581{
582 uint32 at_pointer;
583 bitarr_address_t address;
584
585 /* finding BitPacked with uniform find */
586 if (!uniform_find
587 ((void *) longest->base.base, longest->base.total_bits,
588 longest->base.word_bits, longest->base.word_mask,
589 range->begin - 1, 0, range->end, longest->base.max_vocab, word,
590 &at_pointer)) {
591 address.base = NULL;
592 address.offset = 0;
593 return address;
594 }
595 address.base = longest->base.base;
596 address.offset =
597 at_pointer * longest->base.total_bits + longest->base.word_bits;
598 return address;
599}
600
601static float
602get_available_prob(lm_trie_t * trie, int32 wid, int32 * hist,
603 int max_order, int32 n_hist, int32 * n_used)
604{
605 float prob;
606 node_range_t node;
607 bitarr_address_t address;
608 int order_minus_2;
609 uint8 independent_left;
610 int32 *hist_iter, *hist_end;
611
612 *n_used = 1;
613 prob = unigram_find(trie->unigrams, wid, &node)->prob;
614 if (n_hist == 0) {
615 return prob;
616 }
617
618 /* find ngrams of higher order if any */
619 order_minus_2 = 0;
620 independent_left = (node.begin == node.end);
621 hist_iter = hist;
622 hist_end = hist + n_hist;
623 for (;; order_minus_2++, hist_iter++) {
624 if (hist_iter == hist_end)
625 return prob;
626 if (independent_left)
627 return prob;
628 if (order_minus_2 == max_order - 2)
629 break;
630
631 address =
632 middle_find(&trie->middle_begin[order_minus_2], *hist_iter,
633 &node);
634 independent_left = (address.base == NULL)
635 || (node.begin == node.end);
636
637 /* didn't find entry */
638 if (address.base == NULL)
639 return prob;
640 prob = lm_trie_quant_mpread(trie->quant, address, order_minus_2);
641 *n_used = order_minus_2 + 2;
642 }
643
644 address = longest_find(trie->longest, *hist_iter, &node);
645 if (address.base != NULL) {
646 prob = lm_trie_quant_lpread(trie->quant, address);
647 *n_used = max_order;
648 }
649 return prob;
650}
651
652static float
653get_available_backoff(lm_trie_t * trie, int32 start, int32 * hist,
654 int32 n_hist)
655{
656 float backoff = 0.0f;
657 int order_minus_2;
658 int32 *hist_iter;
659 node_range_t node;
660 unigram_t *first_hist = unigram_find(trie->unigrams, hist[0], &node);
661 if (start <= 1) {
662 backoff += first_hist->bo;
663 start = 2;
664 }
665 order_minus_2 = start - 2;
666 for (hist_iter = hist + start - 1; hist_iter < hist + n_hist;
667 hist_iter++, order_minus_2++) {
668 bitarr_address_t address =
669 middle_find(&trie->middle_begin[order_minus_2], *hist_iter,
670 &node);
671 if (address.base == NULL)
672 break;
673 backoff +=
674 lm_trie_quant_mboread(trie->quant, address, order_minus_2);
675 }
676 return backoff;
677}
678
679static float
680lm_trie_nobo_score(lm_trie_t * trie, int32 wid, int32 * hist,
681 int max_order, int32 n_hist, int32 * n_used)
682{
683 float prob =
684 get_available_prob(trie, wid, hist, max_order, n_hist, n_used);
685 if (n_hist < *n_used)
686 return prob;
687 return prob + get_available_backoff(trie, *n_used, hist, n_hist);
688}
689
690static float
691lm_trie_hist_score(lm_trie_t * trie, int32 wid, int32 * hist, int32 n_hist,
692 int32 * n_used)
693{
694 float prob;
695 int i, j;
696 node_range_t node;
697 bitarr_address_t address;
698
699 *n_used = 1;
700 prob = unigram_find(trie->unigrams, wid, &node)->prob;
701 if (n_hist == 0)
702 return prob;
703 for (i = 0; i < n_hist - 1; i++) {
704 address = middle_find(&trie->middle_begin[i], hist[i], &node);
705 if (address.base == NULL) {
706 for (j = i; j < n_hist; j++) {
707 prob += trie->backoff_cache[j];
708 }
709 return prob;
710 }
711 else {
712 (*n_used)++;
713 prob = lm_trie_quant_mpread(trie->quant, address, i);
714 }
715 }
716 address = longest_find(trie->longest, hist[n_hist - 1], &node);
717 if (address.base == NULL) {
718 return prob + trie->backoff_cache[n_hist - 1];
719 }
720 else {
721 (*n_used)++;
722 return lm_trie_quant_lpread(trie->quant, address);
723 }
724}
725
726static uint8
727history_matches(int32 * hist, int32 * prev_hist, int32 n_hist)
728{
729 int i;
730 for (i = 0; i < n_hist; i++) {
731 if (hist[i] != prev_hist[i]) {
732 return FALSE;
733 }
734 }
735 return TRUE;
736}
737
738static void
739update_backoff(lm_trie_t * trie, int32 * hist, int32 n_hist)
740{
741 int i;
742 node_range_t node;
743 bitarr_address_t address;
744
745 memset(trie->backoff_cache, 0, sizeof(trie->backoff_cache));
746 trie->backoff_cache[0] = unigram_find(trie->unigrams, hist[0], &node)->bo;
747 for (i = 1; i < n_hist; i++) {
748 address = middle_find(&trie->middle_begin[i - 1], hist[i], &node);
749 if (address.base == NULL) {
750 break;
751 }
752 trie->backoff_cache[i] =
753 lm_trie_quant_mboread(trie->quant, address, i - 1);
754 }
755 memcpy(trie->hist_cache, hist, n_hist * sizeof(*hist));
756}
757
758float
759lm_trie_score(lm_trie_t * trie, int order, int32 wid, int32 * hist,
760 int32 n_hist, int32 * n_used)
761{
762 if (n_hist < order - 1) {
763 return lm_trie_nobo_score(trie, wid, hist, order, n_hist, n_used);
764 }
765 else {
766 assert(n_hist == order - 1);
767 if (!history_matches(hist, (int32 *) trie->hist_cache, n_hist)) {
768 update_backoff(trie, hist, n_hist);
769 }
770 return lm_trie_hist_score(trie, wid, hist, n_hist, n_used);
771 }
772}
773
774void
775lm_trie_fill_raw_ngram(lm_trie_t * trie,
776 ngram_raw_t * raw_ngrams, uint32 * raw_ngram_idx,
777 uint32 * counts, node_range_t range, uint32 * hist,
778 int n_hist, int order, int max_order)
779{
780 if (n_hist > 0 && range.begin == range.end) {
781 return;
782 }
783 if (n_hist == 0) {
784 uint32 i;
785 for (i = 0; i < counts[0]; i++) {
786 node_range_t node;
787 unigram_find(trie->unigrams, i, &node);
788 hist[0] = i;
789 lm_trie_fill_raw_ngram(trie, raw_ngrams, raw_ngram_idx, counts,
790 node, hist, 1, order, max_order);
791 }
792 }
793 else if (n_hist < order - 1) {
794 uint32 ptr;
795 node_range_t node;
796 bitarr_address_t address;
797 uint32 new_word;
798 middle_t *middle = &trie->middle_begin[n_hist - 1];
799 for (ptr = range.begin; ptr < range.end; ptr++) {
800 address.base = middle->base.base;
801 address.offset = ptr * middle->base.total_bits;
802 new_word =
803 bitarr_read_int25(address, middle->base.word_bits,
804 middle->base.word_mask);
805 hist[n_hist] = new_word;
806 address.offset += middle->base.word_bits + middle->quant_bits;
807 node.begin =
808 bitarr_read_int25(address, middle->next_mask.bits,
809 middle->next_mask.mask);
810 address.offset =
811 (ptr + 1) * middle->base.total_bits +
812 middle->base.word_bits + middle->quant_bits;
813 node.end =
814 bitarr_read_int25(address, middle->next_mask.bits,
815 middle->next_mask.mask);
816 lm_trie_fill_raw_ngram(trie, raw_ngrams, raw_ngram_idx, counts,
817 node, hist, n_hist + 1, order, max_order);
818 }
819 }
820 else {
821 bitarr_address_t address;
822 uint32 ptr;
823 float prob, backoff;
824 int i;
825 assert(n_hist == order - 1);
826 for (ptr = range.begin; ptr < range.end; ptr++) {
827 ngram_raw_t *raw_ngram = &raw_ngrams[*raw_ngram_idx];
828 if (order == max_order) {
829 longest_t *longest = trie->longest;
830 address.base = longest->base.base;
831 address.offset = ptr * longest->base.total_bits;
832 hist[n_hist] =
833 bitarr_read_int25(address, longest->base.word_bits,
834 longest->base.word_mask);
835 address.offset += longest->base.word_bits;
836 prob = lm_trie_quant_lpread(trie->quant, address);
837 }
838 else {
839 middle_t *middle = &trie->middle_begin[n_hist - 1];
840 address.base = middle->base.base;
841 address.offset = ptr * middle->base.total_bits;
842 hist[n_hist] =
843 bitarr_read_int25(address, middle->base.word_bits,
844 middle->base.word_mask);
845 address.offset += middle->base.word_bits;
846 prob =
847 lm_trie_quant_mpread(trie->quant, address, n_hist - 1);
848 backoff =
849 lm_trie_quant_mboread(trie->quant, address,
850 n_hist - 1);
851 raw_ngram->backoff = backoff;
852 }
853 raw_ngram->prob = prob;
854 raw_ngram->words =
855 (uint32 *) ckd_calloc(order, sizeof(*raw_ngram->words));
856 for (i = 0; i <= n_hist; i++) {
857 raw_ngram->words[i] = hist[n_hist - i];
858 }
859 (*raw_ngram_idx)++;
860 }
861 }
862}
SPHINXBASE_EXPORT uint8 bitarr_required_bits(uint32 max_value)
Computes amount of bits required ti store integers upto value provided.
Definition: bitarr.c:131
SPHINXBASE_EXPORT uint32 bitarr_read_int25(bitarr_address_t address, uint8 length, uint32 mask)
Read uint32 value from bit array.
Definition: bitarr.c:100
SPHINXBASE_EXPORT void bitarr_write_int25(bitarr_address_t address, uint8 length, uint32 value)
Write specified value into bit array.
Definition: bitarr.c:112
SPHINXBASE_EXPORT void bitarr_mask_from_max(bitarr_mask_t *bit_mask, uint32 max_value)
Fills mask for certain int range according to provided max value.
Definition: bitarr.c:125
Sphinx's memory allocation/deallocation routines.
SPHINXBASE_EXPORT void ckd_free(void *ptr)
Test and free a 1-D array.
Definition: ckd_alloc.c:244
#define ckd_calloc(n, sz)
Macros to simplify the use of above functions.
Definition: ckd_alloc.h:248
Implementation of logging routines.
#define E_ERROR(...)
Print error message to error log.
Definition: err.h:104
#define E_INFO(...)
Print logging information to standard error stream.
Definition: err.h:114
Basic type definitions used in Sphinx.
Definition: lm_trie.h:58
Structure that stores address of certain value in bit array.
Definition: bitarr.h:73