Source code for optimal_transport_reweighting
# -*- coding: utf-8 -*-
"""
Created on Sun Dec 12 04:14:07 2021
@author: maout
"""
# optimal transport multidimensional reweighting
from pyemd import emd_with_flow
import numpy as np
from scipy.spatial.distance import pdist, squareform
__all__ = ["reweight_optimal_transport_multidim"]
[docs]def reweight_optimal_transport_multidim(samples, weights):
"""
Computes deterministic transport map for particle reweighting.
Particle state is multidimensional.
Parameters
------------
samples: array-like,
Samples from distribution M x dim , with dim>=2.
weights: array-like,
weights for each sample M.
Returns
--------
T: array like,
transport map.
Reweighting particles according to ensemble transform particle filter
(ETPF) algorithm proposed by `Reich 2013`.
Instead of particle resampling, ETPF employes Optimal Transport to
compute a deterministic particle shift which minimises the
expected distances between the particles before and after the transformation.
:math: `CO = X^T \\cdot X`
:math: `CO = diag(CO)*ones(1,M) -2*CO + ones(M,1)*diag(CO)'`
:math: `[dist,T] = emd(ww,ones(M,1)/M,CO,-1,3)`
:math: `T = T \\cdot M`
"""
num_samples = samples.shape[0] ## this should be the number of points
covar = squareform(pdist(samples, 'euclidean'))
b = np.ones((num_samples, 1)) / num_samples # uniform distribution on samples
_, T = emd_with_flow(weights.reshape(-1, ), b.reshape(-1, ), covar, -1)
T = np.array(T)*num_samples
return T #%%