forked from Refefer/cloverleaf
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfeature_store.rs
139 lines (115 loc) · 4.13 KB
/
feature_store.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
//! Defines the FeatureStore class which is used to define discrete features for each node
use std::sync::Arc;
use crate::NodeID;
use crate::vocab::Vocab;
/// Makes it compatible for with feature setting
struct ArcWrap(Arc<String>);
impl AsRef<str> for ArcWrap {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}
/// Main FeatureStore struct. We use a vector of vectors to allow for dynamic numbers of features.
/// This can and should be updated to a more memory friendly approach since vectors have surprising
/// overhead and the number of discrete features tends to be relatively small.
#[derive(Debug)]
pub struct FeatureStore {
/// Raw storage for features, indexed by node id
features: Vec<Vec<usize>>,
/// Maps a raw feature to a feature_id
feature_vocab: Vocab,
}
impl FeatureStore {
pub fn new(size: usize) -> Self {
FeatureStore {
features: vec![Vec::with_capacity(0); size],
feature_vocab: Vocab::new(),
}
}
pub fn set_features<A,B>(
&mut self,
node: NodeID,
node_features: impl Iterator<Item=(A, B)>
)
where
A: AsRef<str>,
B: AsRef<str>
{
self.features[node] = node_features
.map(|(ft, fname)| self.feature_vocab.get_or_insert(ft.as_ref(), fname.as_ref()))
.collect()
}
pub fn set_features_raw(
&mut self,
node: NodeID,
node_features: impl Iterator<Item=usize>
) {
self.features[node].extend(node_features);
}
pub fn get_features(&self, node: NodeID) -> &[usize] {
&self.features[node]
}
fn get_pretty_feature(&self, feat_id: usize) -> (String, String) {
let (nt, name) = self.feature_vocab.get_name(feat_id).unwrap();
(nt.to_string(), name.to_string())
}
pub fn get_pretty_features(&self, node: NodeID) -> Vec<(String, String)> {
self.features[node].iter().map(|v_id| {
self.get_pretty_feature(*v_id)
}).collect()
}
pub fn num_features(&self) -> usize {
self.feature_vocab.len()
}
pub fn num_nodes(&self) -> usize {
self.features.len()
}
/// This method assigns an unique, anonymous feature to all nodes which lack any features. This is
/// necessary for all graph embedding algorithms which map {feature} -> Embedding.
pub fn fill_missing_nodes(&mut self) {
for i in 0..self.features.len() {
if self.features[i].len() == 0 {
self.set_features(i, [("node", i.to_string())].into_iter());
}
}
}
pub fn get_vocab(&self) -> &Vocab {
&self.feature_vocab
}
pub fn clone_vocab(&self) -> Vocab {
self.feature_vocab.clone()
}
pub fn iter(&self) -> impl Iterator<Item=&Vec<usize>> {
self.features.iter()
}
/// Count the number of occurrences of each feature in the feature set. This is helpful when
/// pruning to a minimum count.
pub fn count_features(&self) -> Vec<usize> {
let mut counts = vec![0usize; self.feature_vocab.len()];
for feats in self.features.iter() {
for f_i in feats.iter() {
counts[*f_i] += 1;
}
}
counts
}
/// Removes features which don't meet the provided `count`. This is helpful to prevent one-off
/// occurences of words acting as node biasesand otherwise harming the quality of the
/// embeddings.
pub fn prune_min_count(&self, count: usize) -> FeatureStore {
let counts = self.count_features();
let mut new_fs = FeatureStore::new(self.features.len());
// Filter out features that don't meet the min_count
self.features.iter().enumerate().for_each(|(node_id, feats)| {
let new_feats = feats.iter()
.filter(|f_i| counts[**f_i] >= count)
.map(|f_i| {
let (nt, nn) = self.feature_vocab.get_name(*f_i)
.expect("Should never be unavailable!");
(ArcWrap(nt), nn)
});
new_fs.set_features(node_id, new_feats);
});
new_fs
}
}