Skip to content

Instrument models

ConstantGain

Bases: GainModel

A constant gain model

Source code in src/jaxspec/model/instrument.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
class ConstantGain(GainModel):
    """
    A constant gain model
    """

    def __init__(self, prior_distribution: Distribution):
        """
        Parameters:
            prior_distribution: the prior distribution for the gain value.
        """

        self.prior_distribution = prior_distribution

    def numpyro_model(self, observation_name: str):
        factor = numpyro.sample(f"ins/~/gain_{observation_name}", self.prior_distribution)

        def gain(energy):
            return factor

        return gain

__init__(prior_distribution)

Parameters:

Name Type Description Default
prior_distribution Distribution

the prior distribution for the gain value.

required
Source code in src/jaxspec/model/instrument.py
25
26
27
28
29
30
31
def __init__(self, prior_distribution: Distribution):
    """
    Parameters:
        prior_distribution: the prior distribution for the gain value.
    """

    self.prior_distribution = prior_distribution

ConstantShift

Bases: ShiftModel

A constant shift model

Source code in src/jaxspec/model/instrument.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class ConstantShift(ShiftModel):
    """
    A constant shift model
    """

    def __init__(self, prior_distribution: Distribution):
        """
        Parameters:
            prior_distribution: the prior distribution for the shift value.
        """
        self.prior_distribution = prior_distribution

    def numpyro_model(self, observation_name: str):
        shift_offset = numpyro.sample(f"ins/~/shift_{observation_name}", self.prior_distribution)

        def shift(energy):
            return energy + shift_offset

        return shift

__init__(prior_distribution)

Parameters:

Name Type Description Default
prior_distribution Distribution

the prior distribution for the shift value.

required
Source code in src/jaxspec/model/instrument.py
57
58
59
60
61
62
def __init__(self, prior_distribution: Distribution):
    """
    Parameters:
        prior_distribution: the prior distribution for the shift value.
    """
    self.prior_distribution = prior_distribution

GainModel

Bases: ABC, Module

Generic class for a gain model

Source code in src/jaxspec/model/instrument.py
10
11
12
13
14
15
16
17
class GainModel(ABC, nnx.Module):
    """
    Generic class for a gain model
    """

    @abstractmethod
    def numpyro_model(self, observation_name: str):
        pass

InstrumentModel

Bases: Module

Source code in src/jaxspec/model/instrument.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
class InstrumentModel(nnx.Module):
    def __init__(
        self,
        reference_observation_name: str,
        gain_model: GainModel | None = None,
        shift_model: ShiftModel | None = None,
    ):
        """
        Encapsulate an instrument model, build as a combination of a shift and gain model.

        Parameters:
            reference_observation_name : The observation to use as a reference
            gain_model : The gain model
            shift_model : The shift model
        """

        self.reference = reference_observation_name
        self.gain_model = gain_model
        self.shift_model = shift_model

    def get_gain_and_shift_model(
        self, observation_name: str
    ) -> tuple[Callable | None, Callable | None]:
        """
        Return the gain and shift models for the given observation. It should be called within a numpyro model.
        """

        if observation_name == self.reference:
            return None, None

        else:
            gain = (
                self.gain_model.numpyro_model(observation_name)
                if self.gain_model is not None
                else None
            )
            shift = (
                self.shift_model.numpyro_model(observation_name)
                if self.shift_model is not None
                else None
            )

            return gain, shift

__init__(reference_observation_name, gain_model=None, shift_model=None)

Encapsulate an instrument model, build as a combination of a shift and gain model.

Parameters:

Name Type Description Default
reference_observation_name

The observation to use as a reference

required
gain_model

The gain model

required
shift_model

The shift model

required
Source code in src/jaxspec/model/instrument.py
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
def __init__(
    self,
    reference_observation_name: str,
    gain_model: GainModel | None = None,
    shift_model: ShiftModel | None = None,
):
    """
    Encapsulate an instrument model, build as a combination of a shift and gain model.

    Parameters:
        reference_observation_name : The observation to use as a reference
        gain_model : The gain model
        shift_model : The shift model
    """

    self.reference = reference_observation_name
    self.gain_model = gain_model
    self.shift_model = shift_model

get_gain_and_shift_model(observation_name)

Return the gain and shift models for the given observation. It should be called within a numpyro model.

Source code in src/jaxspec/model/instrument.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def get_gain_and_shift_model(
    self, observation_name: str
) -> tuple[Callable | None, Callable | None]:
    """
    Return the gain and shift models for the given observation. It should be called within a numpyro model.
    """

    if observation_name == self.reference:
        return None, None

    else:
        gain = (
            self.gain_model.numpyro_model(observation_name)
            if self.gain_model is not None
            else None
        )
        shift = (
            self.shift_model.numpyro_model(observation_name)
            if self.shift_model is not None
            else None
        )

        return gain, shift

ShiftModel

Bases: ABC, Module

Generic class for a shift model

Source code in src/jaxspec/model/instrument.py
42
43
44
45
46
47
48
49
class ShiftModel(ABC, nnx.Module):
    """
    Generic class for a shift model
    """

    @abstractmethod
    def numpyro_model(self, observation_name: str):
        pass