Skip to content

Commit 49d2441

Browse files
committed
bench: rewrite nbody for better vectorization
1 parent 3cee9e2 commit 49d2441

File tree

1 file changed

+132
-87
lines changed

1 file changed

+132
-87
lines changed

src/test/bench/shootout-nbody.rs

Lines changed: 132 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -39,134 +39,175 @@
3939
// OF THE POSSIBILITY OF SUCH DAMAGE.
4040

4141
use std::mem;
42+
use std::ops::{Add, Sub, Mul};
4243

4344
const PI: f64 = 3.141592653589793;
4445
const SOLAR_MASS: f64 = 4.0 * PI * PI;
4546
const YEAR: f64 = 365.24;
4647
const N_BODIES: usize = 5;
48+
const N_PAIRS: usize = N_BODIES * (N_BODIES - 1) / 2;
4749

48-
static BODIES: [Planet;N_BODIES] = [
50+
const BODIES: [Planet; N_BODIES] = [
4951
// Sun
5052
Planet {
51-
x: 0.0, y: 0.0, z: 0.0,
52-
vx: 0.0, vy: 0.0, vz: 0.0,
53+
pos: Vec3(0.0, 0.0, 0.0),
54+
vel: Vec3(0.0, 0.0, 0.0),
5355
mass: SOLAR_MASS,
5456
},
5557
// Jupiter
5658
Planet {
57-
x: 4.84143144246472090e+00,
58-
y: -1.16032004402742839e+00,
59-
z: -1.03622044471123109e-01,
60-
vx: 1.66007664274403694e-03 * YEAR,
61-
vy: 7.69901118419740425e-03 * YEAR,
62-
vz: -6.90460016972063023e-05 * YEAR,
59+
pos: Vec3(4.84143144246472090e+00,
60+
-1.16032004402742839e+00,
61+
-1.03622044471123109e-01),
62+
vel: Vec3(1.66007664274403694e-03 * YEAR,
63+
7.69901118419740425e-03 * YEAR,
64+
-6.90460016972063023e-05 * YEAR),
6365
mass: 9.54791938424326609e-04 * SOLAR_MASS,
6466
},
6567
// Saturn
6668
Planet {
67-
x: 8.34336671824457987e+00,
68-
y: 4.12479856412430479e+00,
69-
z: -4.03523417114321381e-01,
70-
vx: -2.76742510726862411e-03 * YEAR,
71-
vy: 4.99852801234917238e-03 * YEAR,
72-
vz: 2.30417297573763929e-05 * YEAR,
69+
pos: Vec3(8.34336671824457987e+00,
70+
4.12479856412430479e+00,
71+
-4.03523417114321381e-01),
72+
vel: Vec3(-2.76742510726862411e-03 * YEAR,
73+
4.99852801234917238e-03 * YEAR,
74+
2.30417297573763929e-05 * YEAR),
7375
mass: 2.85885980666130812e-04 * SOLAR_MASS,
7476
},
7577
// Uranus
7678
Planet {
77-
x: 1.28943695621391310e+01,
78-
y: -1.51111514016986312e+01,
79-
z: -2.23307578892655734e-01,
80-
vx: 2.96460137564761618e-03 * YEAR,
81-
vy: 2.37847173959480950e-03 * YEAR,
82-
vz: -2.96589568540237556e-05 * YEAR,
79+
pos: Vec3(1.28943695621391310e+01,
80+
-1.51111514016986312e+01,
81+
-2.23307578892655734e-01),
82+
vel: Vec3(2.96460137564761618e-03 * YEAR,
83+
2.37847173959480950e-03 * YEAR,
84+
-2.96589568540237556e-05 * YEAR),
8385
mass: 4.36624404335156298e-05 * SOLAR_MASS,
8486
},
8587
// Neptune
8688
Planet {
87-
x: 1.53796971148509165e+01,
88-
y: -2.59193146099879641e+01,
89-
z: 1.79258772950371181e-01,
90-
vx: 2.68067772490389322e-03 * YEAR,
91-
vy: 1.62824170038242295e-03 * YEAR,
92-
vz: -9.51592254519715870e-05 * YEAR,
89+
pos: Vec3(1.53796971148509165e+01,
90+
-2.59193146099879641e+01,
91+
1.79258772950371181e-01),
92+
vel: Vec3(2.68067772490389322e-03 * YEAR,
93+
1.62824170038242295e-03 * YEAR,
94+
-9.51592254519715870e-05 * YEAR),
9395
mass: 5.15138902046611451e-05 * SOLAR_MASS,
9496
},
9597
];
9698

97-
#[derive(Copy, Clone)]
99+
/// A 3d Vector type with oveloaded operators to improve readability.
100+
#[derive(Clone, Copy)]
101+
struct Vec3(pub f64, pub f64, pub f64);
102+
103+
impl Vec3 {
104+
fn zero() -> Self { Vec3(0.0, 0.0, 0.0) }
105+
106+
fn norm(&self) -> f64 { self.squared_norm().sqrt() }
107+
108+
fn squared_norm(&self) -> f64 {
109+
self.0 * self.0 + self.1 * self.1 + self.2 * self.2
110+
}
111+
}
112+
113+
impl Add for Vec3 {
114+
type Output = Self;
115+
fn add(self, rhs: Self) -> Self {
116+
Vec3(self.0 + rhs.0, self.1 + rhs.1, self.2 + rhs.2)
117+
}
118+
}
119+
120+
impl Sub for Vec3 {
121+
type Output = Self;
122+
fn sub(self, rhs: Self) -> Self {
123+
Vec3(self.0 - rhs.0, self.1 - rhs.1, self.2 - rhs.2)
124+
}
125+
}
126+
127+
impl Mul<f64> for Vec3 {
128+
type Output = Self;
129+
fn mul(self, rhs: f64) -> Self {
130+
Vec3(self.0 * rhs, self.1 * rhs, self.2 * rhs)
131+
}
132+
}
133+
134+
#[derive(Clone, Copy)]
98135
struct Planet {
99-
x: f64, y: f64, z: f64,
100-
vx: f64, vy: f64, vz: f64,
136+
pos: Vec3,
137+
vel: Vec3,
101138
mass: f64,
102139
}
103140

104-
fn advance(bodies: &mut [Planet;N_BODIES], dt: f64, steps: isize) {
105-
for _ in 0..steps {
106-
let mut b_slice: &mut [_] = bodies;
107-
loop {
108-
let bi = match shift_mut_ref(&mut b_slice) {
109-
Some(bi) => bi,
110-
None => break
111-
};
112-
for bj in &mut *b_slice {
113-
let dx = bi.x - bj.x;
114-
let dy = bi.y - bj.y;
115-
let dz = bi.z - bj.z;
116-
117-
let d2 = dx * dx + dy * dy + dz * dz;
118-
let mag = dt / (d2 * d2.sqrt());
119-
120-
let massj_mag = bj.mass * mag;
121-
bi.vx -= dx * massj_mag;
122-
bi.vy -= dy * massj_mag;
123-
bi.vz -= dz * massj_mag;
124-
125-
let massi_mag = bi.mass * mag;
126-
bj.vx += dx * massi_mag;
127-
bj.vy += dy * massi_mag;
128-
bj.vz += dz * massi_mag;
129-
}
130-
bi.x += dt * bi.vx;
131-
bi.y += dt * bi.vy;
132-
bi.z += dt * bi.vz;
141+
/// Computes all pairwise position differences between the planets.
142+
fn pairwise_diffs(bodies: &[Planet; N_BODIES], diff: &mut [Vec3; N_PAIRS]) {
143+
let mut bodies = bodies.iter();
144+
let mut diff = diff.iter_mut();
145+
while let Some(bi) = bodies.next() {
146+
for bj in bodies.clone() {
147+
*diff.next().unwrap() = bi.pos - bj.pos;
148+
}
149+
}
150+
}
151+
152+
/// Computes the magnitude of the force between each pair of planets.
153+
fn magnitudes(diff: &[Vec3; N_PAIRS], dt: f64, mag: &mut [f64; N_PAIRS]) {
154+
for (mag, diff) in mag.iter_mut().zip(diff.iter()) {
155+
let d2 = diff.squared_norm();
156+
*mag = dt / (d2 * d2.sqrt());
157+
}
158+
}
159+
160+
/// Updates the velocities of the planets by computing their gravitational
161+
/// accelerations and performing one step of Euler integration.
162+
fn update_velocities(bodies: &mut [Planet; N_BODIES], dt: f64,
163+
diff: &mut [Vec3; N_PAIRS], mag: &mut [f64; N_PAIRS]) {
164+
pairwise_diffs(bodies, diff);
165+
magnitudes(&diff, dt, mag);
166+
167+
let mut bodies = &mut bodies[..];
168+
let mut mag = mag.iter();
169+
let mut diff = diff.iter();
170+
while let Some(bi) = shift_mut_ref(&mut bodies) {
171+
for bj in bodies.iter_mut() {
172+
let diff = *diff.next().unwrap();
173+
let mag = *mag.next().unwrap();
174+
bi.vel = bi.vel - diff * (bj.mass * mag);
175+
bj.vel = bj.vel + diff * (bi.mass * mag);
133176
}
134177
}
135178
}
136179

137-
fn energy(bodies: &[Planet;N_BODIES]) -> f64 {
180+
/// Advances the solar system by one timestep by first updating the
181+
/// velocities and then integrating the positions using the updated velocities.
182+
///
183+
/// Note: the `diff` & `mag` arrays are effectively scratch space. They're
184+
/// provided as arguments to avoid re-zeroing them every time `advance` is
185+
/// called.
186+
fn advance(mut bodies: &mut [Planet; N_BODIES], dt: f64,
187+
diff: &mut [Vec3; N_PAIRS], mag: &mut [f64; N_PAIRS]) {
188+
update_velocities(bodies, dt, diff, mag);
189+
for body in bodies.iter_mut() {
190+
body.pos = body.pos + body.vel * dt;
191+
}
192+
}
193+
194+
/// Computes the total energy of the solar system.
195+
fn energy(bodies: &[Planet; N_BODIES]) -> f64 {
138196
let mut e = 0.0;
139197
let mut bodies = bodies.iter();
140-
loop {
141-
let bi = match bodies.next() {
142-
Some(bi) => bi,
143-
None => break
144-
};
145-
e += (bi.vx * bi.vx + bi.vy * bi.vy + bi.vz * bi.vz) * bi.mass / 2.0;
146-
for bj in bodies.clone() {
147-
let dx = bi.x - bj.x;
148-
let dy = bi.y - bj.y;
149-
let dz = bi.z - bj.z;
150-
let dist = (dx * dx + dy * dy + dz * dz).sqrt();
151-
e -= bi.mass * bj.mass / dist;
152-
}
198+
while let Some(bi) = bodies.next() {
199+
e += bi.vel.squared_norm() * bi.mass / 2.0
200+
- bi.mass * bodies.clone()
201+
.map(|bj| bj.mass / (bi.pos - bj.pos).norm())
202+
.fold(0.0, |a, b| a + b);
153203
}
154204
e
155205
}
156206

157-
fn offset_momentum(bodies: &mut [Planet;N_BODIES]) {
158-
let mut px = 0.0;
159-
let mut py = 0.0;
160-
let mut pz = 0.0;
161-
for bi in bodies.iter() {
162-
px += bi.vx * bi.mass;
163-
py += bi.vy * bi.mass;
164-
pz += bi.vz * bi.mass;
165-
}
166-
let sun = &mut bodies[0];
167-
sun.vx = - px / SOLAR_MASS;
168-
sun.vy = - py / SOLAR_MASS;
169-
sun.vz = - pz / SOLAR_MASS;
207+
/// Offsets the sun's velocity to make the overall momentum of the system zero.
208+
fn offset_momentum(bodies: &mut [Planet; N_BODIES]) {
209+
let p = bodies.iter().fold(Vec3::zero(), |v, b| v + b.vel * b.mass);
210+
bodies[0].vel = p * (-1.0 / bodies[0].mass);
170211
}
171212

172213
fn main() {
@@ -178,11 +219,15 @@ fn main() {
178219
.unwrap_or(1000)
179220
};
180221
let mut bodies = BODIES;
222+
let mut diff = [Vec3::zero(); N_PAIRS];
223+
let mut mag = [0.0f64; N_PAIRS];
181224

182225
offset_momentum(&mut bodies);
183226
println!("{:.9}", energy(&bodies));
184227

185-
advance(&mut bodies, 0.01, n);
228+
for _ in (0..n) {
229+
advance(&mut bodies, 0.01, &mut diff, &mut mag);
230+
}
186231

187232
println!("{:.9}", energy(&bodies));
188233
}

0 commit comments

Comments
 (0)