forked from Refefer/cloverleaf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathreweighter.rs
130 lines (106 loc) · 3.82 KB
/
reweighter.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
//! A simple node reweighter. Takes a mapping of nodes to weights, an embedding store, and
//! reweights the counts according to an alpha parameter and the distance to a context.
use hashbrown::HashMap;
use rayon::prelude::*;
use crate::graph::NodeID;
use crate::embeddings::{EmbeddingStore,Entity};
pub struct Reweighter {
pub alpha: f32
}
impl Reweighter {
pub fn new(alpha: f32) -> Self {
Reweighter { alpha }
}
pub fn reweight(
&self,
results: &mut HashMap<NodeID, f32>,
embeddings: &EmbeddingStore,
context_emb: &[f32]
) {
// Compute distances for each item
let context_node = Entity::Embedding(context_emb);
let distances: HashMap<_,_> = results.par_keys()
.filter(|node| embeddings.is_set(**node))
.map(|node| {
let n = Entity::Node(*node);
let distance = embeddings.compute_distance(&context_node, &n);
(*node, distance)
}).collect();
if distances.len() > 2 {
Reweighter::reweight_by_distance(results, &distances, self.alpha);
}
}
#[inline]
fn sigmoid(x: f32) -> f32 {
1. / (1. + (-x).exp())
}
fn reweight_by_distance(
results: &mut HashMap<NodeID, f32>,
distances: &HashMap<NodeID,f32>,
alpha: f32
) {
// Z Normalize the values to a unit Normal, then run it through a sigmoid
// transform (pretending it's a logistic distribution) to convert to probabilities.
// In cases where an embedding is missing, we set the distance to the expected value
let (mu, sigma) = Reweighter::compute_stats(&distances);
results.par_iter_mut().for_each(|(k, wi)| {
let p = if let Some(d) = distances.get(k) {
let nd = (d - mu) / sigma;
// Lower is better!
1. - Reweighter::sigmoid(nd)
} else {
0.5
};
//*wi = (1. - alpha) * *wi + alpha * p;
*wi *= (p).powf(alpha);
});
}
fn compute_stats(distances: &HashMap<NodeID,f32>) -> (f32, f32) {
let n = distances.len() as f32;
// Z Normalize the values to a unit Normal, then run it through a sigmoid
// transform (pretending it's a logistic distribution) to convert to probabilities.
// In cases where an embedding is missing, we set the distance to the expected value
let mu = distances.par_values().sum::<f32>() / n;
let ss = distances.par_values()
.map(|d| (*d - mu).powf(2.))
.sum::<f32>();
let sigma = (ss / n).sqrt();
(mu, if sigma > 0. { sigma } else {1.})
}
}
#[cfg(test)]
mod reweighter_tests {
use super::*;
use float_ord::FloatOrd;
fn build_counts() -> HashMap<usize, f32>{
let mut hm = HashMap::new();
hm.insert(0, 1.);
hm.insert(1, 2.);
hm
}
#[test]
fn test_compute_stats() {
let hm = build_counts();
let (mu, sigma) = Reweighter::compute_stats(&hm);
assert_eq!(mu, 1.5);
assert_eq!(sigma, (0.5 / 2f32).sqrt());
}
#[test]
fn test_reweight() {
let mut counts = build_counts();
let mut distances = HashMap::new();
distances.insert(0, 0.2);
distances.insert(1, 0.5);
// Distance doesn't matter
Reweighter::reweight_by_distance(&mut counts, &distances, 0.);
assert_eq!(counts[&0], 1.);
assert_eq!(counts[&1], 2.);
// Distance only matters
Reweighter::reweight_by_distance(&mut counts, &distances, 1.);
let mut counts = counts.into_iter().collect::<Vec<_>>();
counts.sort_by_key(|(_, w) | FloatOrd(-*w));
println!("{:?}", counts);
assert_eq!(counts[0].0, 0);
assert_eq!(counts[1].0, 1);
}
}