Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyrecest/filters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .abstract_axial_filter import AbstractAxialFilter
from .axial_kalman_filter import AxialKalmanFilter
from .abstract_dummy_filter import AbstractDummyFilter
from .abstract_extended_object_tracker import AbstractExtendedObjectTracker
from .abstract_filter import AbstractFilter
Expand Down Expand Up @@ -51,6 +52,7 @@
__all__ = [
"AbstractDummyFilter",
"AbstractAxialFilter",
"AxialKalmanFilter",
"AbstractExtendedObjectTracker",
"AbstractFilter",
"BinghamFilter",
Expand Down
108 changes: 108 additions & 0 deletions pyrecest/filters/axial_kalman_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# pylint: disable=no-name-in-module,no-member
import copy

# pylint: disable=redefined-builtin
from pyrecest.backend import abs, concatenate, dot, eye, linalg
from pyrecest.distributions import GaussianDistribution

from .abstract_axial_filter import AbstractAxialFilter


class AxialKalmanFilter(AbstractAxialFilter):
"""Kalman Filter for directional estimation with antipodal symmetry.

Works for antipodally symmetric complex numbers (2D unit vectors) and
quaternions (4D unit vectors).

References:
- Gerhard Kurz, Igor Gilitschenski, Simon Julier, Uwe D. Hanebeck,
Recursive Bingham Filter for Directional Estimation Involving 180
Degree Symmetry, Journal of Advances in Information Fusion,
9(2):90-105, December 2014.
"""

def __init__(self):
from pyrecest.backend import array

initial_state = GaussianDistribution(
array([1.0, 0.0, 0.0, 0.0]),
eye(4),
)
AbstractAxialFilter.__init__(self, initial_state)
self._set_composition_operator()

@property
def dim(self):
"""Manifold dimension (1 for complex/circle, 3 for quaternions)."""
return self._filter_state.dim - 1

@property
def filter_state(self):
return self._filter_state

@filter_state.setter
def filter_state(self, new_state):
assert isinstance(
new_state, GaussianDistribution
), "filter_state must be a GaussianDistribution"
assert new_state.mu.shape[0] in (2, 4), "Only 2D and 4D states are supported"
assert (
abs(linalg.norm(new_state.mu) - 1) < 1e-5
), "mean must be a unit vector"
self._filter_state = copy.deepcopy(new_state)
self._set_composition_operator()

def predict_identity(self, gauss_w):
"""Predict assuming identity system model with noise gauss_w.

Computes x(k+1) = x(k) ⊕ w(k), where ⊕ is complex or quaternion
multiplication.

Parameters:
gauss_w (GaussianDistribution): system noise with unit vector mean
"""
assert isinstance(gauss_w, GaussianDistribution)
assert (
abs(linalg.norm(gauss_w.mu) - 1) < 1e-5
), "noise mean must be a unit vector"
mu_new = self.composition_operator(self._filter_state.mu, gauss_w.mu)
C_new = self._filter_state.C + gauss_w.C
self._filter_state = GaussianDistribution(mu_new, C_new, check_validity=False)

def update_identity(self, gauss_v, z):
"""Update assuming identity measurement model with noise gauss_v.

Computes z(k) = x(k) ⊕ v(k), where ⊕ is complex or quaternion
multiplication.

Parameters:
gauss_v (GaussianDistribution): measurement noise with unit vector mean
z (array): measurement as a unit vector of shape (2,) or (4,)
"""
assert isinstance(gauss_v, GaussianDistribution)
assert (
abs(linalg.norm(gauss_v.mu) - 1) < 1e-5
), "noise mean must be a unit vector"
assert gauss_v.mu.shape[0] == self._filter_state.mu.shape[0]
assert z.shape == self._filter_state.mu.shape
assert abs(linalg.norm(z) - 1) < 1e-5, "measurement must be a unit vector"

# Conjugate of noise mean: negate all but the first component
mu_v_conj = concatenate([gauss_v.mu[:1], -gauss_v.mu[1:]])
z = self.composition_operator(mu_v_conj, z)

if dot(z, self._filter_state.mu) < 0:
z = -z

d = self._filter_state.dim # embedding dimension (2 or 4)
IS = self._filter_state.C + gauss_v.C # innovation covariance (H = I)
K = linalg.solve(IS, self._filter_state.C).T # Kalman gain: C @ inv(IS)
mu_new = self._filter_state.mu + K @ (z - self._filter_state.mu)
C_new = (eye(d) - K) @ self._filter_state.C

mu_new = mu_new / linalg.norm(mu_new) # enforce unit vector
self._filter_state = GaussianDistribution(mu_new, C_new, check_validity=False)

def get_point_estimate(self):
"""Return the mean of the current filter state."""
return self._filter_state.mu
225 changes: 225 additions & 0 deletions pyrecest/tests/filters/test_axial_kalman_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import unittest

import numpy.testing as npt

# pylint: disable=no-name-in-module,no-member
import pyrecest.backend
from pyrecest.backend import array, eye, linalg
from pyrecest.distributions import GaussianDistribution
from pyrecest.filters.abstract_axial_filter import (
_complex_multiplication,
_quaternion_multiplication,
)
from pyrecest.filters.axial_kalman_filter import AxialKalmanFilter


class TestAxialKalmanFilter4D(unittest.TestCase):
def setUp(self):
mu = array([1.0, 2.0, 3.0, 4.0])
mu = mu / linalg.norm(mu)
C = 0.3 * eye(4)
self.mu = mu
self.C = C
self.filter = AxialKalmanFilter()

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_set_state_and_get_estimate(self):
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
est = self.filter.get_point_estimate()
npt.assert_array_equal(self.mu, est)
npt.assert_array_equal(self.C, self.filter.filter_state.C)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_predict_identity_zero_mean(self):
"""Predicting with identity-rotation noise should not change the mean."""
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
noise_mu = array([1.0, 0.0, 0.0, 0.0])
self.filter.predict_identity(
GaussianDistribution(noise_mu, 0.1 * eye(4))
)
est = self.filter.get_point_estimate()
npt.assert_allclose(self.mu, est, atol=1e-10)
# Covariance should increase
self.assertTrue(
(self.filter.filter_state.C >= self.C).all()
if hasattr((self.filter.filter_state.C >= self.C), "all")
else all(
self.filter.filter_state.C.flatten() >= self.C.flatten()
)
)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_predict_identity_nonzero_mean(self):
"""Predicting with non-identity rotation noise updates the mean correctly."""
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
self.filter.predict_identity(
GaussianDistribution(self.mu, 0.1 * eye(4))
)
est = self.filter.get_point_estimate()
expected = _quaternion_multiplication(self.mu, self.mu)
npt.assert_allclose(est, expected, atol=1e-10)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_update_identity_at_mode(self):
"""Updating with z=mu and identity noise should keep the mean and reduce C."""
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
z = self.mu
self.filter.update_identity(
GaussianDistribution(array([1.0, 0.0, 0.0, 0.0]), self.C), z
)
est = self.filter.get_point_estimate()
npt.assert_allclose(self.mu, est, atol=1e-6)
# Covariance should decrease
self.assertTrue(
(self.filter.filter_state.C <= self.C).all()
if hasattr((self.filter.filter_state.C <= self.C), "all")
else all(
self.filter.filter_state.C.flatten() <= self.C.flatten()
)
)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_update_identity_antipodal_symmetry(self):
"""z and -z (antipodal) should produce the same result."""
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
noise = GaussianDistribution(array([1.0, 0.0, 0.0, 0.0]), self.C)
self.filter.update_identity(noise, self.mu)
mu4 = self.filter.get_point_estimate()
C4 = self.filter.filter_state.C

self.filter.filter_state = GaussianDistribution(self.mu, self.C)
self.filter.update_identity(noise, -self.mu)
mu5 = self.filter.get_point_estimate()
C5 = self.filter.filter_state.C

npt.assert_allclose(mu4, mu5, atol=1e-10)
npt.assert_allclose(C4, C5, atol=1e-10)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_update_identity_unit_norm(self):
"""After updating with a non-mode measurement the mean is a unit vector."""
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
z = array([0.0, 0.0, 0.0, 1.0])
self.filter.update_identity(
GaussianDistribution(self.mu, self.C), z
)
est = self.filter.get_point_estimate()
npt.assert_allclose(linalg.norm(est), 1.0, atol=1e-6)


class TestAxialKalmanFilter2D(unittest.TestCase):
def setUp(self):
mu = array([1.0, 2.0])
mu = mu / linalg.norm(mu)
C = 0.3 * eye(2)
self.mu = mu
self.C = C
self.filter = AxialKalmanFilter()

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_set_state_and_get_estimate(self):
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
est = self.filter.get_point_estimate()
npt.assert_array_equal(self.mu, est)
npt.assert_array_equal(self.C, self.filter.filter_state.C)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_predict_identity_zero_mean(self):
"""Predicting with identity-rotation noise should not change the mean."""
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
noise_mu = array([1.0, 0.0])
self.filter.predict_identity(
GaussianDistribution(noise_mu, 0.1 * eye(2))
)
est = self.filter.get_point_estimate()
npt.assert_allclose(self.mu, est, atol=1e-10)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_predict_identity_nonzero_mean(self):
"""Predicting with non-identity rotation noise updates the mean correctly."""
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
self.filter.predict_identity(
GaussianDistribution(self.mu, 0.1 * eye(2))
)
est = self.filter.get_point_estimate()
expected = _complex_multiplication(self.mu, self.mu)
npt.assert_allclose(est, expected, atol=1e-10)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_update_identity_at_mode(self):
"""Updating with z=mu and identity noise should keep the mean and reduce C."""
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
z = self.mu
self.filter.update_identity(
GaussianDistribution(array([1.0, 0.0]), self.C), z
)
est = self.filter.get_point_estimate()
npt.assert_allclose(self.mu, est, atol=1e-6)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_update_identity_antipodal_symmetry(self):
"""z and -z (antipodal) should produce the same result."""
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
noise = GaussianDistribution(array([1.0, 0.0]), self.C)
self.filter.update_identity(noise, self.mu)
mu4 = self.filter.get_point_estimate()
C4 = self.filter.filter_state.C

self.filter.filter_state = GaussianDistribution(self.mu, self.C)
self.filter.update_identity(noise, -self.mu)
mu5 = self.filter.get_point_estimate()
C5 = self.filter.filter_state.C

npt.assert_allclose(mu4, mu5, atol=1e-10)
npt.assert_allclose(C4, C5, atol=1e-10)

@unittest.skipIf(
pyrecest.backend.__backend_name__ == "pytorch",
reason="Not supported on this backend", # pylint: disable=no-member
)
def test_update_identity_unit_norm(self):
"""After updating with a non-mode measurement the mean is a unit vector."""
self.filter.filter_state = GaussianDistribution(self.mu, self.C)
z = array([0.0, 1.0])
self.filter.update_identity(
GaussianDistribution(self.mu, self.C), z
)
est = self.filter.get_point_estimate()
npt.assert_allclose(linalg.norm(est), 1.0, atol=1e-10)


if __name__ == "__main__":
unittest.main()
Loading