pado.optical_element.OpticalElement

class OpticalElement(dim, pitch, wvl, field_change=None, device='cpu', name='not defined', polar='non')[source]
Parameters:
__init__(dim, pitch, wvl, field_change=None, device='cpu', name='not defined', polar='non')[source]

Base class for optical elements that modify incident light wavefront.

The wavefront modification is stored as amplitude and phase tensors. Note that the number of channels is one for wavefront modulation.

Parameters:
  • dim (tuple) – Dimensions (B, 1, R, C) for batch size, channels, rows, columns

  • pitch (float) – Pixel pitch in meters

  • wvl (float) – Wavelength of light in meters

  • field_change (torch.Tensor, optional) – Wavefront modification tensor [B, C, H, W]

  • device (str) – Device to store wavefront (‘cpu’, ‘cuda:0’, etc.)

  • name (str) – Name identifier for this optical element

  • polar (str) – Polarization mode (‘non’: scalar, ‘polar’: vector)

Return type:

None

Examples

>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.field_change.shape
torch.Size([1, 1, 100, 100])
forward(light, interp_mode='nearest')[source]

Propagate incident light through the optical element.

Parameters:
  • light (Light) – Input light field

  • interp_mode (str) – Interpolation method for resizing (‘bilinear’, ‘nearest’)

Returns:

Light field after interaction with optical element

Return type:

Light

Examples

>>> element = OpticalElement(dim=(1, 1, 64, 64), pitch=2e-6)
>>> light = Light(dim=(1, 1, 64, 64), pitch=2e-6)
>>> output = element.forward(light)
forward_non_polar(light, interp_mode='nearest')[source]

Propagate non-polarized light through the optical element.

Handles resolution matching between light and optical element by resizing and padding as needed. Applies the optical element’s field modulation to the input light.

Parameters:
  • light (Light) – Input light field to propagate through the element

  • interp_mode (str) – Interpolation method for resizing (‘bilinear’, ‘nearest’)

Returns:

Modified light field after interaction with optical element

Return type:

Light

Raises:

ValueError – If wavelengths of light and element don’t match

get_amplitude_change()[source]

Return amplitude change of the wavefront.

Returns:

Amplitude change of the wavefront

Return type:

torch.Tensor

Examples

>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> amp = element.get_amplitude_change()
get_device()[source]

Returns the device on which tensors are stored.

Returns:

The device identifier (e.g., ‘cpu’, ‘cuda:0’).

Return type:

str

get_field_change()[source]

Returns the field_change tensor.

Returns:

The field_change tensor representing amplitude and phase changes.

Return type:

torch.Tensor

get_name()[source]

Returns the name of the optical element.

Returns:

The name identifier.

Return type:

str

get_phase_change()[source]

Return phase change of the wavefront.

Returns:

Phase change of the wavefront

Return type:

torch.Tensor

Examples

>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> phase = element.get_phase_change()
get_pitch()[source]

Returns the pixel pitch.

Returns:

The pixel pitch in meters.

Return type:

float

get_polar()[source]

Returns the polarization mode.

Returns:

The polarization mode (‘non’ for scalar, ‘polar’ for vector).

Return type:

str

get_wvl()[source]

Returns the wavelength.

Returns:

The wavelength in meters.

Return type:

float

pad(pad_width, padval=0)[source]

Pad the wavefront change with constant value.

Parameters:
  • pad_width (tuple) – Padding width following torch.nn.functional.pad format

  • padval (float) – Value to pad with, only 0 supported currently

Raises:

NotImplementedError – If padval is not 0

Return type:

None

Examples

>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.pad((10,10,10,10))  # Add 10 pixels padding on all sides
resize(target_pitch, interp_mode='nearest')[source]

Resize the wavefront change by changing the pixel pitch.

Parameters:
  • target_pitch (float) – New pixel pitch to use

  • interp_mode (str) – Interpolation method used in torch.nn.functional.interpolate - ‘bilinear’: Bilinear interpolation - ‘nearest’: Nearest neighbor interpolation

Return type:

None

Examples

>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.resize(1e-6)  # Resize to 1μm pitch
set_amplitude_change(amplitude, c=None)[source]

Set amplitude change for specific or all channels.

Parameters:
  • amplitude (torch.Tensor) – Amplitude change in polar representation

  • c (int, optional) – Channel index. If None, applies to all channels

Return type:

None

Examples

>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> amp = torch.ones((1,1,100,100))
>>> element.set_amplitude_change(amp)
set_field_change(field_change, c=None)[source]

Set field change for specific or all channels.

Parameters:
  • field_change (torch.Tensor) – Field change in complex tensor

  • c (int, optional) – Channel index. If None, applies to all channels

Return type:

None

Examples

>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> field = torch.ones((1,1,100,100), dtype=torch.cfloat)
>>> element.set_field_change(field)
set_name(name)[source]

Sets the name of the optical element.

Parameters:

name (str) – The name identifier.

Return type:

None

set_phase_change(phase, c=None)[source]

Set phase change for specific or all channels.

Parameters:
  • phase (torch.Tensor) – Phase change in polar representation

  • c (int, optional) – Channel index. If None, applies to all channels

Return type:

None

Examples

>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> phase = torch.zeros((1,1,100,100))
>>> element.set_phase_change(phase)
set_pitch(pitch)[source]

Set the pixel pitch of the complex tensor.

Parameters:

pitch (float) – Pixel pitch in meters

Return type:

None

Examples

>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.set_pitch(1e-6)  # Set 1μm pitch
set_polar(polar)[source]

Set polarization mode for the optical element.

Parameters:

polar (str) – Polarization mode (‘non’: scalar, ‘polar’: vector)

Return type:

None

Examples

>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.set_polar('polar')  # Set vector field mode
set_wvl(wvl)[source]

Sets the wavelength.

Parameters:

wvl (float) – The wavelength in meters.

Return type:

None

shape()[source]

Return shape of light-wavefront modulation.

The number of channels is one for wavefront modulation.

Returns:

Dimensions (B, 1, R, C) for batch size, channels, rows, columns

Return type:

tuple

Examples

>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.shape()
(1, 1, 100, 100)
visualize(b=0, c=None)[source]

Visualize the wavefront modulation of the optical element.

Displays amplitude and phase changes of the optical element’s wavefront modulation. Creates subplots showing amplitude change and phase change for specified channels.

Parameters:
  • b (int, optional) – Batch index to visualize. Defaults to 0.

  • c (int, optional) – Channel index to visualize. If None, visualizes all channels. Defaults to None.

Return type:

None

Examples

>>> lens = RefractiveLens((1,1,512,512), 2e-6, 0.1, 633e-9, 'cpu')
>>> lens.visualize()  # Visualize first batch, all channels
>>> lens.visualize(b=0, c=0)  # Visualize first batch, first channel