"""
Module defining transfer functions
"""
from typing import List, Optional, Dict, Any, Union
from pydantic import validator, constr
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from resistics.common import Metadata
[docs]class Component(Metadata):
"""
Data class for a single component in a Transfer function
Example
-------
>>> from resistics.transfunc import Component
>>> component = Component(real=[1, 2, 3, 4, 5], imag=[-5, -4, -3, -2 , -1])
>>> component.get_value(0)
(1-5j)
>>> component.to_numpy()
array([1.-5.j, 2.-4.j, 3.-3.j, 4.-2.j, 5.-1.j])
"""
real: List[float]
"""The real part of the component"""
imag: List[float]
"""The complex part of the component"""
[docs] def get_value(self, eval_idx: int) -> complex:
"""Get the value for an evaluation frequency"""
return self.real[eval_idx] + 1j * self.imag[eval_idx]
[docs] def to_numpy(self) -> np.ndarray:
"""Get the component as a numpy complex array"""
return np.array(self.real) + 1j * np.array(self.imag)
[docs]def get_component_key(out_chan: str, in_chan: str) -> str:
"""
Get key for out channel and in channel combination in the solution
Parameters
----------
out_chan : str
The output channel
in_chan : str
The input channel
Returns
-------
str
The component key
Examples
--------
>>> from resistics.regression import get_component_key
>>> get_component_key("Ex", "Hy")
'ExHy'
"""
return f"{out_chan}{in_chan}"
[docs]class TransferFunction(Metadata):
"""
Define a generic transfer function
This class is a describes generic transfer function, including:
- The output channels for the transfer function
- The input channels for the transfer function
- The cross channels for the transfer function
The cross channels are the channels that will be used to calculate out the
cross powers for the regression.
This generic parent class has no implemented plotting function. However,
child classes may have a plotting function as different transfer functions
may need different types of plots.
.. note::
Users interested in writing a custom transfer function should inherit
from this generic Transfer function
See Also
--------
ImpandanceTensor : Transfer function for the MT impedance tensor
Tipper : Transfer function for the MT tipper
Examples
--------
A generic example
>>> tf = TransferFunction(variation="example", out_chans=["bye", "see you", "ciao"], in_chans=["hello", "hi_there"])
>>> print(tf.to_string())
| bye | | bye_hello bye_hi_there | | hello |
| see you | = | see you_hello see you_hi_there | | hi_there |
| ciao | | ciao_hello ciao_hi_there |
Combining the impedance tensor and the tipper into one TransferFunction
>>> tf = TransferFunction(variation="combined", out_chans=["Ex", "Ey"], in_chans=["Hx", "Hy", "Hz"])
>>> print(tf.to_string())
| Ex | | Ex_Hx Ex_Hy Ex_Hz | | Hx |
| Ey | = | Ey_Hx Ey_Hy Ey_Hz | | Hy |
| Hz |
"""
_types: Dict[str, type] = {}
"""Store types which will help automatic instantiation"""
name: Optional[str] = None
"""The name of the transfer function, this will be set automatically"""
variation: constr(max_length=16) = "generic"
"""A short additional bit of information about this variation"""
out_chans: List[str]
"""The output channels"""
in_chans: List[str]
"""The input channels"""
cross_chans: Optional[List[str]] = None
"""The channels to use for calculating the cross spectra"""
n_out: Optional[int] = None
"""The number of output channels"""
n_in: Optional[int] = None
"""The number of input channels"""
n_cross: Optional[int] = None
"""The number of cross power channels"""
def __init_subclass__(cls) -> None:
"""
Used to automatically register child transfer functions in `_types`
When a TransferFunction child class is imported, it is added to the base
TransferFunction _types variable. Later, this dictionary of class types
can be used to initialise a specific child transfer function from a
dictonary as long as that specific child transfer fuction has already
been imported and it is called from a pydantic class that will validate
the inputs.
The intention of this method is to support initialising transfer
functions from JSON files. This is a similar approach to
ResisticsProcess.
"""
cls._types[cls.__name__] = cls
@classmethod
def __get_validators__(cls):
"""Get the validators that will be used by pydantic"""
yield cls.validate
[docs] @classmethod
def validate(
cls, value: Union["TransferFunction", Dict[str, Any]]
) -> "TransferFunction":
"""
Validate a TransferFunction
Parameters
----------
value : Union[TransferFunction, Dict[str, Any]]
A TransferFunction child class or a dictionary
Returns
-------
TransferFunction
A TransferFunction or TransferFunction child class
Raises
------
ValueError
If the value is neither a TransferFunction or a dictionary
KeyError
If name is not in the dictionary
ValueError
If initialising from dictionary fails
Examples
--------
The following example will show how a child TransferFunction class
can be instantiated using a dictionary and the parent TransferFunction
(but only as long as that child class has been imported).
>>> from resistics.transfunc import TransferFunction
Show known TransferFunction types in built into resistics
>>> for entry in TransferFunction._types.items():
... print(entry)
('ImpedanceTensor', <class 'resistics.transfunc.ImpedanceTensor'>)
('Tipper', <class 'resistics.transfunc.Tipper'>)
Now let's initialise an ImpedanceTensor from the base TransferFunction
and a dictionary.
>>> mytf = {"name": "ImpedanceTensor", "variation": "ecross", "cross_chans": ["Ex", "Ey"]}
>>> test = TransferFunction(**mytf)
Traceback (most recent call last):
...
KeyError: 'out_chans'
This is not quite what we were expecting. The generic TransferFunction
requires out_chans to be defined, but they are not in the dictionary as
the ImpedanceTensor child class defaults these. To get this to work,
instead use the validate class method. This is the class method used by
pydantic when instantiating.
>>> mytf = {"name": "ImpedanceTensor", "variation": "ecross", "cross_chans": ["Ex", "Ey"]}
>>> test = TransferFunction.validate(mytf)
>>> test.summary()
{
'name': 'ImpedanceTensor',
'variation': 'ecross',
'out_chans': ['Ex', 'Ey'],
'in_chans': ['Hx', 'Hy'],
'cross_chans': ['Ex', 'Ey'],
'n_out': 2,
'n_in': 2,
'n_cross': 2
}
That's more like it. This will raise errors if an unknown type of
TransferFunction is received.
>>> mytf = {"name": "NewTF", "cross_chans": ["Ex", "Ey"]}
>>> test = TransferFunction.validate(mytf)
Traceback (most recent call last):
...
ValueError: Unable to initialise NewTF from dictionary
Or if the dictionary does not have a name key
>>> mytf = {"cross_chans": ["Ex", "Ey"]}
>>> test = TransferFunction.validate(mytf)
Traceback (most recent call last):
...
KeyError: 'No name provided for initialisation of TransferFunction'
Unexpected inputs will also raise an error
>>> test = TransferFunction.validate(5)
Traceback (most recent call last):
...
ValueError: TransferFunction unable to initialise from <class 'int'>
"""
if isinstance(value, TransferFunction):
return value
if not isinstance(value, dict):
raise ValueError(
f"TransferFunction unable to initialise from {type(value)}"
)
if "name" not in value:
raise KeyError("No name provided for initialisation of TransferFunction")
# check if it is a TransferFunction
name = value.pop("name")
if name == "TransferFunction":
return cls(**value)
# check other known Transfer Functions
try:
return cls._types[name](**value)
except Exception:
raise ValueError(f"Unable to initialise {name} from dictionary")
@validator("name", always=True)
def validate_name(cls, value: Union[str, None]) -> str:
"""Inialise the name attribute of the transfer function"""
if value is None:
return cls.__name__
return value
@validator("cross_chans", always=True)
def validate_cross_chans(
cls, value: Union[None, List[str]], values: Dict[str, Any]
) -> List[str]:
"""Validate cross spectra channels"""
if value is None:
return values["in_chans"]
return value
@validator("n_out", always=True)
def validate_n_out(cls, value: Union[None, int], values: Dict[str, Any]) -> int:
"""Validate number of output channels"""
if value is None:
return len(values["out_chans"])
return value
@validator("n_in", always=True)
def validate_n_in(cls, value: Union[None, int], values: Dict[str, Any]) -> int:
"""Validate number of input channels"""
if value is None:
return len(values["in_chans"])
return value
@validator("n_cross", always=True)
def validate_n_cross(cls, value: Union[None, int], values: Dict[str, Any]) -> int:
"""Validate number of cross channels"""
if value is None:
return len(values["cross_chans"])
return value
[docs] def n_eqns_per_output(self) -> int:
"""Get the number of equations per output"""
return len(self.cross_chans)
[docs] def n_regressors(self) -> int:
"""Get the number of regressors"""
return self.n_in
[docs] def to_string(self):
"""Get the transfer function as as string"""
n_lines = max(len(self.in_chans), len(self.out_chans))
lens = [len(x) for x in self.in_chans] + [len(x) for x in self.out_chans]
max_len = max(lens)
line_equals = (n_lines - 1) // 2
outstr = ""
for il in range(n_lines):
out_chan = self._out_chan_string(il, max_len)
in_chan = self._in_chan_string(il, max_len)
tensor = self._tensor_string(il, max_len)
eq = "=" if il == line_equals else " "
outstr += f"{out_chan} {eq} {tensor} {in_chan}\n"
return outstr.rstrip("\n")
def _out_chan_string(self, il: int, max_len: int) -> str:
"""Get the out channels string"""
if il >= self.n_out:
empty_len = max_len + 4
return f"{'':{empty_len}s}"
return f"| { self.out_chans[il]:{max_len}s} |"
def _in_chan_string(self, il: int, max_len: int) -> str:
"""Get the in channel string"""
if il >= self.n_in:
return ""
return f"| { self.in_chans[il]:{max_len}s} |"
def _tensor_string(self, il: int, max_len: int) -> str:
"""Get the tensor string"""
if il >= self.n_out:
element_len = ((max_len * 2 + 1) + 1) * self.n_in + 3
return f"{'':{element_len}s}"
elements = "| "
for chan in self.in_chans:
component = f"{self.out_chans[il]}_{chan}"
elements += f"{component:{2*max_len + 1}s} "
elements += "|"
return elements
[docs]class ImpedanceTensor(TransferFunction):
"""
Standard magnetotelluric impedance tensor
Notes
-----
Information about data units
- Magnetic permeability in nT . m / A
- Electric (E) data is in mV/m
- Magnetic (H) data is in nT
- Z = E/H is in mV / m . nT
- Units of resistance = Ohm = V / A
Examples
--------
>>> from resistics.transfunc import ImpedanceTensor
>>> tf = ImpedanceTensor()
>>> print(tf.to_string())
| Ex | = | Ex_Hx Ex_Hy | | Hx |
| Ey | | Ey_Hx Ey_Hy | | Hy |
"""
variation: constr(max_length=16) = "default"
out_chans: List[str] = ["Ex", "Ey"]
in_chans: List[str] = ["Hx", "Hy"]
[docs] @staticmethod
def get_resistivity(periods: np.ndarray, component: Component) -> np.ndarray:
"""
Get apparent resistivity for a component
Parameters
----------
periods : np.ndarray
The periods of the component
component : Component
The component values
Returns
-------
np.ndarray
Apparent resistivity
"""
squared = np.power(np.absolute(component.to_numpy()), 2)
return 0.2 * periods * squared
[docs] @staticmethod
def get_phase(key: str, component: Component) -> np.ndarray:
"""
Get the phase for the component
.. note::
Components ExHx and ExHy are wrapped around in [0,90]
Parameters
----------
key : str
The component name
component : Component
The component values
Returns
-------
np.ndarray
The phase values
"""
phase = np.angle(component.to_numpy())
# unwrap into specific quadrant and convert to degrees
phase = np.unwrap(phase) * 180 / np.pi
if key == "ExHx" or key == "ExHy":
phase = np.mod(phase, 360) - 180
return phase
[docs] @staticmethod
def get_fig(
x_lim: Optional[List[float]] = None,
res_lim: Optional[List[float]] = None,
phs_lim: Optional[List[float]] = None,
) -> go.Figure:
"""
Get a figure for plotting the ImpedanceTensor
Parameters
----------
x_lim : Optional[List[float]], optional
The x limits, to be provided as powers of 10, by default None. For
example, for 0.001, use -3
res_lim : Optional[List[float]], optional
The y limits for resistivity, to be provided as powers of 10, by
default None. For example, for 1000, use 3
phs_lim : Optional[List[float]], optional
The phase limits, by default None
Returns
-------
go.Figure
Plotly figure
"""
from resistics.plot import PLOTLY_MARGIN, PLOTLY_TEMPLATE
fig = make_subplots(
rows=2,
cols=1,
shared_xaxes=True,
vertical_spacing=0.08,
subplot_titles=["Apparent resistivity", "Phase"],
)
# apparent resistivity axes
fig.update_xaxes(type="log", showticklabels=True, row=1, col=1)
fig.update_yaxes(title_text="App. resistivity (Ohm m)", row=1, col=1)
fig.update_yaxes(type="log", row=1, col=1)
if x_lim is not None:
fig.update_xaxes(range=x_lim, row=1, col=1)
if res_lim is not None:
fig.update_yaxes(range=res_lim, row=1, col=1)
# phase axes
fig.update_xaxes(title_text="Period (s)", type="log", row=2, col=1)
fig.update_xaxes(showticklabels=True, row=2, col=1)
# fig.update_yaxes(scaleanchor="x", scaleratio=1, row=1, col=1)
fig.update_yaxes(title_text="Phase (degrees)", row=2, col=1)
if phs_lim is not None:
fig.update_yaxes(range=phs_lim, row=2, col=1)
# update the layout
fig.update_layout(template=PLOTLY_TEMPLATE, margin=dict(PLOTLY_MARGIN))
return fig
[docs] @staticmethod
def plot(
freqs: List[float],
components: Dict[str, Component],
fig: Optional[go.Figure] = None,
to_plot: Optional[List[str]] = None,
legend: str = "Impedance tensor",
x_lim: Optional[List[float]] = None,
res_lim: Optional[List[float]] = None,
phs_lim: Optional[List[float]] = None,
symbol: Optional[str] = "circle",
) -> go.Figure:
"""
Plot the Impedance tensor
Parameters
----------
freqs : List[float]
The frequencies where the impedance tensor components have been
calculated
components : Dict[str, Component]
The component data
fig : Optional[go.Figure], optional
Figure to add to, by default None
to_plot : Optional[List[str]], optional
The components to plot, by default all of the components of the
impedance tensor
legend : str, optional
Legend prefix for the components, by default "Impedance tensor"
x_lim : Optional[List[float]], optional
The x limits, to be provided as powers of 10, by default None. For
example, for 0.001, use -3. Only used when a figure is not provided.
res_lim : Optional[List[float]], optional
The y limits for resistivity, to be provided as powers of 10, by
default None. For example, for 1000, use 3. Only used when a figure
is not provided.
phs_lim : Optional[List[float]], optional
The phase limits, by default None. Only used when a figure is not
provided.
symbol : Optional[str], optional
The marker symbol to use, by default "circle"
Returns
-------
go.Figure
[description]
"""
if fig is None:
fig = ImpedanceTensor.get_fig(x_lim=x_lim, res_lim=res_lim, phs_lim=phs_lim)
if to_plot is None:
to_plot = ["ExHy", "EyHx", "ExHx", "EyHy"]
periods = np.reciprocal(freqs)
colors = {"ExHx": "orange", "EyHy": "green", "ExHy": "red", "EyHx": "blue"}
for comp in to_plot:
res = ImpedanceTensor.get_resistivity(periods, components[comp])
phs = ImpedanceTensor.get_phase(comp, components[comp])
comp_legend = f"{legend} - {comp}"
scatter = go.Scatter(
x=periods,
y=res,
mode="lines+markers",
marker=dict(color=colors[comp], symbol=symbol),
line=dict(color=colors[comp]),
name=comp_legend,
legendgroup=comp_legend,
)
fig.add_trace(scatter, row=1, col=1)
scatter = go.Scatter(
x=periods,
y=phs,
mode="lines+markers",
marker=dict(color=colors[comp], symbol=symbol),
line=dict(color=colors[comp]),
name=comp_legend,
legendgroup=comp_legend,
showlegend=False,
)
fig.add_trace(scatter, row=2, col=1)
return fig
[docs]class Tipper(TransferFunction):
"""
Magnetotelluric tipper
The tipper components are Tx = HzHx and Ty = HzHy
The tipper length is sqrt(Re(Tx)^2 + Re(Ty)^2)
The tipper angle is arctan (Re(Ty)/Re(Tx))
Notes
-----
Information about units
- Tipper T = H/H is dimensionless
Examples
--------
>>> from resistics.transfunc import Tipper
>>> tf = Tipper()
>>> print(tf.to_string())
| Hz | = | Hz_Hx Hz_Hy | | Hx |
| Hy |
"""
variation: constr(max_length=16) = "default"
out_chans: List[str] = ["Hz"]
in_chans: List[str] = ["Hx", "Hy"]
[docs] def get_length(self, components: Dict[str, Component]) -> np.ndarray:
"""Get the tipper length"""
txRe = components["HzHx"].real
tyRe = components["HzHy"].real
return np.sqrt(np.power(txRe, 2) + np.power(tyRe, 2))
[docs] def get_real_angle(self, components: Dict[str, Component]) -> np.ndarray:
"""Get the real angle"""
txRe = np.array(components["HzHx"].real)
tyRe = np.array(components["HzHy"].real)
return np.arctan(tyRe / txRe) * 180 / np.pi
[docs] def get_imag_angle(self, components: Dict[str, Component]) -> np.ndarray:
"""Get the imaginary angle"""
txIm = np.array(components["HzHx"].imag)
tyIm = np.array(components["HzHy"].imag)
return np.arctan(tyIm / txIm) * 180 / np.pi
[docs] def plot(
self,
freqs: List[float],
components: Dict[str, Component],
x_lim: Optional[List[float]] = None,
len_lim: Optional[List[float]] = None,
ang_lim: Optional[List[float]] = None,
) -> go.Figure:
"""
Plot the impedance tensor
.. warning::
This probably needs further checking and verification
Parameters
----------
freqs : List[float]
The x axis frequencies
components : Dict[str, Component]
The component data
x_lim : Optional[List[float]], optional
The x limits, to be provided as powers of 10, by default None. For
example, for 0.001, use -3
len_lim : Optional[List[float]], optional
The y limits for tipper length, to be provided as powers of 10, by
default None. For example, for 1000, use 3
ang_lim : Optional[List[float]], optional
The angle limits, by default None
Returns
-------
go.Figure
Plotly figure
"""
import warnings
from plotly.subplots import make_subplots
warnings.warn("Plotting of tippers needs further verification")
periods = np.reciprocal(freqs)
if x_lim is None:
x_lim = [-3, 5]
if len_lim is None:
len_lim = [-2, 6]
if ang_lim is None:
ang_lim = [-10, 100]
fig = make_subplots(
rows=2,
cols=1,
shared_xaxes=True,
vertical_spacing=0.08,
subplot_titles=["Length", "Angles"],
)
fig.update_layout(width=1000, autosize=True)
# x axes
fig.update_xaxes(title_text="Period (s)", type="log", range=x_lim, row=1, col=1)
fig.update_xaxes(showticklabels=True, row=1, col=1)
fig.update_xaxes(title_text="Period (s)", type="log", range=x_lim, row=2, col=1)
fig.update_xaxes(showticklabels=True, row=2, col=1)
# y axes
fig.update_yaxes(title_text="Tipper length", row=1, col=1)
# fig.update_yaxes(type="log", row=1, col=1)
# fig.update_yaxes(scaleanchor="x", scaleratio=1, row=1, col=1)
fig.update_yaxes(title_text="Angle (degrees)", row=2, col=1)
# plot the tipper length
scatter = go.Scatter(
x=periods,
y=self.get_length(components),
mode="lines+markers",
marker=dict(color="red"),
line=dict(color="red"),
name="Tipper length",
)
fig.add_trace(scatter, row=1, col=1)
# plot the real angle
scatter = go.Scatter(
x=periods,
y=self.get_real_angle(components),
mode="lines+markers",
marker=dict(color="green"),
line=dict(color="green"),
name="Real angle",
)
fig.add_trace(scatter, row=2, col=1)
# plot the imag angle
scatter = go.Scatter(
x=periods,
y=self.get_imag_angle(components),
mode="lines+markers",
marker=dict(color="blue"),
line=dict(color="blue"),
name="Imag angle",
)
fig.add_trace(scatter, row=2, col=1)
return fig