1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use std::sync::Arc;

use crate::{Backtrace, BacktraceHash};

use crate::CountAndSize;

// ----------------------------------------------------------------------------

/// A hash of a pointer address.
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq)]
pub struct PtrHash(u64);

impl nohash_hasher::IsEnabled for PtrHash {}

impl PtrHash {
    #[inline]
    pub fn new(ptr: *mut u8) -> Self {
        let hash = ahash::RandomState::with_seeds(1, 2, 3, 4).hash_one(ptr);
        Self(hash)
    }
}

// ----------------------------------------------------------------------------

/// Formatted backtrace.
///
/// Clones without allocating.
#[derive(Clone)]
pub struct ReadableBacktrace {
    /// Human-readable backtrace.
    readable: Arc<str>,
}

impl std::fmt::Display for ReadableBacktrace {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        self.readable.fmt(f)
    }
}

impl ReadableBacktrace {
    fn new(mut backtrace: Backtrace) -> Self {
        Self {
            readable: backtrace.format(),
        }
    }
}

// ----------------------------------------------------------------------------

/// Per-callstack statistics.
#[derive(Clone)]
pub struct CallstackStatistics {
    /// For when we print this statistic.
    pub readable_backtrace: ReadableBacktrace,

    /// If this was stochastically sampled - at what rate?
    ///
    /// A `stochastic_rate` of `10` means that we only sampled 1 in 10 allocations.
    ///
    /// (so this is actually an interval rather than rate…).
    pub stochastic_rate: usize,

    /// Live allocations at this callstack.
    ///
    /// You should multiply this by [`Self::stochastic_rate`] to get an estimate
    /// of the real data.
    pub extant: CountAndSize,
}

// ----------------------------------------------------------------------------

/// Track the callstacks of allocations.
pub struct AllocationTracker {
    /// Sample every N allocations. Must be power-of-two.
    stochastic_rate: usize,

    /// De-duplicated readable backtraces.
    readable_backtraces: nohash_hasher::IntMap<BacktraceHash, ReadableBacktrace>,

    /// Current live allocations.
    live_allocs: ahash::HashMap<PtrHash, BacktraceHash>,

    /// How much memory is allocated by each callstack?
    callstack_stats: nohash_hasher::IntMap<BacktraceHash, CountAndSize>,
}

impl AllocationTracker {
    pub fn with_stochastic_rate(stochastic_rate: usize) -> Self {
        assert!(stochastic_rate != 0);
        assert!(stochastic_rate.is_power_of_two());
        Self {
            stochastic_rate,
            readable_backtraces: Default::default(),
            live_allocs: Default::default(),
            callstack_stats: Default::default(),
        }
    }

    fn should_sample(&self, ptr: PtrHash) -> bool {
        ptr.0 & (self.stochastic_rate as u64 - 1) == 0
    }

    pub fn on_alloc(&mut self, ptr: PtrHash, size: usize) {
        if !self.should_sample(ptr) {
            return;
        }

        let unresolved_backtrace = Backtrace::new_unresolved();
        let hash = BacktraceHash::new(&unresolved_backtrace);

        self.readable_backtraces
            .entry(hash)
            .or_insert_with(|| ReadableBacktrace::new(unresolved_backtrace));

        {
            self.callstack_stats.entry(hash).or_default().add(size);
        }

        self.live_allocs.insert(ptr, hash);
    }

    pub fn on_dealloc(&mut self, ptr: PtrHash, size: usize) {
        if !self.should_sample(ptr) {
            return;
        }

        if let Some(hash) = self.live_allocs.remove(&ptr) {
            if let std::collections::hash_map::Entry::Occupied(mut entry) =
                self.callstack_stats.entry(hash)
            {
                let stats = entry.get_mut();
                stats.sub(size);

                // Free up some memory:
                if stats.size == 0 {
                    entry.remove();
                }
            }
        }
    }

    /// Return the `n` callstacks that currently is using the most memory.
    pub fn top_callstacks(&self, n: usize) -> Vec<CallstackStatistics> {
        let mut vec: Vec<_> = self
            .callstack_stats
            .iter()
            .filter(|(_hash, c)| c.count > 0)
            .filter_map(|(hash, c)| {
                Some(CallstackStatistics {
                    readable_backtrace: self.readable_backtraces.get(hash)?.clone(),
                    stochastic_rate: self.stochastic_rate,
                    extant: *c,
                })
            })
            .collect();
        vec.sort_by_key(|stats| -(stats.extant.size as i64));
        vec.truncate(n);
        vec.shrink_to_fit();
        vec
    }
}