Improve performance of byte_pair_merge (#31)
The improvements to `byte_pair_merge` are: - Changing the `parts` vector to avoid repetition of data. This vector used to store ranges for which the invariant `parts[i].end == parts[i + 1].start` holds, which makes the vector twice as big as it needs to be. Keeping this vector small improves CPU-cache efficiency. - Using `usize::MAX` as a sentinel in lieu of `Optional` for the computation of the minimum rank. This change removes branching from the loop to compute the minimum rank, generating assembly that uses conditional moves instead. Ideally, we could keep the `Optional` and inform it of the sentinel much like `Optional<NonZeroUsize>`. As far as I could tell, specifying custom sentinels for `Optional` has an old Rust [RFC](https://github.com/rust-lang/rfcs/pull/41) that has stalled, so we don't get to have nice things. - Minimizing the number of lookups into `ranks` by looking up ranks once and iteratively updating them after each merge. This reduces the number of rank lookups from `n*m` to `n + O(m)`
This commit is contained in:
parent
7830ed537b
commit
c4b8770184
103
src/lib.rs
103
src/lib.rs
@ -11,11 +11,58 @@ use pyo3::types::{PyBytes, PyList, PyTuple};
|
||||
use pyo3::PyResult;
|
||||
use rustc_hash::FxHashMap as HashMap;
|
||||
|
||||
fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::ops::Range<usize>> {
|
||||
let mut parts: Vec<_> = (0..piece.len()).map(|i| i..i + 1).collect();
|
||||
fn _byte_pair_merge<T>(
|
||||
piece: &[u8],
|
||||
ranks: &HashMap<Vec<u8>, usize>,
|
||||
f: impl Fn(std::ops::Range<usize>) -> T,
|
||||
) -> Vec<T> {
|
||||
// This is a vector of (start, rank).
|
||||
// The rank is of the byte pair starting at position start.
|
||||
// The rank of the last item in the vector is not a valid value.
|
||||
let mut parts: Vec<(usize, usize)> = (0..piece.len() + 1).map(|i| (i, usize::MAX)).collect();
|
||||
|
||||
// If you have n parts and m merges, this does O(mn) work
|
||||
// We could do something with a heap and do O(m log n) work
|
||||
// NOTE: using a macro here because a closure fails to get inlined
|
||||
// according to optimization remarks.
|
||||
// A closure also cannot capture a reference to `piece` without
|
||||
// the borrow checker complaining about the mutable borrows during
|
||||
// the assignments later in this code.
|
||||
macro_rules! get_rank {
|
||||
($start_idx:expr, $skip:expr) => {{
|
||||
let start_idx: usize = $start_idx;
|
||||
let skip: usize = $skip;
|
||||
if (start_idx + skip + 2) < parts.len() {
|
||||
ranks
|
||||
.get(&piece[parts[start_idx].0..parts[start_idx + skip + 2].0])
|
||||
.map(|r| *r)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}};
|
||||
($idx:expr) => {{
|
||||
get_rank!($idx, 0)
|
||||
}};
|
||||
}
|
||||
|
||||
// We look up the ranks once in the beggining and iteratively update
|
||||
// them during each merge, which reduces the number of rank lookups.
|
||||
for i in 0..parts.len() - 2 {
|
||||
match get_rank!(i) {
|
||||
Some(rank) => {
|
||||
// usize::MAX is a sentinel value and cannot be a valid rank
|
||||
debug_assert!(rank != usize::MAX);
|
||||
parts[i].1 = rank;
|
||||
}
|
||||
None => {
|
||||
continue;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// If you have n parts and m merges, this does O(mn) work.
|
||||
// We could do something with a heap and do O(m log n) work.
|
||||
// It is important to consider that n is often small (<100), and as such
|
||||
// the cache-locality benefits outweigh the algorithmic complexity downsides
|
||||
// of the `parts` vector data structure above.
|
||||
|
||||
// Note that we hash bytes, not token pairs. As long as we train BPE the way we
|
||||
// currently do, this is equivalent. An easy way to break this would be to decouple
|
||||
@ -24,45 +71,53 @@ fn _byte_pair_merge(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<std::o
|
||||
if parts.len() == 1 {
|
||||
break;
|
||||
}
|
||||
let mut min_rank: Option<(usize, usize)> = None;
|
||||
for i in 0..parts.len() - 1 {
|
||||
let rank = if let Some(r) = ranks.get(&piece[parts[i].start..parts[i + 1].end]) {
|
||||
*r
|
||||
} else {
|
||||
continue;
|
||||
};
|
||||
if min_rank.is_none() || rank < min_rank.unwrap().0 {
|
||||
min_rank = Some((rank, i));
|
||||
|
||||
// usize::MAX is a sentinel rank value allowing us to
|
||||
// take the min more quickly
|
||||
let mut min_rank: (usize, usize) = (usize::MAX, 0);
|
||||
for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() {
|
||||
if rank < min_rank.0 {
|
||||
min_rank = (rank, i);
|
||||
}
|
||||
}
|
||||
if let Some((_, i)) = min_rank {
|
||||
parts[i] = parts[i].start..parts[i + 1].end;
|
||||
|
||||
if min_rank.0 != usize::MAX {
|
||||
let i = min_rank.1;
|
||||
|
||||
// NOTE: We are about to remove parts[i + 1]. We do not do it
|
||||
// yet because there are cache-locality benefits to updating
|
||||
// parts[i] and parts[i-1] before removing, which could thrash
|
||||
// the cache. Thus, we update the rank calculation by skipping over
|
||||
// parts[i + 1], by invoking `get_rank!` with `skip = 1`.
|
||||
parts[i].1 = get_rank!(i, 1).unwrap_or(usize::MAX);
|
||||
if i > 0 {
|
||||
parts[i - 1].1 = get_rank!(i - 1, 1).unwrap_or(usize::MAX);
|
||||
}
|
||||
|
||||
parts.remove(i + 1);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
parts
|
||||
let mut out: Vec<T> = Vec::with_capacity(parts.len() - 1);
|
||||
for i in 0..parts.len() - 1 {
|
||||
out.push(f(parts[i].0..parts[i + 1].0));
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<usize> {
|
||||
if piece.len() == 1 {
|
||||
return vec![ranks[piece]];
|
||||
}
|
||||
_byte_pair_merge(piece, ranks)
|
||||
.iter()
|
||||
.map(|p| ranks[&piece[p.start..p.end]])
|
||||
.collect()
|
||||
_byte_pair_merge(piece, ranks, |p| ranks[&piece[p.start..p.end]])
|
||||
}
|
||||
|
||||
pub fn byte_pair_split<'a>(piece: &'a [u8], ranks: &HashMap<Vec<u8>, usize>) -> Vec<&'a [u8]> {
|
||||
if piece.len() == 1 {
|
||||
return vec![piece];
|
||||
}
|
||||
_byte_pair_merge(piece, ranks)
|
||||
.iter()
|
||||
.map(|p| &piece[p.start..p.end])
|
||||
.collect()
|
||||
_byte_pair_merge(piece, ranks, |p| &piece[p.start..p.end])
|
||||
}
|
||||
|
||||
// Various performance notes:
|
||||
|
Loading…
x
Reference in New Issue
Block a user