Optimizing Repeated Correlations
post by SatvikBeri · 2024-08-01T17:33:23.823Z · LW · GW · 1 commentsContents
1 comment
At my work, we run experiments – we specify some set of input parameters, run some code, and get various metrics as output. Since we run so many of these, it's important for them to be fast and cheap.
Recently I was working on an experiment type that took about ~1 hour per run, where the slow part was calculating correlations. A simplified version looks like this:
a_length = 1_000_000
a = rand(a_length)
b = rand(a_length)
c = rand(a_length)
xs = [rand(a_length) for i in 1:1000]
function get_correlations1(xs, a, b, c)
return [[cor(x, y) for y in [a, b, c]] for x in xs]
end
@btime correlations = get_correlations1($xs, $a, $b, $c)
> 4.563 s (2001 allocations: 164.19 KiB)
I wondered if we could use the fact that a, b, c
were constant throughout the loops to our advantage, and looked up various ways of calculating correlations. Searching online, I found several formulas for sample correlation, and this was the most useful:
The benefit of this version is that if we are repeatedly using a Y
, we can cache instead of recalculating it in every loop. Translated to code, this looks something like:
function zscores(x)
return (x .- mean(x)) / std(x)
end
function zscores!(x, buffer)
μ = mean(x)
σ = std(x; mean=μ)
buffer .= (x .- μ)./σ
return buffer
end
function get_correlations2(xs, a, b, c)
la = length(a) - 1
za, zb, zc = zscores.([a, b, c]) ./ la
output = Vector{Float64}[]
buffer = zero(za)
for x in xs
zx = zscores!(x, buffer)
push!(output, [zx' * y for y in [za, zb, zc]])
end
return output
end
@btime correlations2 = get_correlations2($xs, $a, $b, $c);
> 3.197 s (11028 allocations: 76.62 MiB)
And a sanity check to make sure the calculations match:
all(isapprox.(get_correlations2(xs, a, b, c), get_correlations1(xs, a, b, c)))
> true
This cuts out about 33% of the runtime, and the results seem to be better for larger datasets – in production, I'm saving closer to 60%.
1 comments
Comments sorted by top scores.
comment by SatvikBeri · 2024-08-03T16:25:36.130Z · LW(p) · GW(p)
There was a serious bug in this post that invalidated the results, so I took it down for a while. The bug has now been fixed and the posted results should be correct.