pado.optical_element.OpticalElement¶
- class OpticalElement(dim, pitch, wvl, field_change=None, device='cpu', name='not defined', polar='non')[source]¶
- Parameters:
- __init__(dim, pitch, wvl, field_change=None, device='cpu', name='not defined', polar='non')[source]¶
Base class for optical elements that modify incident light wavefront.
The wavefront modification is stored as amplitude and phase tensors. Note that the number of channels is one for wavefront modulation.
- Parameters:
dim (tuple) – Dimensions (B, 1, R, C) for batch size, channels, rows, columns
pitch (float) – Pixel pitch in meters
wvl (float) – Wavelength of light in meters
field_change (torch.Tensor, optional) – Wavefront modification tensor [B, C, H, W]
device (str) – Device to store wavefront (‘cpu’, ‘cuda:0’, etc.)
name (str) – Name identifier for this optical element
polar (str) – Polarization mode (‘non’: scalar, ‘polar’: vector)
- Return type:
None
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.field_change.shape torch.Size([1, 1, 100, 100])
- forward(light, interp_mode='nearest')[source]¶
Propagate incident light through the optical element.
- Parameters:
- Returns:
Light field after interaction with optical element
- Return type:
Examples
>>> element = OpticalElement(dim=(1, 1, 64, 64), pitch=2e-6) >>> light = Light(dim=(1, 1, 64, 64), pitch=2e-6) >>> output = element.forward(light)
- forward_non_polar(light, interp_mode='nearest')[source]¶
Propagate non-polarized light through the optical element.
Handles resolution matching between light and optical element by resizing and padding as needed. Applies the optical element’s field modulation to the input light.
- Parameters:
- Returns:
Modified light field after interaction with optical element
- Return type:
- Raises:
ValueError – If wavelengths of light and element don’t match
- get_amplitude_change()[source]¶
Return amplitude change of the wavefront.
- Returns:
Amplitude change of the wavefront
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> amp = element.get_amplitude_change()
- get_device()[source]¶
Returns the device on which tensors are stored.
- Returns:
The device identifier (e.g., ‘cpu’, ‘cuda:0’).
- Return type:
- get_field_change()[source]¶
Returns the field_change tensor.
- Returns:
The field_change tensor representing amplitude and phase changes.
- Return type:
- get_name()[source]¶
Returns the name of the optical element.
- Returns:
The name identifier.
- Return type:
- get_phase_change()[source]¶
Return phase change of the wavefront.
- Returns:
Phase change of the wavefront
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> phase = element.get_phase_change()
- get_polar()[source]¶
Returns the polarization mode.
- Returns:
The polarization mode (‘non’ for scalar, ‘polar’ for vector).
- Return type:
- pad(pad_width, padval=0)[source]¶
Pad the wavefront change with constant value.
- Parameters:
- Raises:
NotImplementedError – If padval is not 0
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.pad((10,10,10,10)) # Add 10 pixels padding on all sides
- resize(target_pitch, interp_mode='nearest')[source]¶
Resize the wavefront change by changing the pixel pitch.
- Parameters:
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.resize(1e-6) # Resize to 1μm pitch
- set_amplitude_change(amplitude, c=None)[source]¶
Set amplitude change for specific or all channels.
- Parameters:
amplitude (torch.Tensor) – Amplitude change in polar representation
c (int, optional) – Channel index. If None, applies to all channels
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> amp = torch.ones((1,1,100,100)) >>> element.set_amplitude_change(amp)
- set_field_change(field_change, c=None)[source]¶
Set field change for specific or all channels.
- Parameters:
field_change (torch.Tensor) – Field change in complex tensor
c (int, optional) – Channel index. If None, applies to all channels
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> field = torch.ones((1,1,100,100), dtype=torch.cfloat) >>> element.set_field_change(field)
- set_phase_change(phase, c=None)[source]¶
Set phase change for specific or all channels.
- Parameters:
phase (torch.Tensor) – Phase change in polar representation
c (int, optional) – Channel index. If None, applies to all channels
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> phase = torch.zeros((1,1,100,100)) >>> element.set_phase_change(phase)
- set_pitch(pitch)[source]¶
Set the pixel pitch of the complex tensor.
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.set_pitch(1e-6) # Set 1μm pitch
- set_polar(polar)[source]¶
Set polarization mode for the optical element.
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.set_polar('polar') # Set vector field mode
- shape()[source]¶
Return shape of light-wavefront modulation.
The number of channels is one for wavefront modulation.
- Returns:
Dimensions (B, 1, R, C) for batch size, channels, rows, columns
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.shape() (1, 1, 100, 100)
- visualize(b=0, c=None)[source]¶
Visualize the wavefront modulation of the optical element.
Displays amplitude and phase changes of the optical element’s wavefront modulation. Creates subplots showing amplitude change and phase change for specified channels.
- Parameters:
- Return type:
Examples
>>> lens = RefractiveLens((1,1,512,512), 2e-6, 0.1, 633e-9, 'cpu') >>> lens.visualize() # Visualize first batch, all channels >>> lens.visualize(b=0, c=0) # Visualize first batch, first channel