Differentiation is important, and doing differentiation with computers is even more important for building modern machine learning models at scale.

But while the topic of rmad ( reverse mode automatic differentiation ) has been covered a billion times, in this post I wish to cover something more elegant – dual numbers. I was planning to write a full fledged blog about the underlying maths, however this post on the AMS website this link seems to be more than exhaustive for a gentle introduction on the topic. What I will be doing however is writing a basic implementation of some of these ideas in rust:

use std::collections::BTreeMap;
use std::ops::{Add, Sub, Mul, Div, AddAssign, SubAssign};
use std::fmt;
use num_traits::{Float, FromPrimitive, ToPrimitive, Pow};

#[derive(Debug, Clone)]
pub struct HyperDual<T: Float + AddAssign + SubAssign> {
    real: T,
    dual: BTreeMap<String, T>,
}

impl<T: Float + FromPrimitive + ToPrimitive + AddAssign + SubAssign> HyperDual<T> {
    pub fn new(real: T, dual: BTreeMap<String, T>) -> Self {
        HyperDual { real, dual }
    }

    pub fn get(&self, key: &str) -> T {
        *self.dual.get(key).unwrap_or(&T::zero())
    }

    pub fn set(&mut self, key: &str, value: T) {
        self.dual.insert(key.to_string(), value);
    }

    pub fn sin(self) -> Self {
        let real = self.real.sin();
        let cos_val = self.real.cos();
        let mut dual = BTreeMap::new();

        for (key, &value) in self.dual.iter() {
            dual.insert(key.clone(), cos_val * value);
        }

        HyperDual::new(real, dual)
    }

    pub fn cos(self) -> Self {
        let real = self.real.cos();
        let sin_val = self.real.sin();
        let mut dual = BTreeMap::new();

        for (key, &value) in self.dual.iter() {
            dual.insert(key.clone(), -sin_val * value);
        }

        HyperDual::new(real, dual)
    }
}

impl<T: Float + fmt::Display + AddAssign + SubAssign> fmt::Display for HyperDual<T> {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}", self.real)?;
        for (key, value) in self.dual.iter() {
            write!(f, " + {}{}", value, key)?;
        }
        Ok(())
    }
}

impl<T: Float + FromPrimitive + AddAssign + SubAssign> Add for HyperDual<T> {
    type Output = Self;

    fn add(mut self, rhs: Self) -> Self {
        let real = self.real + rhs.real;
        let dual = self.dual.clone();

        for (key, value) in rhs.dual {
            *self.dual.entry(key).or_insert(T::zero()) += value;
        }

        HyperDual::new(real, dual)
    }
}

impl<T: Float + FromPrimitive + AddAssign + SubAssign> Sub for HyperDual<T> {
    type Output = Self;

    fn sub(mut self, rhs: Self) -> Self {
        let real = self.real - rhs.real;
        let dual = self.dual.clone();

        for (key, value) in rhs.dual {
            *self.dual.entry(key).or_insert(T::zero()) -= value;
        }

        HyperDual::new(real, dual)
    }
}

impl<T: Float + FromPrimitive + AddAssign + SubAssign> Mul for HyperDual<T> {
    type Output = Self;

    fn mul(self, rhs: Self) -> Self {
        let real = self.real * rhs.real;
        let mut dual = BTreeMap::new();

        for (key, &value) in self.dual.iter() {
            dual.insert(key.clone(), self.real * rhs.get(key) + rhs.real * value);
        }

        for (key, &value) in rhs.dual.iter() {
            if !dual.contains_key(key) {
                dual.insert(key.clone(), self.real * value);
            }
        }

        HyperDual::new(real, dual)
    }
}

impl<T: Float + FromPrimitive + AddAssign + SubAssign> Div for HyperDual<T> {
    type Output = Self;

    fn div(self, rhs: Self) -> Self {
        let real = self.real / rhs.real;
        let mut dual = BTreeMap::new();

        for (key, &value) in self.dual.iter() {
            dual.insert(key.clone(), (rhs.real * value - self.real * rhs.get(key)) / (rhs.real * rhs.real));
        }

        HyperDual::new(real, dual)
    }
}

impl<T: Float + FromPrimitive + AddAssign + SubAssign> std::ops::Neg for HyperDual<T> {
    type Output = Self;

    fn neg(self) -> Self {
        HyperDual::new(-self.real, self.dual.iter().map(|(k, v)| (k.clone(), -(*v))).collect())
    }
}

impl<T: Float + FromPrimitive + AddAssign + SubAssign> Pow<T> for HyperDual<T> {
    type Output = Self;

    fn pow(self, rhs: T) -> Self {
        let real = self.real.powf(rhs);
        let deriv_multiplier = rhs * self.real.powf(rhs - T::from_f64(1.0).unwrap());
        let dual = self.dual.iter().map(|(k, v)| (k.clone(), deriv_multiplier * (*v))).collect();
        
        HyperDual::new(real, dual)
    }
}

fn main() {
    let x = HyperDual::new(3.0, vec![("x".to_string(), 1.0)].into_iter().collect());
    let y = HyperDual::new(-1.0, vec![("y".to_string(), 1.0)].into_iter().collect());
    let z = HyperDual::new(2.0, vec![("z".to_string(), 1.0)].into_iter().collect());

    let w = x * y.clone() * (y * z).sin();

    println!("Result: {w}");
}

This should yield 2.727892280477045 + 0.9092974268256817x - 0.23101126119419035y - 1.2484405096414273z, which corresponds to f(3,-1,2) + ∇f(3,-1,2) = -3sin(-2) + [-sin(-2), 3sin(-2) - 6cos(-2), 3cos(-2)]. This result is in line with the example given on the AMS website.

Adieu <3.