pado package¶
Submodules¶
pado.light module¶
- class Light(dim, pitch, wvl, field=None, device='cpu')[source]¶
Bases:
object
Light wave with complex field wavefront.
Represent a light wave with complex field wavefront that can be manipulated through various optical operations.
- Parameters:
- __init__(dim, pitch, wvl, field=None, device='cpu')[source]¶
Create light wave instance with complex field wavefront.
- Parameters:
dim (tuple) – Field dimensions (B, Ch, R, C) for batch, channels, rows, cols
pitch (float) – Pixel pitch in meters
wvl (float or list) – Wavelength in meters. Can be single value or list for multi-wavelength
field (torch.Tensor, optional) – Initial complex field [B, Ch, R, C]
device (str) – Device for computation (‘cpu’, ‘cuda:0’, etc.)
Examples
>>> light = Light(dim=(1, 1, 1024, 1024), pitch=6.4e-6, wvl=633e-9) >>> light = Light(dim=(2, 3, 512, 512), pitch=2e-6, wvl=[450e-9, 550e-9, 650e-9])
- clone()[source]¶
Create deep copy of light instance.
- Returns:
New light instance with copied attributes
- Return type:
Examples
>>> light_copy = light.clone()
- pad(pad_width, padval=0)[source]¶
Pad light field with constant value.
- Parameters:
- Return type:
Examples
>>> light.pad((16, 16, 16, 16)) # Add 16 pixels padding on all sides
- set_real(real, c=None)[source]¶
Set real part of light wavefront.
- Parameters:
real (torch.Tensor) – Real part in rectangular representation. If c is None: Expected shape is (B,Ch,R,C) matching self.field.real.shape If c is provided: Expected shape is (B,R,C) matching self.field[:, c, …].real.shape where B=batch size, Ch=channels, R=rows, C=columns
c (int, optional) – Channel index to modify. If provided, only that specific channel will be updated.
- Return type:
Examples
>>> # Set real part for all channels (B=1, Ch=1, R=1024, C=1024) >>> real_part = torch.ones((1, 1, 1024, 1024)) # (B,Ch,R,C) >>> light.set_real(real_part) >>> >>> # Set real part for only channel 0 (B=1, R=1024, C=1024) >>> real_part_channel0 = torch.ones((1, 1024, 1024)) # (B,R,C) >>> light.set_real(real_part_channel0, c=0)
- set_imag(imag, c=None)[source]¶
Set imaginary part of light wavefront.
- Parameters:
imag (torch.Tensor) – Imaginary part in rectangular representation. If c is None: Expected shape is (B,Ch,R,C) matching self.field.imag.shape If c is provided: Expected shape is (B,R,C) matching self.field[:, c, …].imag.shape where B=batch size, Ch=channels, R=rows, C=columns
c (int, optional) – Channel index to modify. If provided, only that specific channel will be updated.
- Return type:
Examples
>>> # Set imaginary part for all channels (B=1, Ch=1, R=1024, C=1024) >>> imag_part = torch.ones((1, 1, 1024, 1024)) # (B,Ch,R,C) >>> light.set_imag(imag_part) >>> >>> # Set imaginary part for only channel 0 (B=1, R=1024, C=1024) >>> imag_part_channel0 = torch.ones((1, 1024, 1024)) # (B,R,C) >>> light.set_imag(imag_part_channel0, c=0)
- set_amplitude(amplitude, c=None)[source]¶
Set amplitude of light wavefront (keeps phase unchanged) without maintaining computation graph.
- Parameters:
amplitude (torch.Tensor) – Amplitude in polar representation.
c (int, optional) – Channel index to modify.
- Return type:
- set_phase(phase, c=None)[source]¶
Set phase of light wavefront (keeps amplitude unchanged).
- Parameters:
phase (torch.Tensor) – Phase in polar representation (in radians). If c is None: Expected shape is (B,Ch,R,C) matching self.field.shape If c is provided: Expected shape is (B,R,C) matching self.field[:, c, …].shape where B=batch size, Ch=channels, R=rows, C=columns
c (int, optional) – Channel index to modify. If provided, only that specific channel will be updated.
- Return type:
Examples
>>> # Set phase for all channels (B=1, Ch=1, R=1024, C=1024) >>> phase = torch.zeros((1, 1, 1024, 1024)) # (B,Ch,R,C) >>> light.set_phase(phase) >>> >>> # Set phase for only channel 0 (B=1, R=1024, C=1024) >>> phase_channel0 = torch.zeros((1, 1024, 1024)) # (B,R,C) >>> light.set_phase(phase_channel0, c=0)
- set_field(field, c=None)[source]¶
Set complex field of light wavefront.
- Parameters:
field (torch.Tensor) – Complex field tensor. If c is None: Expected shape is (B,Ch,R,C) matching self.field.shape If c is provided: Expected shape is (B,R,C) matching self.field[:, c, …].shape where B=batch size, Ch=channels, R=rows, C=columns
c (int, optional) – Channel index to modify. If provided, only that specific channel will be updated.
- Return type:
Examples
>>> # Set field for all channels (B=1, Ch=1, R=1024, C=1024) >>> field = torch.ones((1, 1, 1024, 1024), dtype=torch.complex64) # (B,Ch,R,C) >>> light.set_field(field) >>> >>> # Set field for only channel 0 (B=1, R=1024, C=1024) >>> field_channel0 = torch.ones((1, 1024, 1024), dtype=torch.complex64) # (B,R,C) >>> light.set_field(field_channel0, c=0)
- set_pitch(pitch)[source]¶
Set pixel pitch of light field.
Examples
>>> light.set_pitch(6.4e-6) # Set 6.4μm pitch
- get_channel()[source]¶
Return number of channels in light field.
- Returns:
Number of channels
- Return type:
Examples
>>> channels = light.get_channel()
- get_amplitude(c=None)[source]¶
Get amplitude of light wavefront.
- Parameters:
c (int, optional) – Channel index to retrieve. If provided, only that specific channel will be returned.
- Returns:
- Amplitude in polar representation.
If c is None: Shape is (B,Ch,R,C) If c is provided: Shape is (B,R,C) where B=batch size, Ch=channels, R=rows, C=columns
- Return type:
Examples
>>> # Get amplitude for all channels (B=1, Ch=1, R=1024, C=1024) >>> amp = light.get_amplitude() # Shape: (1, 1, 1024, 1024) >>> >>> # Get amplitude for only channel 0 (B=1, R=1024, C=1024) >>> amp_channel0 = light.get_amplitude(c=0) # Shape: (1, 1024, 1024)
- get_phase(c=None)[source]¶
Get phase of light wavefront.
- Parameters:
c (int, optional) – Channel index to retrieve. If provided, only that specific channel will be returned.
- Returns:
- Phase in polar representation (in radians).
If c is None: Shape is (B,Ch,R,C) If c is provided: Shape is (B,R,C) where B=batch size, Ch=channels, R=rows, C=columns
- Return type:
Examples
>>> # Get phase for all channels (B=1, Ch=1, R=1024, C=1024) >>> phase = light.get_phase() # Shape: (1, 1, 1024, 1024) >>> >>> # Get phase for only channel 0 (B=1, R=1024, C=1024) >>> phase_channel0 = light.get_phase(c=0) # Shape: (1, 1024, 1024)
- get_intensity(c=None)[source]¶
Get intensity (amplitude squared) of light wavefront.
- Parameters:
c (int, optional) – Channel index to retrieve. If provided, only that specific channel will be returned.
- Returns:
- Intensity values.
If c is None: Shape is (B,Ch,R,C) If c is provided: Shape is (B,R,C) where B=batch size, Ch=channels, R=rows, C=columns
- Return type:
Examples
>>> # Get intensity for all channels (B=1, Ch=1, R=1024, C=1024) >>> intensity = light.get_intensity() # Shape: (1, 1, 1024, 1024) >>> >>> # Get intensity for only channel 0 (B=1, R=1024, C=1024) >>> intensity_channel0 = light.get_intensity(c=0) # Shape: (1, 1024, 1024)
- get_field(c=None)[source]¶
Get complex field of light wavefront.
- Parameters:
c (int, optional) – Channel index to retrieve. If provided, only that specific channel will be returned.
- Returns:
- Complex field tensor.
If c is None: Shape is (B,Ch,R,C) If c is provided: Shape is (B,R,C) where B=batch size, Ch=channels, R=rows, C=columns
- Return type:
Examples
>>> # Get field for all channels (B=1, Ch=1, R=1024, C=1024) >>> field = light.get_field() # Shape: (1, 1, 1024, 1024), dtype=torch.complex64 >>> >>> # Get field for only channel 0 (B=1, R=1024, C=1024) >>> field_channel0 = light.get_field(c=0) # Shape: (1, 1024, 1024), dtype=torch.complex64
- get_device()[source]¶
Return device of light field.
- Returns:
Device name (‘cpu’, ‘cuda:0’, etc.)
- Return type:
Examples
>>> device = light.get_device()
- get_bandwidth()[source]¶
Return spatial bandwidth of light wavefront.
- Returns:
Spatial height and width of wavefront in meters
- Return type:
Examples
>>> height, width = light.get_bandwidth()
- get_ideal_angle_limit()[source]¶
Return ideal angle limit of light wavefront based on optical axis.
Calculate the maximum diffraction angle supported by the current sampling, based on the Nyquist sampling criterion: sin(θ_max) = λ/(2·pitch). Use the shortest wavelength for multi-wavelength cases.
- Returns:
Ideal angle limit in degrees
- Return type:
Examples
>>> angle_limit = light.get_ideal_angle_limit()
- magnify(scale_factor, interp_mode='nearest', c=None)[source]¶
Change wavefront resolution without changing pixel pitch.
- Parameters:
- Return type:
Examples
>>> light.magnify(2.0) # Double resolution >>> light.magnify(0.5, interp_mode='bilinear') # Half resolution
- resize(target_pitch, interp_mode='nearest')[source]¶
Resize wavefront by changing pixel pitch.
- Parameters:
- Return type:
Examples
>>> light.resize(4e-6) # Change pitch to 4μm
- set_spherical_light(z, dx=0.0, dy=0.0)[source]¶
Set spherical wavefront from point source.
- Parameters:
- Return type:
Examples
>>> light.set_spherical_light(z=0.1) # Source at 10cm >>> light.set_spherical_light(z=0.05, dx=1e-3) # Offset source
- set_plane_light(theta=0)[source]¶
Set plane wave with unit amplitude.
Examples
>>> light.set_plane_light(theta=15) # 15° incident angle
- set_amplitude_ones()[source]¶
Set amplitude to ones. :rtype:
None
Examples
>>> light.set_amplitude_ones()
- Return type:
None
- set_amplitude_zeros()[source]¶
Set amplitude to zeros. :rtype:
None
Examples
>>> light.set_amplitude_zeros()
- Return type:
None
- set_phase_zeros()[source]¶
Set phase to zeros. :rtype:
None
Examples
>>> light.set_phase_zeros()
- Return type:
None
- set_phase_random(std=1.5707963267948966, distribution='gaussian', c=None)[source]¶
Set random phase with specified distribution.
- Parameters:
std (float, optional) – Standard deviation for phase randomness. - For gaussian: Standard deviation in radians - For uniform: Half-width of uniform distribution in radians - For von_mises: Inverse of concentration parameter (1/κ) Defaults to π/2.
distribution (str, optional) – Type of random distribution. Must be one of [‘gaussian’, ‘uniform’, ‘von_mises’]. Defaults to ‘gaussian’.
c (int, optional) – Channel index to modify. If None, all channels are modified.
- Raises:
ValueError – If distribution is not one of the supported types.
- Return type:
Examples
>>> light.set_phase_random() # Default gaussian with π/2 std for all channels >>> light.set_phase_random(std=np.pi/4) # Reduced randomness >>> light.set_phase_random(distribution='uniform') # Uniform distribution >>> light.set_phase_random(std=0.25, distribution='von_mises') # von Mises with κ=4 >>> light.set_phase_random(c=0) # Only modify first channel
- save(fn)[source]¶
Save light field to file.
Examples
>>> light.save("field.pt") # Save as PyTorch format >>> light.save("field.npy") # Save as NumPy format >>> light.save("field.mat") # Save as MATLAB format
- adjust_amplitude_to_other_light(other_light)[source]¶
Scale amplitude to match average of another light field.
Examples
>>> target = Light(dim=(1,1,1024,1024), pitch=6.4e-6, wvl=633e-9) >>> light.adjust_amplitude_to_other_light(target)
- load_image(image_path, random_phase=False, std=3.141592653589793, distribution='uniform', batch_idx=None)[source]¶
Load image as amplitude pattern with optional random phase.
- Parameters:
image_path (str) – Path to image file.
random_phase (bool, optional) – Whether to apply random phase. Defaults to False.
std (float, optional) – Standard deviation for phase randomness. For gaussian: Standard deviation in radians For uniform: Half-width of uniform distribution in radians For von_mises: Inverse of concentration parameter (1/κ) Defaults to π.
distribution (str, optional) – Type of random distribution. Must be one of [‘gaussian’, ‘uniform’, ‘von_mises’]. Defaults to ‘uniform’.
batch_idx (int, optional) – Specific batch index to load the image into. If None, loads the image into all batches. Defaults to None.
- Return type:
Examples
>>> light.load_image("target.png") # No random phase, all batches >>> light.load_image("target.png", random_phase=True) # Default uniform, all batches >>> light.load_image("target.png", random_phase=True, std=np.pi/4) >>> light.load_image("target.png", random_phase=True, distribution='gaussian') >>> light.load_image("target.png", batch_idx=0) # Load only into first batch
- visualize(b=0, c=None, uniform_scale=False, vmin=None, vmax=None, fix_noise=True, amp_threshold=1e-05)[source]¶
Visualize amplitude, phase, and intensity of light field with improved noise handling.
- Parameters:
b (int) – Batch index
c (int, optional) – Channel index
uniform_scale (bool) – Use same scale for all channels
vmin (float, optional) – Intensity plot range
vmax (float, optional) – Intensity plot range
fix_noise (bool) – Whether to fix numerical noise in amplitude visualization
amp_threshold (float) – Threshold for detecting uniform amplitude with noise
- Return type:
Examples
>>> light.visualize() # Show all channels >>> light.visualize(c=0) # Show only first channel >>> light.visualize(uniform_scale=True) # Use uniform scaling >>> light.visualize(fix_noise=False) # Don't fix numerical noise >>> light.visualize(b=1) # Visualize the second batch
- visualize_image(b=0)[source]¶
Visualize amplitude, phase, and intensity as RGB images.
Examples
>>> light.visualize_image() >>> light.visualize_image(b=1) # Visualize the second batch
- class PolarizedLight(dim, pitch, wvl, fieldX=None, fieldY=None, device='cuda:0')[source]¶
Bases:
Light
Light wave with polarized complex field wavefront.
Represent a light wave with X and Y polarization components that can be manipulated through various optical operations.
- Parameters:
- __init__(dim, pitch, wvl, fieldX=None, fieldY=None, device='cuda:0')[source]¶
Create polarized light wave instance with X and Y field components.
- Parameters:
dim (tuple) – Field dimensions (B, Ch, R, C) for batch, channels, rows, cols
pitch (float) – Pixel pitch in meters
wvl (float) – Wavelength in meters
fieldX (torch.Tensor, optional) – Initial X-polarized field [B, Ch, R, C]
fieldY (torch.Tensor, optional) – Initial Y-polarized field [B, Ch, R, C]
device (str) – Device for computation (‘cpu’, ‘cuda:0’, etc.)
- Return type:
None
Examples
>>> light = PolarizedLight(dim=(1, 1, 1024, 1024), pitch=6.4e-6, wvl=633e-9) >>> light = PolarizedLight(dim=(2, 1, 512, 512), pitch=2e-6, wvl=532e-9, ... fieldX=torch.ones(2,1,512,512), ... fieldY=torch.zeros(2,1,512,512))
- clone()[source]¶
Create deep copy of polarized light instance.
- Returns:
New polarized light instance with copied attributes
- Return type:
Examples
>>> light_copy = light.clone()
- get_amplitude()[source]¶
Return total amplitude of polarized field.
- Returns:
Total amplitude
- Return type:
- get_amplitudeX()[source]¶
Return amplitude of X-polarized field.
- Returns:
Amplitude of X-polarized field
- Return type:
- get_amplitudeY()[source]¶
Return amplitude of Y-polarized field.
- Returns:
Amplitude of Y-polarized field
- Return type:
- get_field()[source]¶
Return complex field.
- Returns:
Complex field values for X and Y components
- Return type:
Examples
>>> field = light.get_field()
- get_fieldX()[source]¶
Return X-polarized field component.
- Returns:
Complex field tensor for X polarization
- Return type:
Examples
>>> x_field = light.get_fieldX()
- get_fieldY()[source]¶
Return Y-polarized field component.
- Returns:
Complex field tensor for Y polarization
- Return type:
Examples
>>> y_field = light.get_fieldY()
- get_imag()[source]¶
Get imaginary part of field for both components.
- Returns:
Imaginary part values for X and Y components
- Return type:
Examples
>>> imag_parts = light.get_imag()
- get_intensity()[source]¶
Return total intensity of polarized field.
- Returns:
Total intensity
- Return type:
- get_intensityX()[source]¶
Return intensity of X-polarized field.
- Returns:
Intensity of X-polarized field
- Return type:
- get_intensityY()[source]¶
Return intensity of Y-polarized field.
- Returns:
Intensity of Y-polarized field
- Return type:
- get_lightX()[source]¶
Return X-polarized Light instance.
- Returns:
Light instance for X component
- Return type:
Examples
>>> x_light = light.get_lightX()
- get_lightY()[source]¶
Return Y-polarized Light instance.
- Returns:
Light instance for Y component
- Return type:
Examples
>>> y_light = light.get_lightY()
- get_phase()[source]¶
Return phase of field.
- Returns:
Phase values for X and Y components
- Return type:
Examples
>>> phase = light.get_phase()
- get_phaseX()[source]¶
Return phase of X-polarized field.
- Returns:
Phase of X-polarized field
- Return type:
- get_phaseY()[source]¶
Return phase of Y-polarized field.
- Returns:
Phase of Y-polarized field
- Return type:
- get_real()[source]¶
Get real part of field for both components.
- Returns:
Real part values for X and Y components
- Return type:
Examples
>>> real_parts = light.get_real()
- magnify(scale_factor, interp_mode='nearest')[source]¶
Change wavefront resolution without changing pixel pitch.
- Parameters:
- Return type:
Examples
>>> light.magnify(2.0, 'bilinear')
- pad(pad_width, padval=0)[source]¶
Pad light field.
- Parameters:
- Return type:
Examples
>>> light.pad((10, 10, 10, 10))
- resize(new_pitch, interp_mode='nearest')[source]¶
Resize light field to match new pixel pitch.
- Parameters:
- Return type:
Examples
>>> light.resize(8e-6)
- set_amplitude(amplitude)[source]¶
Set amplitude for both X and Y components.
- Parameters:
amplitude (torch.Tensor or tuple) – Amplitude values for both components or tuple of (amplitudeX, amplitudeY)
- Return type:
Examples
>>> light.set_amplitude(torch.ones(1,1,512,512)) >>> light.set_amplitude((x_amplitude, y_amplitude))
- set_amplitudeX(amplitude)[source]¶
Set amplitude for X component.
- Parameters:
amplitude (torch.Tensor) – Amplitude values for X component
- Return type:
Examples
>>> light.set_amplitudeX(torch.ones(1,1,512,512))
- set_amplitudeY(amplitude)[source]¶
Set amplitude for Y component.
- Parameters:
amplitude (torch.Tensor) – Amplitude values for Y component
- Return type:
Examples
>>> light.set_amplitudeY(torch.ones(1,1,512,512))
- set_field(field)[source]¶
Set both X and Y field components.
Examples
>>> light.set_field((x_field, y_field))
- set_fieldX(field)[source]¶
Set X-polarized field.
- Parameters:
field (torch.Tensor) – Complex field values
- Return type:
Examples
>>> light.set_fieldX(torch.ones(1,1,512,512, dtype=torch.cfloat))
- set_fieldY(field)[source]¶
Set Y-polarized field.
- Parameters:
field (torch.Tensor) – Complex field values
- Return type:
Examples
>>> light.set_fieldY(torch.ones(1,1,512,512, dtype=torch.cfloat))
- set_imag(imag)[source]¶
Set imaginary part for both X and Y components.
- Parameters:
imag (torch.Tensor or tuple) – Imaginary values for both components or tuple of (imagX, imagY)
- Return type:
Examples
>>> light.set_imag(torch.zeros(1,1,512,512)) >>> light.set_imag((x_imag, y_imag))
- set_imagX(imag)[source]¶
Set imaginary part of X-polarized field.
- Parameters:
imag (torch.Tensor) – Imaginary component values
- Return type:
Examples
>>> light.set_imagX(torch.zeros(1,1,512,512))
- set_imagY(imag)[source]¶
Set imaginary part of Y-polarized field.
- Parameters:
imag (torch.Tensor) – Imaginary component values
- Return type:
Examples
>>> light.set_imagY(torch.zeros(1,1,512,512))
- set_phase(phase)[source]¶
Set phase for both X and Y components.
- Parameters:
phase (torch.Tensor or tuple) – Phase values for both components or tuple of (phaseX, phaseY)
- Return type:
Examples
>>> light.set_phase(torch.zeros(1,1,512,512)) >>> light.set_phase((x_phase, y_phase))
- set_phaseX(phase)[source]¶
Set phase of X-polarized field.
- Parameters:
phase (torch.Tensor) – Phase values in radians
- Return type:
Examples
>>> light.set_phaseX(torch.zeros(1,1,512,512))
- set_phaseY(phase)[source]¶
Set phase of Y-polarized field.
- Parameters:
phase (torch.Tensor) – Phase values in radians
- Return type:
Examples
>>> light.set_phaseY(torch.zeros(1,1,512,512))
- set_plane_light()[source]¶
Set plane wave with unit amplitude and zero phase. :rtype:
None
Examples
>>> light.set_plane_light()
- Return type:
None
- set_real(real)[source]¶
Set real part for both X and Y components.
- Parameters:
real (torch.Tensor or tuple) – Real values for both components or tuple of (realX, realY)
- Return type:
Examples
>>> light.set_real(torch.ones(1,1,512,512)) >>> light.set_real((x_real, y_real))
- set_realX(real)[source]¶
Set real part of X-polarized field.
- Parameters:
real (torch.Tensor) – Real component values
- Return type:
Examples
>>> light.set_realX(torch.ones(1,1,512,512))
- set_realY(real)[source]¶
Set real part of Y-polarized field.
- Parameters:
real (torch.Tensor) – Real component values
- Return type:
Examples
>>> light.set_realY(torch.ones(1,1,512,512))
- set_spherical_light(z, dx=0, dy=0)[source]¶
Set spherical wavefront from point source.
- Parameters:
- Return type:
Examples
>>> light.set_spherical_light(0.1, dx=1e-3)
pado.material module¶
- class Material(material_name)[source]¶
Bases:
object
- Parameters:
material_name (Literal['PDMS', 'FUSED_SILICA', 'VACUUM'])
pado.math module¶
- wrap_phase(phase_u, stay_positive=False)[source]¶
Wrap phase values to [-π, π] or [0, 2π] range.
- Parameters:
phase_u (torch.Tensor) – Unwrapped phase values tensor
stay_positive (bool) – If True, output range is [0, 2π]. If False, [-π, π]
- Returns:
Wrapped phase values tensor
- Return type:
Examples
>>> phase = torch.tensor([3.5 * np.pi, -2.5 * np.pi]) >>> wrapped = wrap_phase(phase) # tensor([0.5000 * π, -0.5000 * π])
- fft(arr_c, normalized='backward', pad_width=None, padval=0, shift=True)[source]¶
Compute 2D FFT of a complex tensor with optional padding and frequency shifting.
- Parameters:
arr_c (torch.Tensor) – Complex tensor [B, Ch, H, W]
normalized (str) – FFT normalization mode: “backward”, “forward”, or “ortho”
pad_width (tuple) – Padding as (left, right, top, bottom)
padval (int) – Padding value (only 0 supported)
shift (bool) – If True, center zero-frequency component
- Returns:
FFT result tensor
- Return type:
Examples
>>> light = Light(dim=(1, 1, 100, 100), pitch=2e-6, wvl=500e-9) >>> field_fft = fft(light.field)
- ifft(arr_c, normalized='backward', pad_width=None, shift=True)[source]¶
Compute 2D inverse FFT of a complex tensor with optional padding and shifting.
- Parameters:
arr_c (torch.Tensor) – Complex tensor [B, Ch, H, W]
normalized (str) – IFFT normalization mode: “backward”, “forward”, or “ortho”
pad_width (tuple) – Padding as (left, right, top, bottom)
shift (bool) – If True, center zero-frequency component
- Returns:
IFFT result tensor
- Return type:
Examples
>>> field = torch.ones((1, 1, 64, 64), dtype=torch.complex64) >>> field_freq = fft(field) >>> field_restored = ifft(field_freq)
- calculate_psnr(img1, img2, data_range=1.0)[source]¶
Calculate Peak Signal-to-Noise Ratio between multi-channel tensors.
- Parameters:
img1 (torch.Tensor) – First tensor [B, Channel, R, C]
img2 (torch.Tensor) – Second tensor [B, Channel, R, C]
data_range (float, optional) – The data range of the input image (e.g., 1.0 for normalized images, 255 for uint8 images). If None, uses the maximum value from images.
- Returns:
PSNR value in dB, infinity if images are identical
- Return type:
Examples
>>> intensity1 = light1.get_intensity() # [B, Channel, R, C] >>> intensity2 = light2.get_intensity() # [B, Channel, R, C] >>> psnr = calculate_psnr(intensity1, intensity2)
- calculate_ssim(img1, img2, window_size=21, sigma=None, data_range=1.0)[source]¶
Calculate Structural Similarity Index between multi-channel tensors.
- Parameters:
img1 (torch.Tensor) – First tensor [B, Channel, H, W]
img2 (torch.Tensor) – Second tensor [B, Channel, H, W]
window_size (int) – Size of Gaussian window (odd number)
sigma (float, optional) – Standard deviation of Gaussian window. If None, defaults to window_size/6
data_range (float) – Dynamic range of images
- Returns:
SSIM score (-1 to 1, where 1 indicates identical images)
- Return type:
Examples
>>> intensity1 = light1.get_intensity() # [B, Channel, R, C] >>> intensity2 = light2.get_intensity() # [B, Channel, R, C] >>> similarity = calculate_ssim(intensity1, intensity2)
- gaussian_window(size, sigma)[source]¶
Create normalized 2D Gaussian window.
- Parameters:
- Returns:
Normalized 2D Gaussian window [size, size]
- Return type:
Examples
>>> window = gaussian_window(11, 1.5)
- sc_dft_1d(g, M, delta_x, delta_fx)[source]¶
Compute 1D scaled DFT for optical field propagation.
- Parameters:
g (torch.Tensor) – Input complex field [M]
M (int) – Number of sample points
delta_x (float) – Spatial sampling interval (m)
delta_fx (float) – Frequency sampling interval (1/m)
- Returns:
Transformed complex field [M]
- Return type:
Examples
>>> M = 1000 >>> pitch = 2e-6 >>> g = torch.exp(-x**2 / (2 * (100*um)**2)).to(torch.complex64) >>> G = sc_dft_1d(g, M, pitch, 1/(M*pitch))
- sc_idft_1d(G, M, delta_fx, delta_x)[source]¶
Compute 1D scaled inverse DFT for optical field reconstruction.
- Parameters:
G (torch.Tensor) – Frequency domain input [M]
M (int) – Number of samples
delta_fx (float) – Frequency sampling interval (1/m)
delta_x (float) – Spatial sampling interval (m)
- Returns:
Spatial domain output [M]
- Return type:
Examples
>>> M = 1000 >>> G = torch.ones(M, dtype=torch.complex64) >>> field = sc_idft_1d(G, M, 1/(M*2e-6), 4e-6)
- sc_dft_2d(u, Mx, My, delta_x, delta_y, delta_fx, delta_fy)[source]¶
Perform 2D scaled DFT using separable 1D transforms.
- Parameters:
u (torch.Tensor) – Input field [My, Mx]
Mx (int) – Number of samples in x,y directions
My (int) – Number of samples in x,y directions
delta_x (float) – Spatial sampling intervals (m)
delta_y (float) – Spatial sampling intervals (m)
delta_fx (float) – Frequency sampling intervals (1/m)
delta_fy (float) – Frequency sampling intervals (1/m)
- Returns:
Transformed field [My, Mx]
- Return type:
Examples
>>> field = light.get_field().squeeze() >>> U = sc_dft_2d(field, 1024, 1024, pitch, pitch, 1/(pitch*1024), 1/(pitch*1024))
- sc_idft_2d(U, Mx, My, delta_x, delta_y, delta_fx, delta_fy)[source]¶
Perform 2D scaled inverse DFT using separable 1D transforms.
- Parameters:
U (torch.Tensor) – Frequency domain input [My, Mx]
Mx (int) – Number of samples in x,y directions
My (int) – Number of samples in x,y directions
delta_x (float) – Target spatial sampling intervals (m)
delta_y (float) – Target spatial sampling intervals (m)
delta_fx (float) – Frequency sampling intervals (1/m)
delta_fy (float) – Frequency sampling intervals (1/m)
- Returns:
Spatial domain output [My, Mx]
- Return type:
Examples
>>> U = sc_dft_2d(field, Mx, My, dx, dy, dfx, dfy) >>> field_recovered = sc_idft_2d(U, Mx, My, dx, dy, dfx, dfy)
- compute_scasm_transfer_function(Mx, My, delta_fx, delta_fy, λ, z)[source]¶
Compute transfer function for Scaled Angular Spectrum Method propagation.
- Parameters:
- Returns:
Transfer function H(fx,fy) [My, Mx]
- Return type:
Examples
>>> H = compute_scasm_transfer_function(1024, 1024, 1/(1024*2e-6), 1/(1024*2e-6), 633e-9, 0.1) >>> U_prop = torch.fft.fft2(light.get_field()) * H
pado.optical_element module¶
- class OpticalElement(dim, pitch, wvl, field_change=None, device='cpu', name='not defined', polar='non')[source]¶
Bases:
object
- Parameters:
- __init__(dim, pitch, wvl, field_change=None, device='cpu', name='not defined', polar='non')[source]¶
Base class for optical elements that modify incident light wavefront.
The wavefront modification is stored as amplitude and phase tensors. Note that the number of channels is one for wavefront modulation.
- Parameters:
dim (tuple) – Dimensions (B, 1, R, C) for batch size, channels, rows, columns
pitch (float) – Pixel pitch in meters
wvl (float) – Wavelength of light in meters
field_change (torch.Tensor, optional) – Wavefront modification tensor [B, C, H, W]
device (str) – Device to store wavefront (‘cpu’, ‘cuda:0’, etc.)
name (str) – Name identifier for this optical element
polar (str) – Polarization mode (‘non’: scalar, ‘polar’: vector)
- Return type:
None
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.field_change.shape torch.Size([1, 1, 100, 100])
- forward(light, interp_mode='nearest')[source]¶
Propagate incident light through the optical element.
- Parameters:
- Returns:
Light field after interaction with optical element
- Return type:
Examples
>>> element = OpticalElement(dim=(1, 1, 64, 64), pitch=2e-6) >>> light = Light(dim=(1, 1, 64, 64), pitch=2e-6) >>> output = element.forward(light)
- forward_non_polar(light, interp_mode='nearest')[source]¶
Propagate non-polarized light through the optical element.
Handles resolution matching between light and optical element by resizing and padding as needed. Applies the optical element’s field modulation to the input light.
- Parameters:
- Returns:
Modified light field after interaction with optical element
- Return type:
- Raises:
ValueError – If wavelengths of light and element don’t match
- get_amplitude_change()[source]¶
Return amplitude change of the wavefront.
- Returns:
Amplitude change of the wavefront
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> amp = element.get_amplitude_change()
- get_device()[source]¶
Returns the device on which tensors are stored.
- Returns:
The device identifier (e.g., ‘cpu’, ‘cuda:0’).
- Return type:
- get_field_change()[source]¶
Returns the field_change tensor.
- Returns:
The field_change tensor representing amplitude and phase changes.
- Return type:
- get_name()[source]¶
Returns the name of the optical element.
- Returns:
The name identifier.
- Return type:
- get_phase_change()[source]¶
Return phase change of the wavefront.
- Returns:
Phase change of the wavefront
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> phase = element.get_phase_change()
- get_polar()[source]¶
Returns the polarization mode.
- Returns:
The polarization mode (‘non’ for scalar, ‘polar’ for vector).
- Return type:
- pad(pad_width, padval=0)[source]¶
Pad the wavefront change with constant value.
- Parameters:
- Raises:
NotImplementedError – If padval is not 0
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.pad((10,10,10,10)) # Add 10 pixels padding on all sides
- resize(target_pitch, interp_mode='nearest')[source]¶
Resize the wavefront change by changing the pixel pitch.
- Parameters:
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.resize(1e-6) # Resize to 1μm pitch
- set_amplitude_change(amplitude, c=None)[source]¶
Set amplitude change for specific or all channels.
- Parameters:
amplitude (torch.Tensor) – Amplitude change in polar representation
c (int, optional) – Channel index. If None, applies to all channels
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> amp = torch.ones((1,1,100,100)) >>> element.set_amplitude_change(amp)
- set_field_change(field_change, c=None)[source]¶
Set field change for specific or all channels.
- Parameters:
field_change (torch.Tensor) – Field change in complex tensor
c (int, optional) – Channel index. If None, applies to all channels
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> field = torch.ones((1,1,100,100), dtype=torch.cfloat) >>> element.set_field_change(field)
- set_phase_change(phase, c=None)[source]¶
Set phase change for specific or all channels.
- Parameters:
phase (torch.Tensor) – Phase change in polar representation
c (int, optional) – Channel index. If None, applies to all channels
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> phase = torch.zeros((1,1,100,100)) >>> element.set_phase_change(phase)
- set_pitch(pitch)[source]¶
Set the pixel pitch of the complex tensor.
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.set_pitch(1e-6) # Set 1μm pitch
- set_polar(polar)[source]¶
Set polarization mode for the optical element.
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.set_polar('polar') # Set vector field mode
- shape()[source]¶
Return shape of light-wavefront modulation.
The number of channels is one for wavefront modulation.
- Returns:
Dimensions (B, 1, R, C) for batch size, channels, rows, columns
- Return type:
Examples
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9) >>> element.shape() (1, 1, 100, 100)
- visualize(b=0, c=None)[source]¶
Visualize the wavefront modulation of the optical element.
Displays amplitude and phase changes of the optical element’s wavefront modulation. Creates subplots showing amplitude change and phase change for specified channels.
- Parameters:
- Return type:
Examples
>>> lens = RefractiveLens((1,1,512,512), 2e-6, 0.1, 633e-9, 'cpu') >>> lens.visualize() # Visualize first batch, all channels >>> lens.visualize(b=0, c=0) # Visualize first batch, first channel
- class RefractiveLens(dim, pitch, focal_length, wvl, device, polar='non', designated_wvl=None)[source]¶
Bases:
OpticalElement
- Parameters:
- __init__(dim, pitch, focal_length, wvl, device, polar='non', designated_wvl=None)[source]¶
Create a thin refractive lens optical element.
Simulates a thin refractive lens that modifies the phase of incident light based on its focal length and wavelength.
- Parameters:
dim (tuple) – Shape of the lens field (B, Ch, R, C) where: B: Batch size Ch: Number of channels R: Number of rows C: Number of columns
pitch (float) – Pixel pitch in meters
focal_length (float) – Focal length of the lens in meters
wvl (float or list) – Wavelength(s) of light in meters. Can be single value or list for multi-channel
device (str) – Device to store the lens field (‘cpu’, ‘cuda:0’, etc.)
polar (str, optional) – Polarization mode. Defaults to ‘non’
designated_wvl (float, optional) – Override wavelength for all channels. Defaults to None
- Return type:
None
Examples
>>> # Create single channel lens >>> lens = RefractiveLens((1,1,512,512), 2e-6, 0.1, 633e-9, 'cpu')
>>> # Create multi-channel lens with different wavelengths >>> lens = RefractiveLens((1,3,512,512), 2e-6, 0.1, [633e-9,532e-9,450e-9], 'cuda:0')
- set_focal_length(focal_length)[source]¶
Set the focal length of the lens.
Examples
>>> lens.set_focal_length(0.2) # Set 20cm focal length
- compute_phase(wvl, shift_x=0, shift_y=0)[source]¶
Compute the phase modulation for the lens.
Calculates the phase change introduced by the lens based on its focal length, wavelength and any lateral shifts.
- Parameters:
- Returns:
Phase modulation pattern of the lens
- Return type:
Examples
>>> phase = lens.compute_phase(633e-9) # Centered lens >>> phase = lens.compute_phase(633e-9, shift_x=10e-6) # Shifted lens
- class CosineSquaredLens(dim, pitch, focal_length, wvl, device, polar='non')[source]¶
Bases:
OpticalElement
- Parameters:
- __init__(dim, pitch, focal_length, wvl, device, polar='non')[source]¶
Lens with cosine squared phase distribution.
Creates a lens with phase distribution of form [1+cos(k*r^2)]/2.
- Parameters:
dim (tuple) – Field dimensions (B, 1, R, C) - batch size, channels, rows, columns
pitch (float) – Pixel pitch in meters
focal_length (float) – Focal length in meters
wvl (float) – Wavelength in meters
device (str) – Device to store wavefront (‘cpu’, ‘cuda:0’, …)
polar (str) – Polarization mode (‘non’: scalar, ‘polar’: vector)
- Return type:
None
Examples
>>> # Create basic cosine squared lens >>> lens = CosineSquaredLens((1,1,1024,1024), 2e-6, 0.1, 633e-9, 'cpu')
- height2phase(height, wvl, RI, wrap=True)[source]¶
Convert material height to corresponding phase shift.
Calculates phase shift from material height using wavelength and refractive index.
- Parameters:
- Returns:
Phase change induced by material height
- Return type:
Examples
>>> height = 500e-9 # 500nm height >>> phase = height2phase(height, 633e-9, 1.5)
- phase2height(phase_u, wvl, RI, minh=0)[source]¶
Convert phase change to material height.
Note that phase to height mapping is not one-to-one. There exists an integer phase wrapping factor:
height = wvl/(RI-1) * (phase_u + i*2π), where i is integer
This function uses minimum height minh to constrain the conversion. Minimal height is chosen such that height is always >= minh.
- Parameters:
phase_u (torch.Tensor) – Phase change of light
wvl (float) – Wavelength of light in meters
RI (float) – Refractive index of material at given wavelength
minh (float) – Minimum height constraint in meters
- Returns:
Material height that induces the phase change
- Return type:
Examples
>>> phase = torch.ones((1,1,1024,1024)) * np.pi >>> height = phase2height(phase, 633e-9, 1.5, minh=100e-9) # 100nm min height
- class DOE(dim, pitch, material, wvl, device, height=None, phase_change=None, polar='non')[source]¶
Bases:
OpticalElement
- Parameters:
- __init__(dim, pitch, material, wvl, device, height=None, phase_change=None, polar='non')[source]¶
Diffractive optical element (DOE) that modifies incident light wavefront.
The wavefront modification is determined by the material height profile. Supports both height and phase change specifications.
- Parameters:
dim (tuple) – Dimensions (B, 1, R, C) for batch size, channels, rows, columns
pitch (float) – Pixel pitch in meters
material (Material) – Material properties of the DOE
wvl (float) – Wavelength of light in meters
device (str) – Device to store wavefront (‘cpu’, ‘cuda:0’, etc.)
height (torch.Tensor, optional) – Height profile in meters
phase_change (torch.Tensor, optional) – Phase change profile
polar (str) – Polarization mode (‘non’: scalar, ‘polar’: vector)
Examples
>>> # Create DOE with specified height profile >>> height = torch.ones((1,1,100,100)) * 500e-9 # 500nm height >>> doe = DOE(height.shape, 2e-6, material, 500e-9, 'cpu', height=height)
>>> # Create DOE with specified phase profile >>> phase = torch.ones((1,1,100,100)) * np.pi # π phase >>> doe = DOE(phase.shape, 2e-6, material, 500e-9, 'cpu', phase_change=phase)
- visualize(b=0, c=0)[source]¶
Visualize the DOE wavefront modulation.
Displays amplitude change, phase change and height profile.
- Parameters:
- Return type:
Examples
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu') >>> doe.visualize() # Shows modulation plots >>> doe.visualize(b=1, c=0) # Shows plots for batch index 1
- set_diffraction_grating_1d(slit_width, minh, maxh)[source]¶
Set the wavefront modulation as a 1D diffraction grating.
Create alternating height regions to form a binary phase grating.
- Parameters:
- Return type:
Examples
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu') >>> doe.set_diffraction_grating_1d(10e-6, 0, 500e-9) # 10μm slits >>> doe.visualize() # Shows 1D grating pattern
- set_diffraction_grating_2d(slit_width, minh, maxh)[source]¶
Set the wavefront modulation as a 2D diffraction grating.
Create a checkerboard pattern of alternating height regions.
- Parameters:
- Return type:
Examples
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu') >>> doe.set_diffraction_grating_2d(10e-6, 0, 500e-9) # 10μm slits >>> doe.visualize() # Shows 2D grating pattern
- set_Fresnel_lens(focal_length, wvl, shift_x=0, shift_y=0)[source]¶
Set the wavefront modulation as a Fresnel lens.
Create a phase profile that focus light to a point.
- Parameters:
- Return type:
Examples
>>> doe = DOE((1,1,1024,1024), 2e-6, material, 500e-9, 'cpu') >>> doe.set_Fresnel_lens(0.1, 500e-9) # f=10cm lens >>> doe.set_Fresnel_lens(0.1, 500e-9, shift_x=50e-6) # Shifted lens
- set_Fresnel_zone_plate_lens(focal_length, wvl, shift_x=0, shift_y=0)[source]¶
Set binary Fresnel zone plate pattern.
Creates alternating opaque and transparent zones that focus light.
- Parameters:
- Return type:
Examples
>>> doe = DOE((1,1,1024,1024), 2e-6, material, 500e-9, 'cpu') >>> doe.set_Fresnel_zone_plate_lens(0.1, 500e-9) # f=10cm lens >>> doe.set_Fresnel_zone_plate_lens(0.1, 500e-9, shift_x=50e-6) # Shifted lens
- change_wvl(wvl)[source]¶
Change the wavelength and update phase change.
Examples
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu') >>> doe.change_wvl(633e-9) # Change to 633nm wavelength
- sync_height_with_phase()[source]¶
Synchronize height profile with current phase profile. :rtype:
None
Examples
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu') >>> doe.set_phase_change(phase, sync_height=False) >>> doe.sync_height_with_phase() # Update height to match phase
- Return type:
None
- sync_phase_with_height()[source]¶
Synchronize phase profile with current height profile. :rtype:
None
Examples
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu') >>> doe.set_height(height, sync_phase=False) >>> doe.sync_phase_with_height() # Update phase to match height
- Return type:
None
- resize(target_pitch)[source]¶
Resize DOE with a new pixel pitch.
Resize field from which DOE height is recomputed.
Examples
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu') >>> doe.resize(1e-6) # Change pitch to 1μm
- get_height()[source]¶
Return the height map of the DOE.
- Returns:
Height map in meters
- Return type:
Examples
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu') >>> height = doe.get_height() # Get current height profile
- set_phase_change(phase_change, sync_height=True)[source]¶
Set phase change induced by the DOE.
- Parameters:
phase_change (torch.Tensor) – Phase change profile
sync_height (bool) – If True, syncs height profile
- Return type:
Examples
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu') >>> phase = torch.ones((1,1,100,100)) * np.pi >>> doe.set_phase_change(phase, sync_height=True)
- set_field_change(field_change, sync_height=True)[source]¶
Change the field change of the DOE.
- Parameters:
field_change (torch.Tensor) – Complex field change tensor
sync_height (bool) – If True, syncs height profile
- Return type:
Examples
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu') >>> field = torch.exp(1j * torch.ones((1,1,100,100))) >>> doe.set_field_change(field, sync_height=True)
- set_height(height, sync_phase=True)[source]¶
Set the height map of the DOE.
- Parameters:
height (torch.Tensor) – Height map in meters
sync_phase (bool) – If True, syncs phase profile
- Return type:
Examples
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu') >>> height = torch.ones((1,1,100,100)) * 500e-9 >>> doe.set_height(height, sync_phase=True)
- class SLM(dim, pitch, wvl, device, polar='non')[source]¶
Bases:
OpticalElement
- __init__(dim, pitch, wvl, device, polar='non')[source]¶
Spatial Light Modulator (SLM) optical element.
- Parameters:
Examples
>>> slm = SLM(dim=(1,1,1024,1024), pitch=6.4e-6, wvl=633e-9, device='cuda:0')
- set_lens(focal_length, shift_x=0, shift_y=0)[source]¶
Set phase profile to implement a thin lens.
- Parameters:
- Return type:
Examples
>>> slm.set_lens(focal_length=0.5, shift_x=100e-6) # 500mm focal length, 100μm x-shift
- set_amplitude_change(amplitude, wvl)[source]¶
Set amplitude modulation profile of the SLM.
- Parameters:
amplitude (torch.Tensor) – Amplitude modulation profile [B, 1, R, C]
wvl (float) – Operating wavelength in meters
- Return type:
Examples
>>> amp = torch.ones((1,1,1024,1024)) * 0.8 # 80% transmission >>> slm.set_amplitude_change(amp, wvl=633e-9)
- set_phase_change(phase_change, wvl)[source]¶
Set phase modulation profile of the SLM.
- Parameters:
phase_change (torch.Tensor) – Phase modulation profile [B, 1, R, C] in radians
wvl (float) – Operating wavelength in meters
- Return type:
Examples
>>> phase = torch.ones((1,1,1024,1024)) * np.pi # π phase shift >>> slm.set_phase_change(phase, wvl=633e-9)
- class PolarizedSLM(dim, pitch, wvl, device)[source]¶
Bases:
OpticalElement
- __init__(dim, pitch, wvl, device)[source]¶
SLM which can control phase & amplitude of each polarization component.
- Parameters:
Examples
>>> slm = PolarizedSLM(dim=(1,1,1024,1024), pitch=6.4e-6, wvl=633e-9, device='cuda:0')
- set_amplitude_change(amplitude, wvl)[source]¶
Set amplitude change for both polarization components.
- Parameters:
amplitude (torch.Tensor) – Amplitude change [B, 1, R, C, 2] in polar representation
wvl (float) – Wavelength in meters
- Return type:
Examples
>>> amp = torch.ones((1,1,1024,1024,2)) * 0.8 # 80% transmission for both polarizations >>> slm.set_amplitude_change(amp, wvl=633e-9)
- set_phase_change(phase_change, wvl)[source]¶
Set phase change for both polarization components.
- Parameters:
phase_change (torch.Tensor) – Phase change [B, 1, R, C, 2] in polar representation
wvl (float) – Wavelength in meters
- Return type:
Examples
>>> phase = torch.ones((1,1,1024,1024,2)) * np.pi # π phase shift for both polarizations >>> slm.set_phase_change(phase, wvl=633e-9)
- set_amplitudeX_change(amplitude, wvl)[source]¶
Set amplitude change for X polarization component.
- Parameters:
amplitude (torch.Tensor) – Amplitude change [B, 1, R, C] for X component
wvl (float) – Wavelength in meters
- Return type:
Examples
>>> ampX = torch.ones((1,1,1024,1024)) * 0.8 # 80% transmission for X polarization >>> slm.set_amplitudeX_change(ampX, wvl=633e-9)
- set_amplitudeY_change(amplitude, wvl)[source]¶
Set amplitude change for Y polarization component.
- Parameters:
amplitude (torch.Tensor) – Amplitude change [B, 1, R, C] for Y component
wvl (float) – Wavelength in meters
- Return type:
Examples
>>> ampY = torch.ones((1,1,1024,1024)) * 0.6 # 60% transmission for Y polarization >>> slm.set_amplitudeY_change(ampY, wvl=633e-9)
- set_phaseX_change(phase_change, wvl)[source]¶
Set phase change for X polarization component.
- Parameters:
phase_change (torch.Tensor) – Phase change [B, 1, R, C] for X component
wvl (float) – Wavelength in meters
- Return type:
Examples
>>> phaseX = torch.ones((1,1,1024,1024)) * np.pi/2 # π/2 phase shift for X polarization >>> slm.set_phaseX_change(phaseX, wvl=633e-9)
- set_phaseY_change(phase_change, wvl)[source]¶
Set phase change for Y polarization component.
- Parameters:
phase_change (torch.Tensor) – Phase change [B, 1, R, C] for Y component
wvl (float) – Wavelength in meters
- Return type:
Examples
>>> phaseY = torch.ones((1,1,1024,1024)) * np.pi # π phase shift for Y polarization >>> slm.set_phaseY_change(phaseY, wvl=633e-9)
- get_phase_changeX()[source]¶
Return phase change for X polarization component.
- Returns:
Phase change [B, 1, R, C] for X component
- Return type:
Examples
>>> phaseX = slm.get_phase_changeX() # Get X polarization phase profile
- get_phase_changeY()[source]¶
Return phase change for Y polarization component.
- Returns:
Phase change [B, 1, R, C] for Y component
- Return type:
Examples
>>> phaseY = slm.get_phase_changeY() # Get Y polarization phase profile
- get_amplitude_changeX()[source]¶
Return amplitude change for X polarization component.
- Returns:
Amplitude change [B, 1, R, C] for X component
- Return type:
Examples
>>> ampX = slm.get_amplitude_changeX() # Get X polarization amplitude profile
- get_amplitude_changeY()[source]¶
Return amplitude change for Y polarization component.
- Returns:
Amplitude change [B, 1, R, C] for Y component
- Return type:
Examples
>>> ampY = slm.get_amplitude_changeY() # Get Y polarization amplitude profile
- forward(light, interp_mode='nearest')[source]¶
Apply polarization-dependent modulation to input light.
- Parameters:
- Returns:
Modulated light field
- Return type:
Examples
>>> modulated_light = slm.forward(input_light) # Apply polarization modulation >>> modulated_light = slm.forward(input_light, interp_mode='bilinear') # Use bilinear interpolation
- class Aperture(dim, pitch, aperture_diameter, aperture_shape, wvl, device='cpu', polar='non')[source]¶
Bases:
OpticalElement
Aperture optical element for amplitude modulation.
Implement square or circular aperture that modulate light amplitude. Support both polarized and non-polarized light.
- Parameters:
- __init__(dim, pitch, aperture_diameter, aperture_shape, wvl, device='cpu', polar='non')[source]¶
Create aperture optical element instance.
- Parameters:
dim (tuple) – Field dimensions (B, 1, R, C) for batch, channels, rows, cols
pitch (float) – Pixel pitch in meters
aperture_diameter (float) – Diameter of aperture in meters
aperture_shape (str) – Shape of aperture (‘square’ or ‘circle’)
wvl (float) – Wavelength in meters
device (str) – Device for computation (‘cpu’, ‘cuda:0’, etc.)
polar (str) – Polarization mode (‘non’, ‘x’, ‘y’, ‘xy’)
Examples
>>> aperture = Aperture(dim=(1,1,1024,1024), pitch=6.4e-6, ... aperture_diameter=1e-3, aperture_shape='circle', ... wvl=633e-9)
- set_square()[source]¶
Set square aperture amplitude modulation.
Create square aperture mask centered on optical axis. :rtype:
None
Examples
>>> aperture.set_square()
- Return type:
None
- quantize(x, levels, vmin=None, vmax=None, include_vmax=True)[source]¶
Quantize floating point array.
Discretize input array into specified number of levels.
- Parameters:
x (torch.Tensor or np.ndarray) – Input array to quantize
levels (int) – Number of quantization levels
vmin (float, optional) – Minimum value for quantization
vmax (float, optional) – Maximum value for quantization
include_vmax (bool) – Whether to include max value in quantization True: Quantize with spacing of 1/levels False: Quantize with spacing of 1/(levels-1)
- Returns:
Quantized array
- Return type:
torch.Tensor or np.ndarray
Examples
>>> x = torch.randn(100) >>> x_quant = quantize(x, levels=8) >>> x_quant = quantize(x, levels=16, vmin=-1, vmax=1)
pado.propagator module¶
- compute_pad_width(field, linear)[source]¶
Compute padding width for FFT-based convolution.
- Parameters:
field (torch.Tensor) – Complex tensor of shape (B, Ch, R, C)
linear (bool) – Flag for linear convolution (with padding) or circular convolution (no padding)
- Returns:
Padding width tuple
- Return type:
- unpad(field_padded, pad_width)[source]¶
Remove padding from a padded complex tensor.
- Parameters:
field_padded (torch.Tensor) – Padded complex tensor of shape (B, Ch, R, C)
pad_width (tuple) – Padding width tuple
- Returns:
Unpadded complex tensor
- Return type:
- class Propagator(mode, polar='non')[source]¶
Bases:
object
- __init__(mode, polar='non')[source]¶
Light propagator for simulating wave propagation through free space.
Implement common diffraction methods including Fraunhofer, Fresnel, ASM and RS. Support complex field calculations.
- Parameters:
Examples
>>> # Create ASM propagator for scalar field >>> prop = Propagator(mode="ASM", polar="non")
>>> # Create Fresnel propagator for vector field >>> prop = Propagator(mode="Fresnel", polar="polar")
>>> # Create Fraunhofer propagator >>> prop = Propagator(mode="ASM") >>> field = torch.ones((1, 1, 1000, 1000)) >>> light = Light(field, pitch=2e-6, wvl=660e-9) >>> light_prop = prop.forward(light, z=0.05)
- forward(light, z, offset=(0, 0), linear=True, band_limit=True, b=1, target_plane=None, sampling_ratio=1, vectorized=False, steps=100)[source]¶
Propagate incident light through the propagator.
- Parameters:
light (Light) – Incident light field
z (float) – Propagation distance in meters
offset (tuple) – Lateral shift (y, x) in meters for off-axis propagation
linear (bool) – Flag for linear convolution (with padding) or circular convolution (no padding)
band_limit (bool) – If True, apply band-limiting for ASM
b (float) – Scaling factor for observation plane (b>1: expansion, b<1: focusing)
target_plane (tuple, optional) – (x, y, z) coordinates for RS diffraction
sampling_ratio (int) – Spatial sampling ratio for RS computation
vectorized (bool) – If True, use vectorized implementation for RS (better performance but higher memory usage)
steps (int) – Number of computation steps for vectorized RS (higher values use less memory)
- Returns:
Propagated light field
- Return type:
Examples
>>> # Basic propagation >>> prop = Propagator(mode="ASM") >>> light_prop = prop.forward(light, z=0.1)
>>> # Propagation with padding >>> light_prop = prop.forward(light, z=0.1, linear=True)
>>> # Vector field propagation >>> prop = Propagator(mode="Fresnel", polar="polar") >>> light_prop = prop.forward(light, z=0.05) >>> # Access x,y components >>> x_component = light_prop.get_lightX() >>> y_component = light_prop.get_lightY()
>>> # Vectorized RS propagation for better performance >>> prop = Propagator(mode="RS") >>> light_prop = prop.forward(light, z=0.1, vectorized=True, steps=50)
- forward_non_polar(light, z, offset=(0, 0), linear=True, band_limit=True, b=1, target_plane=None, sampling_ratio=1, vectorized=False, steps=100)[source]¶
Propagate non-polarized light field using selected propagation method.
- Parameters:
light (Light) – Input light field
z (float) – Propagation distance in meters
offset (tuple) – Lateral shift (y, x) in meters for off-axis propagation
linear (bool) – If True, use linear convolution with zero-padding
band_limit (bool) – If True, apply band-limiting for ASM
b (float) – Scaling factor for observation plane (b>1: expansion, b<1: focusing)
target_plane (tuple, optional) – (x, y, z) coordinates for RS diffraction
sampling_ratio (int) – Spatial sampling ratio for RS computation
vectorized (bool) – If True, use vectorized implementation for RS
steps (int) – Number of computation steps for vectorized RS
- Returns:
Propagated light field
- Return type:
- forward_Fraunhofer(light, z, linear=True)[source]¶
Propagate light using Fraunhofer diffraction.
Implement far-field diffraction with multi-wavelength support. The propagated field is independent of z distance, which only affect the output pixel pitch.
- forward_Fresnel(light, z, linear)[source]¶
Propagate light using Fresnel diffraction.
Implement Fresnel approximation with multi-wavelength support. Valid when z >> (x² + y²)_max/λ.
- forward_FFT(light, z=None)[source]¶
Propagate light using simple FFT-based propagation.
Apply exp(1j * phase) before FFT without considering propagation distance z or padding. Used for basic Fourier transform of the input field.
- forward_ASM(light, z, offset=(0, 0), linear=True, band_limit=True, b=1)[source]¶
Select appropriate ASM propagation method based on parameters.
Automatically choose between standard ASM, band-limited ASM, and scaled ASM depending on the scaling factor b and offset requirements.
- Parameters:
- Returns:
Propagated light field using selected ASM method
- Return type:
- forward_standard_ASM(light, z, offset=(0, 0), linear=True)[source]¶
Propagate light using standard Angular Spectrum Method.
Implement basic ASM propagation with optional off-axis shift. Support multi-wavelength channels and linear/circular convolution.
- forward_shifted_BL_ASM(light, z, offset=(0, 0), linear=True)[source]¶
Propagate light using band-limited Angular Spectrum Method.
Implement shifted band-limited ASM for improved numerical stability in off-axis propagation. Based on Matsushima’s method.
- Parameters:
- Returns:
Propagated light field
- Return type:
References
Matsushima, “Shifted angular spectrum method for off-axis numerical propagation”
- forward_ScASM(light, z, b, linear=True)[source]¶
Propagate light using scaled Angular Spectrum Method.
This function perform a scaled forward angular spectrum propagation. It take an input optical field ‘light’, propagate it over a distance ‘z’, and scale the observation plane by a factor ‘b’ relative to the source plane. If ‘linear’ is True, zero-padding is applied to avoid wrap-around effects from FFT-based convolutions. refer to M. Abedi, H. Saghafifar, and L. Rahimi, “Improvement of optical wave propagation simulations: the scaled angular spectrum method for far-field and focal analysis,” Opt. Continuum 3, 935-947 (2024)
- Parameters:
- Returns:
Propagated light field with scaled observation plane
- Return type:
References
Abedi et al., “Improvement of optical wave propagation simulations: the scaled angular spectrum method for far-field and focal analysis”
- forward_ScASM_focusing(light, z, b, linear=True)[source]¶
Propagate light using scaled ASM for focusing.
Implement scaled ASM propagation for focusing to a smaller observation plane. Use FFT to transform to frequency domain, apply transfer function, then use Sc-IDFT to resample field at smaller observation plane. Support multi-wavelength channels and linear/circular convolution.
- forward_RayleighSommerfeld(light, z, target_plane=None, sampling_ratio=1, vectorized=False, steps=100)[source]¶
Propagate light using Rayleigh-Sommerfeld diffraction.
Implement exact scalar diffraction calculation. Computationally intensive but accurate for both near and far field. Support arbitrary target plane geometry.
- Parameters:
light (Light) – Input light field
z (float) – Propagation distance in meters
target_plane (tuple, optional) – (x, y, z) coordinates of target plane points
sampling_ratio (int) – Spatial sampling ratio for computation (>1 for faster calculation)
vectorized (bool) – If True, use vectorized implementation for better performance
steps (int) – Number of computation steps for vectorized mode (higher values use less memory)
- Returns:
Propagated light field
- Return type: