from __future__ import annotations
import numpy as np
from ._array_backend import xp
from . import data_stream
[docs]
class RealQuantizer(object):
"""Quantize real-valued voltages to integer levels."""
[docs]
def __init__(self,
target_mean: float=0,
target_fwhm: float=32,
num_bits: int=8,
stats_calc_period: int=1,
stats_calc_num_samples: int=10000) -> None:
"""Initialize a real-valued quantizer.
Args:
target_mean: Target mean of the quantized values.
target_fwhm: Target full width at half maximum.
num_bits: Number of quantization bits.
stats_calc_period: Period for recomputing input statistics.
stats_calc_num_samples: Maximum samples used for statistic estimates.
"""
self.target_mean = target_mean
self.target_fwhm = target_fwhm
self.target_std = self.target_fwhm / (2 * np.sqrt(2 * np.log(2)))
self.num_bits = num_bits
self.stats_cache = [None, None]
self.stats_calc_indices = 0
self.stats_calc_period = stats_calc_period
self.stats_calc_num_samples = stats_calc_num_samples
def _reset_cache(self) -> None:
"""Clear cached statistics and update counters."""
self.stats_calc_indices = 0
self.stats_cache = [None, None]
def _set_target_stats(self, target_mean: float, target_std: float) -> None:
"""Set target mean and standard deviation for the quantizer.
Args:
target_mean: Target mean.
target_std: Target standard deviation.
"""
self.target_mean = target_mean
self.target_std = target_std
self.target_fwhm = target_std * (2 * np.sqrt(2 * np.log(2)))
[docs]
def quantize(self, voltages: xp.ndarray, custom_std: float | None = None) -> xp.ndarray:
"""Quantize real input voltages.
Args:
voltages: Array of real voltages.
custom_std: Optional standard deviation to use for scaling.
Returns:
Array of quantized voltages.
"""
if self.stats_calc_indices == 0:
self.stats_cache = data_stream.estimate_stats(voltages, self.stats_calc_num_samples)
if custom_std is not None:
data_std = custom_std
else:
data_std = self.stats_cache[1]
q_voltages = quantize_real(voltages,
target_mean=self.target_mean,
target_std=self.target_std,
num_bits=self.num_bits,
data_mean=self.stats_cache[0],
data_std=data_std,
stats_calc_num_samples=self.stats_calc_num_samples)
self.stats_calc_indices += 1
if self.stats_calc_indices == self.stats_calc_period:
self.stats_calc_indices = 0
return q_voltages
[docs]
def digitize(self, voltages: xp.ndarray, custom_std: float | None = None) -> xp.ndarray:
"""Alias for `quantize()`.
Args:
voltages: Array of real voltages.
custom_std: Optional standard deviation to use for scaling.
Returns:
Array of quantized voltages.
"""
return self.quantize(voltages, custom_std=custom_std)
[docs]
class ComplexQuantizer(object):
"""Quantize complex voltages using paired real-valued quantizers."""
[docs]
def __init__(self,
target_mean: float=0,
target_fwhm: float=32,
num_bits: int=8,
stats_calc_period: int=1,
stats_calc_num_samples: int=10000) -> None:
"""Initialize a complex-valued quantizer.
Args:
target_mean: Target mean of the quantized values.
target_fwhm: Target full width at half maximum.
num_bits: Number of quantization bits.
stats_calc_period: Period for recomputing input statistics.
stats_calc_num_samples: Maximum samples used for statistic estimates.
"""
self.target_mean = target_mean
self.target_fwhm = target_fwhm
self.target_std = self.target_fwhm / (2 * np.sqrt(2 * np.log(2)))
self.num_bits = num_bits
self.stats_cache_r = [None, None]
self.stats_cache_i = [None, None]
self.stats_calc_period = stats_calc_period
self.stats_calc_num_samples = stats_calc_num_samples
self.quantizer_r = RealQuantizer(target_mean=target_mean,
target_fwhm=target_fwhm,
num_bits=num_bits,
stats_calc_period=stats_calc_period,
stats_calc_num_samples=stats_calc_num_samples)
self.quantizer_i = RealQuantizer(target_mean=target_mean,
target_fwhm=target_fwhm,
num_bits=num_bits,
stats_calc_period=stats_calc_period,
stats_calc_num_samples=stats_calc_num_samples)
def _reset_cache(self) -> None:
"""Clear cached statistics and update counters."""
self.stats_cache_r = [None, None]
self.stats_cache_i = [None, None]
self.quantizer_r._reset_cache()
self.quantizer_i._reset_cache()
[docs]
def quantize(self,
voltages: xp.ndarray,
custom_stds: float | list[float | None] | tuple[float | None, float | None] | None = None) -> xp.ndarray:
"""Quantize complex input voltages.
Args:
voltages: Array of complex voltages.
custom_stds: Optional scalar or length-two sequence of standard
deviations for the real and imaginary parts.
Returns:
Array of complex quantized voltages.
Raises:
ValueError: If `custom_stds` is not scalar or length two.
"""
if custom_stds is None or np.isscalar(custom_stds):
custom_stds = [custom_stds] * 2
else:
if len(custom_stds) != 2:
raise ValueError("custom_stds must be a scalar or a length-2 sequence.")
q_r = self.quantizer_r.quantize(xp.real(voltages), custom_std=custom_stds[0])
q_i = self.quantizer_i.quantize(xp.imag(voltages), custom_std=custom_stds[1])
self.stats_cache_r = self.quantizer_r.stats_cache
self.stats_cache_i = self.quantizer_i.stats_cache
return q_r + q_i * 1j
[docs]
def quantize_real(x: xp.ndarray,
target_mean: float = 0,
target_std: float = 32 / (2 * np.sqrt(2 * np.log(2))),
num_bits: int = 8,
data_mean: float | None = None,
data_std: float | None = None,
stats_calc_num_samples: int = 10000) -> xp.ndarray:
"""Quantize real voltages to integer levels.
Args:
x: Array of voltages.
target_mean: Target mean for quantized voltages.
target_std: Target standard deviation for quantized voltages.
num_bits: Number of quantization bits.
data_mean: Optional precomputed input mean.
data_std: Optional precomputed input standard deviation.
stats_calc_num_samples: Maximum samples used for statistic estimates.
Returns:
Array of quantized voltages.
"""
if data_std is None:
data_mean, data_std = data_stream.estimate_stats(x, stats_calc_num_samples)
if data_std == 0:
factor = 0
else:
factor = target_std / data_std
q_voltages = xp.around(factor * (x - data_mean) + target_mean)
q_voltages = xp.clip(q_voltages, -2**(num_bits - 1), 2**(num_bits - 1) - 1)
q_voltages = q_voltages.astype(int)
return q_voltages
[docs]
def quantize_complex(x: xp.ndarray,
target_mean: float = 0,
target_std: float = 32 / (2 * np.sqrt(2 * np.log(2))),
num_bits: int = 8,
stats_calc_num_samples: int = 10000) -> xp.ndarray:
"""Quantize complex voltages to integer levels.
Args:
x: Array of complex voltages.
target_mean: Target mean for quantized voltages.
target_std: Target standard deviation for quantized voltages.
num_bits: Number of quantization bits.
stats_calc_num_samples: Maximum samples used for statistic estimates.
Returns:
Array of complex quantized voltages.
"""
r, i = xp.real(x), xp.imag(x)
q_r = quantize_real(r,
target_mean=target_mean,
target_std=target_std,
num_bits=num_bits,
stats_calc_num_samples=stats_calc_num_samples)
q_i = quantize_real(i,
target_mean=target_mean,
target_std=target_std,
num_bits=num_bits,
stats_calc_num_samples=stats_calc_num_samples)
q_c = q_r + q_i * 1j
return q_c