########################################################
# The MIT License (MIT)
#
# PADO (Pytorch Automatic Differentiable Optics)
# Copyright (c) 2023 by POSTECH Computer Graphics Lab
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#
# Contact:
# Lead Developer: Dong-Ha Shin (0218sdh@gmail.com)
# Corresponding Author: Seung-Hwan Baek (shwbaek@postech.ac.kr)
#
########################################################
import numpy as np
import matplotlib.pyplot as plt
from typing import Tuple, Optional, Union, List
import torch
import torch.nn.functional as F
from .math import wrap_phase
from .math import nm, um, mm, cm, m
from .light import Light
from .material import Material
[docs]
class OpticalElement:
[docs]
def __init__(self, dim: Tuple[int, int, int, int], pitch: float, wvl: float,
field_change: Optional[torch.Tensor] = None, device: str = 'cpu',
name: str = "not defined", polar: str = 'non') -> None:
"""Base class for optical elements that modify incident light wavefront.
The wavefront modification is stored as amplitude and phase tensors.
Note that the number of channels is one for wavefront modulation.
Args:
dim (tuple): Dimensions (B, 1, R, C) for batch size, channels, rows, columns
pitch (float): Pixel pitch in meters
wvl (float): Wavelength of light in meters
field_change (torch.Tensor, optional): Wavefront modification tensor [B, C, H, W]
device (str): Device to store wavefront ('cpu', 'cuda:0', etc.)
name (str): Name identifier for this optical element
polar (str): Polarization mode ('non': scalar, 'polar': vector)
Examples:
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.field_change.shape
torch.Size([1, 1, 100, 100])
"""
self.name = name
self.dim = dim
self.pitch = pitch
self.device = device
if field_change is None:
self.field_change = torch.ones(dim, dtype=torch.cfloat, device=device)
else:
self.field_change = field_change
self.wvl = wvl
self.polar = polar
[docs]
def forward(self, light: 'Light', interp_mode: str = 'nearest') -> 'Light':
"""Propagate incident light through the optical element.
Args:
light (Light): Input light field
interp_mode (str): Interpolation method for resizing ('bilinear', 'nearest')
Returns:
Light: Light field after interaction with optical element
Examples:
>>> element = OpticalElement(dim=(1, 1, 64, 64), pitch=2e-6)
>>> light = Light(dim=(1, 1, 64, 64), pitch=2e-6)
>>> output = element.forward(light)
"""
if light.pitch > self.pitch:
light.resize(self.pitch, interp_mode)
light.set_pitch(self.pitch)
elif light.pitch < self.pitch:
self.resize(light.pitch, interp_mode)
self.set_pitch(light.pitch)
if self.polar=='non':
return self.forward_non_polar(light, interp_mode)
elif self.polar=='polar':
x = self.forward_non_polar(light.get_lightX(), interp_mode)
y = self.forward_non_polar(light.get_lightY(), interp_mode)
light.set_lightX(x)
light.set_lightY(y)
return light
else:
raise NotImplementedError('Polar is not set.')
[docs]
def forward_non_polar(self, light: 'Light', interp_mode: str = 'nearest') -> 'Light':
"""Propagate non-polarized light through the optical element.
Handles resolution matching between light and optical element by resizing and padding
as needed. Applies the optical element's field modulation to the input light.
Args:
light (Light): Input light field to propagate through the element
interp_mode (str): Interpolation method for resizing ('bilinear', 'nearest')
Returns:
Light: Modified light field after interaction with optical element
Raises:
ValueError: If wavelengths of light and element don't match
"""
if light.wvl != self.wvl:
raise ValueError(f'Wavelength mismatch: light wavelength {light.wvl} != element wavelength {self.wvl}')
# make sure that light and optical element have the same resolution, i.e. pixel count, by padding the smaller one
r1 = np.abs((light.dim[2] - self.dim[2])//2)
r2 = np.abs(light.dim[2] - self.dim[2]) - r1
pad_width = (r1, r2, 0, 0)
if light.dim[2] > self.dim[2]:
self.pad(pad_width)
elif light.dim[2] < self.dim[2]:
light.pad(pad_width)
c1 = np.abs((light.dim[3] - self.dim[3])//2)
c2 = np.abs(light.dim[3] - self.dim[3]) - c1
pad_width = (0, 0, c1, c2)
if light.dim[3] > self.dim[3]:
self.pad(pad_width)
elif light.dim[3] < self.dim[3]:
light.pad(pad_width)
light.set_field(light.field*self.field_change)
return light
[docs]
def get_amplitude_change(self) -> torch.Tensor:
"""Return amplitude change of the wavefront.
Returns:
torch.Tensor: Amplitude change of the wavefront
Examples:
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> amp = element.get_amplitude_change()
"""
return self.field_change.abs()
[docs]
def get_device(self) -> str:
"""Returns the device on which tensors are stored.
Returns:
str: The device identifier (e.g., 'cpu', 'cuda:0').
"""
return self.device
[docs]
def get_field_change(self) -> torch.Tensor:
"""Returns the field_change tensor.
Returns:
torch.Tensor: The field_change tensor representing amplitude and phase changes.
"""
return self.field_change
[docs]
def get_name(self) -> str:
"""Returns the name of the optical element.
Returns:
str: The name identifier.
"""
return self.name
[docs]
def get_phase_change(self) -> torch.Tensor:
"""Return phase change of the wavefront.
Returns:
torch.Tensor: Phase change of the wavefront
Examples:
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> phase = element.get_phase_change()
"""
return self.field_change.angle()
[docs]
def get_pitch(self) -> float:
"""Returns the pixel pitch.
Returns:
float: The pixel pitch in meters.
"""
return self.pitch
[docs]
def get_polar(self) -> str:
"""Returns the polarization mode.
Returns:
str: The polarization mode ('non' for scalar, 'polar' for vector).
"""
return self.polar
[docs]
def get_wvl(self) -> float:
"""Returns the wavelength.
Returns:
float: The wavelength in meters.
"""
return self.wvl
[docs]
def pad(self, pad_width: Tuple[int, int, int, int], padval: float = 0) -> None:
"""Pad the wavefront change with constant value.
Args:
pad_width (tuple): Padding width following torch.nn.functional.pad format
padval (float): Value to pad with, only 0 supported currently
Raises:
NotImplementedError: If padval is not 0
Examples:
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.pad((10,10,10,10)) # Add 10 pixels padding on all sides
"""
if padval == 0:
# Create a new padded tensor instead of modifying in-place
padded_field_change = torch.nn.functional.pad(self.field_change, pad_width)
self.field_change = padded_field_change
else:
raise NotImplementedError('only zero padding supported')
# Create a new dim tuple instead of modifying in-place
new_dim = list(self.dim)
new_dim[2], new_dim[3] = new_dim[2]+pad_width[0]+pad_width[1], new_dim[3]+pad_width[2]+pad_width[3]
self.dim = tuple(new_dim)
[docs]
def resize(self, target_pitch: float, interp_mode: str = 'nearest') -> None:
"""Resize the wavefront change by changing the pixel pitch.
Args:
target_pitch (float): New pixel pitch to use
interp_mode (str): Interpolation method used in torch.nn.functional.interpolate
- 'bilinear': Bilinear interpolation
- 'nearest': Nearest neighbor interpolation
Examples:
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.resize(1e-6) # Resize to 1μm pitch
"""
scale_factor = self.pitch / target_pitch
# Create a new interpolated tensor instead of modifying in-place
resized_field_change = F.interpolate(self.field_change, scale_factor=scale_factor, mode=interp_mode)
self.field_change = resized_field_change
# Create a new dim tuple instead of modifying in-place
new_dim = list(self.dim)
new_dim[2], new_dim[3] = resized_field_change.shape[2], resized_field_change.shape[3]
self.dim = tuple(new_dim)
self.set_pitch(target_pitch)
[docs]
def set_amplitude_change(self, amplitude: torch.Tensor, c: Optional[int] = None) -> None:
"""Set amplitude change for specific or all channels.
Args:
amplitude (torch.Tensor): Amplitude change in polar representation
c (int, optional): Channel index. If None, applies to all channels
Examples:
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> amp = torch.ones((1,1,100,100))
>>> element.set_amplitude_change(amp)
"""
if c is not None:
phase = self.field_change[:, c, ...].angle()
# Create a new field change tensor instead of modifying in-place
new_field_change = self.field_change.clone()
new_field_change[:, c, ...] = amplitude * torch.exp(phase * 1j)
self.field_change = new_field_change
else:
phase = self.field_change.angle()
# Create a completely new tensor
new_field_change = amplitude * torch.exp(phase * 1j)
self.field_change = new_field_change
[docs]
def set_field_change(self, field_change: torch.Tensor, c: Optional[int] = None) -> None:
"""Set field change for specific or all channels.
Args:
field_change (torch.Tensor): Field change in complex tensor
c (int, optional): Channel index. If None, applies to all channels
Examples:
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> field = torch.ones((1,1,100,100), dtype=torch.cfloat)
>>> element.set_field_change(field)
"""
if c is not None:
# Create a clone to avoid in-place modification
new_field_change = self.field_change.clone()
new_field_change[:, c, ...] = field_change
self.field_change = new_field_change
else:
# Create a new tensor for all channels
new_field_change = self.field_change.clone()
for chan in range(self.dim[1]):
new_field_change[:, chan, ...] = field_change
self.field_change = new_field_change
[docs]
def set_name(self, name: str) -> None:
"""Sets the name of the optical element.
Args:
name (str): The name identifier.
"""
self.name = name
[docs]
def set_phase_change(self, phase: torch.Tensor, c: Optional[int] = None) -> None:
"""Set phase change for specific or all channels.
Args:
phase (torch.Tensor): Phase change in polar representation
c (int, optional): Channel index. If None, applies to all channels
Examples:
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> phase = torch.zeros((1,1,100,100))
>>> element.set_phase_change(phase)
"""
if c is not None:
amplitude = self.field_change[:, c, ...].abs()
# Create a clone to avoid in-place modification
new_field_change = self.field_change.clone()
new_field_change[:, c, ...] = amplitude * torch.exp(phase * 1j)
self.field_change = new_field_change
else:
# Create a new tensor for all channels
new_field_change = self.field_change.clone()
for chan in range(self.dim[1]):
amplitude = self.field_change[:, chan, ...].abs()
new_field_change[:, chan, ...] = amplitude * torch.exp(phase * 1j)
self.field_change = new_field_change
[docs]
def set_pitch(self, pitch: float) -> None:
"""Set the pixel pitch of the complex tensor.
Args:
pitch (float): Pixel pitch in meters
Examples:
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.set_pitch(1e-6) # Set 1μm pitch
"""
if pitch <= 0:
raise ValueError(f"Pitch must be positive, got {pitch}")
self.pitch = pitch
[docs]
def set_polar(self, polar: str) -> None:
"""Set polarization mode for the optical element.
Args:
polar (str): Polarization mode ('non': scalar, 'polar': vector)
Examples:
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.set_polar('polar') # Set vector field mode
"""
self.polar = polar
[docs]
def set_wvl(self, wvl: float) -> None:
"""Sets the wavelength.
Args:
wvl (float): The wavelength in meters.
"""
self.wvl = wvl
[docs]
def shape(self) -> Tuple[int, int, int, int]:
"""Return shape of light-wavefront modulation.
The number of channels is one for wavefront modulation.
Returns:
tuple: Dimensions (B, 1, R, C) for batch size, channels, rows, columns
Examples:
>>> element = OpticalElement((1,1,100,100), pitch=2e-6, wvl=500e-9)
>>> element.shape()
(1, 1, 100, 100)
"""
return self.dim
[docs]
def visualize(self, b: int = 0, c: Optional[int] = None) -> None:
"""Visualize the wavefront modulation of the optical element.
Displays amplitude and phase changes of the optical element's wavefront modulation.
Creates subplots showing amplitude change and phase change for specified channels.
Args:
b (int, optional): Batch index to visualize. Defaults to 0.
c (int, optional): Channel index to visualize. If None, visualizes all channels.
Defaults to None.
Examples:
>>> lens = RefractiveLens((1,1,512,512), 2e-6, 0.1, 633e-9, 'cpu')
>>> lens.visualize() # Visualize first batch, all channels
>>> lens.visualize(b=0, c=0) # Visualize first batch, first channel
"""
channels = [c] if c is not None else range(self.dim[1])
for chan in channels:
plt.figure(figsize=(13,6))
plt.subplot(121)
plt.imshow(self.get_amplitude_change().data.cpu()[b,chan,...].squeeze(), cmap='inferno', vmin=0, vmax=1)
plt.title('amplitude change')
plt.colorbar()
plt.subplot(122)
plt.imshow(self.get_phase_change().data.cpu()[b,chan,...].squeeze(), cmap='hsv', vmin=-np.pi, vmax=np.pi)
plt.title('phase change')
plt.colorbar()
wvl_text = f'{self.wvl[chan]/nm:.2f}[nm]' if isinstance(self.wvl, list) else f'{self.wvl/nm:.2f}[nm]'
plt.suptitle(
f'{self.name}, '
f'({self.dim[2]},{self.dim[3]}), '
f'pitch:{self.pitch/um:.2f}[um], '
f'wvl:{wvl_text}, '
f'device:{self.device}'
)
[docs]
class RefractiveLens(OpticalElement):
[docs]
def __init__(self, dim: Tuple[int, int, int, int], pitch: float, focal_length: float,
wvl: Union[float, List[float]], device: str, polar: str = 'non',
designated_wvl: Optional[float] = None) -> None:
"""Create a thin refractive lens optical element.
Simulates a thin refractive lens that modifies the phase of incident light
based on its focal length and wavelength.
Args:
dim (tuple): Shape of the lens field (B, Ch, R, C) where:
B: Batch size
Ch: Number of channels
R: Number of rows
C: Number of columns
pitch (float): Pixel pitch in meters
focal_length (float): Focal length of the lens in meters
wvl (float or list): Wavelength(s) of light in meters. Can be single value or list for multi-channel
device (str): Device to store the lens field ('cpu', 'cuda:0', etc.)
polar (str, optional): Polarization mode. Defaults to 'non'
designated_wvl (float, optional): Override wavelength for all channels. Defaults to None
Examples:
>>> # Create single channel lens
>>> lens = RefractiveLens((1,1,512,512), 2e-6, 0.1, 633e-9, 'cpu')
>>> # Create multi-channel lens with different wavelengths
>>> lens = RefractiveLens((1,3,512,512), 2e-6, 0.1, [633e-9,532e-9,450e-9], 'cuda:0')
"""
super().__init__(dim, pitch, wvl, None, device, name="refractive_lens", polar=polar)
self.focal_length: Optional[float] = None
if focal_length is None:
raise ValueError("focal_length cannot be None")
self.set_focal_length(focal_length)
if dim[1] == 1:
phase = self.compute_phase(self.wvl, shift_x=0, shift_y=0)
# Create unit amplitude field with exact 1.0 amplitude
amplitude = torch.ones(phase.shape, dtype=torch.float64, device=self.device)
field_change = amplitude * torch.exp(1j * phase)
self.set_field_change(field_change, c=0)
else:
if designated_wvl is not None:
for i in range(dim[1]):
phase = self.compute_phase(designated_wvl, shift_x=0, shift_y=0)
# Create unit amplitude field with exact 1.0 amplitude
amplitude = torch.ones(phase.shape, dtype=torch.float64, device=self.device)
field_change = amplitude * torch.exp(1j * phase)
self.set_field_change(field_change, c=i)
else:
for i in range(dim[1]):
phase = self.compute_phase(self.wvl[i], shift_x=0, shift_y=0)
# Create unit amplitude field with exact 1.0 amplitude
amplitude = torch.ones(phase.shape, dtype=torch.float64, device=self.device)
field_change = amplitude * torch.exp(1j * phase)
self.set_field_change(field_change, c=i)
[docs]
def set_focal_length(self, focal_length: float) -> None:
"""Set the focal length of the lens.
Args:
focal_length (float): New focal length in meters
Examples:
>>> lens.set_focal_length(0.2) # Set 20cm focal length
"""
self.focal_length = focal_length
[docs]
def compute_phase(self, wvl: float, shift_x: float = 0, shift_y: float = 0) -> torch.Tensor:
"""Compute the phase modulation for the lens.
Calculates the phase change introduced by the lens based on its focal length,
wavelength and any lateral shifts.
Args:
wvl (float): Wavelength of light in meters
shift_x (float, optional): Horizontal displacement of lens center in meters. Defaults to 0
shift_y (float, optional): Vertical displacement of lens center in meters. Defaults to 0
Returns:
torch.Tensor: Phase modulation pattern of the lens
Examples:
>>> phase = lens.compute_phase(633e-9) # Centered lens
>>> phase = lens.compute_phase(633e-9, shift_x=10e-6) # Shifted lens
"""
x = np.arange(-self.dim[3]/2, self.dim[3]/2) * self.pitch
y = np.arange(-self.dim[2]/2, self.dim[2]/2) * self.pitch
xx, yy = np.meshgrid(x, y, indexing='xy')
theta_change = torch.tensor((-2*np.pi / wvl)*((xx-shift_x)**2 + (yy-shift_y)**2), device=self.device) / (2*self.focal_length)
theta_change = (theta_change + np.pi) % (np.pi * 2) - np.pi
theta_change = torch.unsqueeze(torch.unsqueeze(theta_change, axis=0), axis=0)
return theta_change
[docs]
class CosineSquaredLens(OpticalElement):
[docs]
def __init__(self, dim: Tuple[int, int, int, int], pitch: float, focal_length: float,
wvl: float, device: str, polar: str = 'non') -> None:
"""Lens with cosine squared phase distribution.
Creates a lens with phase distribution of form [1+cos(k*r^2)]/2.
Args:
dim (tuple): Field dimensions (B, 1, R, C) - batch size, channels, rows, columns
pitch (float): Pixel pitch in meters
focal_length (float): Focal length in meters
wvl (float): Wavelength in meters
device (str): Device to store wavefront ('cpu', 'cuda:0', ...)
polar (str): Polarization mode ('non': scalar, 'polar': vector)
Examples:
>>> # Create basic cosine squared lens
>>> lens = CosineSquaredLens((1,1,1024,1024), 2e-6, 0.1, 633e-9, 'cpu')
"""
super().__init__(dim, pitch, wvl, None, device, name="cosine_squared_lens", polar=polar)
self.focal_length: float = focal_length
self.compute_and_set_phase_change()
[docs]
def compute_and_set_phase_change(self) -> None:
"""Compute and set the phase change induced by the lens.
Calculates and applies phase change in range [0, π] to the lens.
Examples:
>>> lens.compute_and_set_phase_change()
"""
k = 20 * np.pi / self.wvl # Wave number
x = np.arange(-self.dim[3]/2, self.dim[3]/2) * self.pitch
y = np.arange(-self.dim[2]/2, self.dim[2]/2) * self.pitch
xx, yy = np.meshgrid(x, y, indexing='xy')
xx = torch.tensor(xx, device=self.device)
yy = torch.tensor(yy, device=self.device)
r_squared = xx**2 + yy**2 # Radius squared from the center
# Calculate phase change based on pi*[1+cos(k*r^2)]/2 to adjust the range to [0, pi]
phase_change = np.pi * (1 + torch.cos(k * r_squared)) / 2
phase_change = torch.unsqueeze(torch.unsqueeze(phase_change, axis=0), axis=0)
for i in range(self.dim[1]): # Assuming potential multiple wavelengths or batch dimension
self.set_phase_change(phase_change, c=i)
[docs]
def height2phase(height: float, wvl: float, RI: float, wrap: bool = True) -> torch.Tensor:
"""Convert material height to corresponding phase shift.
Calculates phase shift from material height using wavelength and refractive index.
Args:
height (float): Height of material in meters
wvl (float): Wavelength of light in meters
RI (float): Refractive index of material at given wavelength
wrap (bool): If True, wraps phase to [0,2π] range
Returns:
torch.Tensor: Phase change induced by material height
Examples:
>>> height = 500e-9 # 500nm height
>>> phase = height2phase(height, 633e-9, 1.5)
"""
dRI = RI - 1
wv_n = 2. * np.pi / wvl
phi = wv_n * dRI * height
if wrap:
phi = wrap_phase(phi, stay_positive=True)
return phi
[docs]
def phase2height(phase_u: torch.Tensor, wvl: float, RI: float, minh: float = 0) -> torch.Tensor:
"""Convert phase change to material height.
Note that phase to height mapping is not one-to-one.
There exists an integer phase wrapping factor:
height = wvl/(RI-1) * (phase_u + i*2π), where i is integer
This function uses minimum height minh to constrain the conversion.
Minimal height is chosen such that height is always >= minh.
Args:
phase_u (torch.Tensor): Phase change of light
wvl (float): Wavelength of light in meters
RI (float): Refractive index of material at given wavelength
minh (float): Minimum height constraint in meters
Returns:
torch.Tensor: Material height that induces the phase change
Examples:
>>> phase = torch.ones((1,1,1024,1024)) * np.pi
>>> height = phase2height(phase, 633e-9, 1.5, minh=100e-9) # 100nm min height
"""
dRI = RI - 1
if minh is not None:
i = torch.ceil(((dRI/wvl)*minh - phase_u)/(2*np.pi))
else:
i = 0
height = wvl * (phase_u + 2*np.pi*i) / dRI
return height
[docs]
class DOE(OpticalElement):
[docs]
def __init__(self, dim: tuple, pitch: float, material: 'Material', wvl: float, device: str, height: Optional[torch.Tensor] = None, phase_change: Optional[torch.Tensor] = None, polar: str = 'non'):
"""Diffractive optical element (DOE) that modifies incident light wavefront.
The wavefront modification is determined by the material height profile.
Supports both height and phase change specifications.
Args:
dim (tuple): Dimensions (B, 1, R, C) for batch size, channels, rows, columns
pitch (float): Pixel pitch in meters
material (Material): Material properties of the DOE
wvl (float): Wavelength of light in meters
device (str): Device to store wavefront ('cpu', 'cuda:0', etc.)
height (torch.Tensor, optional): Height profile in meters
phase_change (torch.Tensor, optional): Phase change profile
polar (str): Polarization mode ('non': scalar, 'polar': vector)
Examples:
>>> # Create DOE with specified height profile
>>> height = torch.ones((1,1,100,100)) * 500e-9 # 500nm height
>>> doe = DOE(height.shape, 2e-6, material, 500e-9, 'cpu', height=height)
>>> # Create DOE with specified phase profile
>>> phase = torch.ones((1,1,100,100)) * np.pi # π phase
>>> doe = DOE(phase.shape, 2e-6, material, 500e-9, 'cpu', phase_change=phase)
"""
super().__init__(dim=dim, pitch=pitch, wvl=wvl, device=device, name="doe", polar=polar)
self.material: 'Material' = material
self.height: Optional[torch.Tensor] = None
# initial DOE is tranparent and induces 0 phase delay
super().set_field_change(torch.ones(dim,device=device)*torch.exp(1*torch.zeros(dim,device=device)))
if (height is None) and (phase_change is not None):
self.set_phase_change(phase_change, sync_height=True)
elif (height is not None) and (phase_change is None):
self.set_height(height, sync_phase=True)
elif (height is None) and (phase_change is None):
phase = torch.zeros(dim, device=device)
self.set_phase_change(phase, sync_height=True)
[docs]
def visualize(self, b: int = 0, c: int = 0) -> None:
"""Visualize the DOE wavefront modulation.
Displays amplitude change, phase change and height profile.
Args:
b (int): Batch index to visualize, defaults to 0
c (int): Channel index to visualize, defaults to 0
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> doe.visualize() # Shows modulation plots
>>> doe.visualize(b=1, c=0) # Shows plots for batch index 1
"""
plt.figure(figsize=(20,5))
plt.subplot(131)
plt.imshow(self.get_amplitude_change().data.cpu()[b,c,...].squeeze(),
cmap='inferno', vmin=0, vmax=1)
plt.title('amplitude change')
plt.colorbar()
plt.subplot(132)
plt.imshow(self.get_phase_change().data.cpu()[b,c,...].squeeze(),
cmap='hsv', vmin=-np.pi, vmax=np.pi)
plt.title('phase change')
plt.colorbar()
plt.subplot(133)
plt.imshow(self.get_height().data.cpu()[b,c,...].squeeze()*1e6,
cmap='hot')
plt.title('height [um]')
plt.colorbar()
plt.suptitle(
f'{self.name}, '
f'({self.dim[2]},{self.dim[3]}), '
f'pitch:{self.pitch/1e-6:.2f}[um], '
f'wvl:{self.wvl/1e-9:.2f}[nm], '
f'device:{self.device}'
)
plt.show()
[docs]
def set_diffraction_grating_1d(self, slit_width: float, minh: float, maxh: float) -> None:
"""Set the wavefront modulation as a 1D diffraction grating.
Create alternating height regions to form a binary phase grating.
Args:
slit_width (float): Width of each slit in meters
minh (float): Minimum height in meters
maxh (float): Maximum height in meters
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> doe.set_diffraction_grating_1d(10e-6, 0, 500e-9) # 10μm slits
>>> doe.visualize() # Shows 1D grating pattern
"""
slit_width_px = np.round(slit_width / self.pitch)
slit_space_px = slit_width_px
dg = np.zeros((self.dim[2], self.dim[3]))
slit_num_r = self.dim[2] // (2 * slit_width_px)
slit_num_c = self.dim[3] // (2 * slit_width_px)
# Create a copy to avoid modifying in-place
dg_copy = dg.copy()
dg_copy[:] = minh
for i in range(int(slit_num_c)):
minc = int((slit_width_px + slit_space_px) * i)
maxc = int(minc + slit_width_px)
dg_copy[:, minc:maxc] = maxh
pc = torch.tensor(dg_copy.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0)
self.set_phase_change(1j*pc)
[docs]
def set_diffraction_grating_2d(self, slit_width: float, minh: float, maxh: float) -> None:
"""Set the wavefront modulation as a 2D diffraction grating.
Create a checkerboard pattern of alternating height regions.
Args:
slit_width (float): Width of each slit in meters
minh (float): Minimum height in meters
maxh (float): Maximum height in meters
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> doe.set_diffraction_grating_2d(10e-6, 0, 500e-9) # 10μm slits
>>> doe.visualize() # Shows 2D grating pattern
"""
slit_width_px = np.round(slit_width / self.pitch)
slit_space_px = slit_width_px
dg = np.zeros((self.dim[2], self.dim[3]))
slit_num_r = self.dim[2] // (2 * slit_width_px)
slit_num_c = self.dim[3] // (2 * slit_width_px)
# Create a copy to avoid modifying in-place
dg_copy = dg.copy()
dg_copy[:] = minh
for i in range(int(slit_num_r)):
for j in range(int(slit_num_c)):
minc = int((slit_width_px + slit_space_px) * j)
maxc = int(minc + slit_width_px)
minr = int((slit_width_px + slit_space_px) * i)
maxr = int(minr + slit_width_px)
dg_copy[minr:maxr, minc:maxc] = maxh
pc = torch.tensor(dg_copy.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0)
self.set_phase_change(pc)
[docs]
def set_Fresnel_lens(self, focal_length: float, wvl: float, shift_x: float = 0, shift_y: float = 0) -> None:
"""Set the wavefront modulation as a Fresnel lens.
Create a phase profile that focus light to a point.
Args:
focal_length (float): Focal length in meters
wvl (float): Wavelength in meters
shift_x (float): Horizontal shift in meters. Defaults to 0
shift_y (float): Vertical shift in meters. Defaults to 0
Examples:
>>> doe = DOE((1,1,1024,1024), 2e-6, material, 500e-9, 'cpu')
>>> doe.set_Fresnel_lens(0.1, 500e-9) # f=10cm lens
>>> doe.set_Fresnel_lens(0.1, 500e-9, shift_x=50e-6) # Shifted lens
"""
x = np.arange(
-self.dim[3] * self.pitch / 2,
self.dim[3] * self.pitch / 2,
self.pitch
)
x = x[:self.dim[3]]
y = np.arange(
-self.dim[2] * self.pitch / 2,
self.dim[2] * self.pitch / 2,
self.pitch
)
y = y[:self.dim[2]]
xx, yy = np.meshgrid(x, y)
xx = torch.tensor(xx, device=self.device)
yy = torch.tensor(yy, device=self.device)
phase_u = (-2 * np.pi / wvl) * (
torch.sqrt(
(xx - shift_x)**2 +
(yy - shift_y)**2 +
focal_length**2
) - focal_length
)
phase_w = wrap_phase(phase_u)
phase_w = phase_w.unsqueeze(0).unsqueeze(0)
self.set_phase_change(phase_w, sync_height=True)
[docs]
def set_Fresnel_zone_plate_lens(self, focal_length: float, wvl: float, shift_x: float = 0, shift_y: float = 0) -> None:
"""Set binary Fresnel zone plate pattern.
Creates alternating opaque and transparent zones that focus light.
Args:
focal_length (float): Focal length in meters
wvl (float): Wavelength in meters
shift_x (float): Horizontal shift in meters. Defaults to 0
shift_y (float): Vertical shift in meters. Defaults to 0
Examples:
>>> doe = DOE((1,1,1024,1024), 2e-6, material, 500e-9, 'cpu')
>>> doe.set_Fresnel_zone_plate_lens(0.1, 500e-9) # f=10cm lens
>>> doe.set_Fresnel_zone_plate_lens(0.1, 500e-9, shift_x=50e-6) # Shifted lens
"""
x = np.arange(-self.dim[3]/2, self.dim[3]/2) * self.pitch
y = np.arange(-self.dim[2]/2, self.dim[2]/2) * self.pitch
xx, yy = np.meshgrid(x, y, indexing='xy')
# Calculate the radial distance from the center
r_squared = (xx - shift_x)**2 + (yy - shift_y)**2
# Original phase calculation for a thin lens
original_phase = (-2 * np.pi / wvl) * r_squared / (2 * focal_length)
# Fresnel zone plate phase calculation
# Map phase to 0 or pi based on the sign of the cosine of the original phase
fresnel_phase = np.pi * (np.cos(original_phase) >= 0).astype(np.float32)
fresnel_phase = torch.tensor(fresnel_phase, device=self.device)
fresnel_phase = torch.unsqueeze(torch.unsqueeze(fresnel_phase, axis=0), axis=0)
self.set_phase_change(fresnel_phase, sync_height=True)
[docs]
def change_wvl(self, wvl: float) -> None:
"""Change the wavelength and update phase change.
Args:
wvl (float): New wavelength in meters
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> doe.change_wvl(633e-9) # Change to 633nm wavelength
"""
height = self.get_height()
# Store the new wavelength
self.wvl = wvl
# Calculate the phase change for the new wavelength
phase = height2phase(height, self.wvl, self.material.get_RI(self.wvl))
# Create a new field change tensor
new_field_change = torch.exp(phase*1j)
self.set_field_change(new_field_change, sync_height=False)
def set_height(self, height: torch.Tensor, sync_phase: bool = True) -> None:
"""Set the height map of the DOE.
Args:
height (torch.Tensor): Height map in meters
sync_phase (bool): If True, syncs phase profile
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> height = torch.ones((1,1,100,100)) * 500e-9
>>> doe.set_height(height, sync_phase=True)
"""
# Store a copy of the height tensor to avoid potential shared memory with input
self.height = height.clone() if height is not None else None
if sync_phase:
self.sync_phase_with_height()
[docs]
def sync_height_with_phase(self) -> None:
"""Synchronize height profile with current phase profile.
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> doe.set_phase_change(phase, sync_height=False)
>>> doe.sync_height_with_phase() # Update height to match phase
"""
height = phase2height(self.get_phase_change(), self.wvl, self.material.get_RI(self.wvl))
self.set_height(height, sync_phase=False)
[docs]
def sync_phase_with_height(self) -> None:
"""Synchronize phase profile with current height profile.
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> doe.set_height(height, sync_phase=False)
>>> doe.sync_phase_with_height() # Update phase to match height
"""
phase = height2phase(self.get_height(), self.wvl, self.material.get_RI(self.wvl))
self.set_phase_change(phase, sync_height=False)
[docs]
def resize(self, target_pitch: float) -> None:
"""Resize DOE with a new pixel pitch.
Resize field from which DOE height is recomputed.
Args:
target_pitch (float): New pixel pitch in meters
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> doe.resize(1e-6) # Change pitch to 1μm
"""
super().resize(target_pitch) # this changes the field change
self.sync_height_with_phase()
[docs]
def get_height(self) -> torch.Tensor:
"""Return the height map of the DOE.
Returns:
torch.Tensor: Height map in meters
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> height = doe.get_height() # Get current height profile
"""
return self.height
[docs]
def set_phase_change(self, phase_change: torch.Tensor, sync_height: bool = True) -> None:
"""Set phase change induced by the DOE.
Args:
phase_change (torch.Tensor): Phase change profile
sync_height (bool): If True, syncs height profile
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> phase = torch.ones((1,1,100,100)) * np.pi
>>> doe.set_phase_change(phase, sync_height=True)
"""
super().set_phase_change(phase_change)
if sync_height:
self.sync_height_with_phase()
[docs]
def set_field_change(self, field_change: torch.Tensor, sync_height: bool = True) -> None:
"""Change the field change of the DOE.
Args:
field_change (torch.Tensor): Complex field change tensor
sync_height (bool): If True, syncs height profile
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> field = torch.exp(1j * torch.ones((1,1,100,100)))
>>> doe.set_field_change(field, sync_height=True)
"""
super().set_field_change(field_change)
if sync_height:
self.sync_height_with_phase()
[docs]
def set_height(self, height: torch.Tensor, sync_phase: bool = True) -> None:
"""Set the height map of the DOE.
Args:
height (torch.Tensor): Height map in meters
sync_phase (bool): If True, syncs phase profile
Examples:
>>> doe = DOE((1,1,100,100), 2e-6, material, 500e-9, 'cpu')
>>> height = torch.ones((1,1,100,100)) * 500e-9
>>> doe.set_height(height, sync_phase=True)
"""
self.height = height
if sync_phase:
self.sync_phase_with_height()
[docs]
class SLM(OpticalElement):
[docs]
def __init__(self, dim: tuple, pitch: float, wvl: float, device: str, polar: str = 'non'):
"""Spatial Light Modulator (SLM) optical element.
Args:
dim (tuple): Field dimensions (B, 1, R, C) for batch, channels, rows, cols
pitch (float): Pixel pitch in meters
wvl (float): Wavelength in meters
device (str): Device for computation ('cpu', 'cuda:0', etc.)
polar (str): Polarization mode ('non' or 'polar')
Examples:
>>> slm = SLM(dim=(1,1,1024,1024), pitch=6.4e-6, wvl=633e-9, device='cuda:0')
"""
super().__init__(dim, pitch, wvl, device=device, name="SLM", polar=polar)
[docs]
def set_lens(self, focal_length: float, shift_x: float = 0, shift_y: float = 0) -> None:
"""Set phase profile to implement a thin lens.
Args:
focal_length (float): Focal length in meters
shift_x (float): Lateral shift in x direction in meters
shift_y (float): Lateral shift in y direction in meters
Examples:
>>> slm.set_lens(focal_length=0.5, shift_x=100e-6) # 500mm focal length, 100μm x-shift
"""
x = np.arange(-self.dim[3]*self.pitch/2, self.dim[3]*self.pitch/2, self.pitch)
y = np.arange(-self.dim[2]*self.pitch/2, self.dim[2]*self.pitch/2, self.pitch)
xx,yy = np.meshgrid(x,y)
phase_u = (2*np.pi / self.wvl)*((xx-shift_x)**2 + (yy-shift_y)**2) / (2*focal_length)
phase_u = torch.tensor(phase_u.astype(np.float32), device=self.device).unsqueeze(0).unsqueeze(0)
phase_w = wrap_phase(phase_u, stay_positive=False)
self.set_phase_change(phase_w)
[docs]
def set_amplitude_change(self, amplitude: torch.Tensor, wvl: float) -> None:
"""Set amplitude modulation profile of the SLM.
Args:
amplitude (torch.Tensor): Amplitude modulation profile [B, 1, R, C]
wvl (float): Operating wavelength in meters
Examples:
>>> amp = torch.ones((1,1,1024,1024)) * 0.8 # 80% transmission
>>> slm.set_amplitude_change(amp, wvl=633e-9)
"""
self.wvl = wvl
super().set_amplitude_change(amplitude)
[docs]
def set_phase_change(self, phase_change: torch.Tensor, wvl: float) -> None:
"""Set phase modulation profile of the SLM.
Args:
phase_change (torch.Tensor): Phase modulation profile [B, 1, R, C] in radians
wvl (float): Operating wavelength in meters
Examples:
>>> phase = torch.ones((1,1,1024,1024)) * np.pi # π phase shift
>>> slm.set_phase_change(phase, wvl=633e-9)
"""
self.wvl = wvl
super().set_phase_change(phase_change)
[docs]
class PolarizedSLM(OpticalElement):
[docs]
def __init__(self, dim: tuple, pitch: float, wvl: float, device: str):
"""SLM which can control phase & amplitude of each polarization component.
Args:
dim (tuple): Field dimensions (B, 1, R, C) for batch, channels, rows, cols
pitch (float): Pixel pitch in meters
wvl (float): Wavelength in meters
device (str): Device for computation ('cpu', 'cuda:0', etc.)
Examples:
>>> slm = PolarizedSLM(dim=(1,1,1024,1024), pitch=6.4e-6, wvl=633e-9, device='cuda:0')
"""
super().__init__(dim, pitch, wvl, device=device, name="Metasurface", polar='polar')
self.amplitude_change = torch.ones((dim[0], 1, dim[2], dim[3], 2), device=self.device)
self.phase_change = torch.zeros((dim[0], 1, dim[2], dim[3], 2), device=self.device)
[docs]
def set_amplitude_change(self, amplitude: torch.Tensor, wvl: float) -> None:
"""Set amplitude change for both polarization components.
Args:
amplitude (torch.Tensor): Amplitude change [B, 1, R, C, 2] in polar representation
wvl (float): Wavelength in meters
Examples:
>>> amp = torch.ones((1,1,1024,1024,2)) * 0.8 # 80% transmission for both polarizations
>>> slm.set_amplitude_change(amp, wvl=633e-9)
"""
self.wvl = wvl
super().set_amplitude_change(amplitude)
[docs]
def set_phase_change(self, phase_change: torch.Tensor, wvl: float) -> None:
"""Set phase change for both polarization components.
Args:
phase_change (torch.Tensor): Phase change [B, 1, R, C, 2] in polar representation
wvl (float): Wavelength in meters
Examples:
>>> phase = torch.ones((1,1,1024,1024,2)) * np.pi # π phase shift for both polarizations
>>> slm.set_phase_change(phase, wvl=633e-9)
"""
self.wvl = wvl
super().set_phase_change(phase_change)
[docs]
def set_amplitudeX_change(self, amplitude: torch.Tensor, wvl: float) -> None:
"""Set amplitude change for X polarization component.
Args:
amplitude (torch.Tensor): Amplitude change [B, 1, R, C] for X component
wvl (float): Wavelength in meters
Examples:
>>> ampX = torch.ones((1,1,1024,1024)) * 0.8 # 80% transmission for X polarization
>>> slm.set_amplitudeX_change(ampX, wvl=633e-9)
"""
self.wvl = wvl
amp = self.get_amplitude_change()
# Create new tensor instead of modifying in-place
new_amp = amp.clone()
new_amp[:,:,:,:,0] = amplitude
super().set_amplitude_change(new_amp)
[docs]
def set_amplitudeY_change(self, amplitude: torch.Tensor, wvl: float) -> None:
"""Set amplitude change for Y polarization component.
Args:
amplitude (torch.Tensor): Amplitude change [B, 1, R, C] for Y component
wvl (float): Wavelength in meters
Examples:
>>> ampY = torch.ones((1,1,1024,1024)) * 0.6 # 60% transmission for Y polarization
>>> slm.set_amplitudeY_change(ampY, wvl=633e-9)
"""
self.wvl = wvl
amp = self.get_amplitude_change()
# Create new tensor instead of modifying in-place
new_amp = amp.clone()
new_amp[:,:,:,:,1] = amplitude
super().set_amplitude_change(new_amp)
[docs]
def set_phaseX_change(self, phase_change: torch.Tensor, wvl: float) -> None:
"""Set phase change for X polarization component.
Args:
phase_change (torch.Tensor): Phase change [B, 1, R, C] for X component
wvl (float): Wavelength in meters
Examples:
>>> phaseX = torch.ones((1,1,1024,1024)) * np.pi/2 # π/2 phase shift for X polarization
>>> slm.set_phaseX_change(phaseX, wvl=633e-9)
"""
self.wvl = wvl
phase = self.get_phase_change()
# Create new tensor instead of modifying in-place
new_phase = phase.clone()
new_phase[:,:,:,:,0] = phase_change
super().set_phase_change(new_phase)
[docs]
def set_phaseY_change(self, phase_change: torch.Tensor, wvl: float) -> None:
"""Set phase change for Y polarization component.
Args:
phase_change (torch.Tensor): Phase change [B, 1, R, C] for Y component
wvl (float): Wavelength in meters
Examples:
>>> phaseY = torch.ones((1,1,1024,1024)) * np.pi # π phase shift for Y polarization
>>> slm.set_phaseY_change(phaseY, wvl=633e-9)
"""
self.wvl = wvl
phase = self.get_phase_change()
# Create new tensor instead of modifying in-place
new_phase = phase.clone()
new_phase[:,:,:,:,1] = phase_change
super().set_phase_change(new_phase)
[docs]
def get_phase_changeX(self) -> torch.Tensor:
"""Return phase change for X polarization component.
Returns:
torch.Tensor: Phase change [B, 1, R, C] for X component
Examples:
>>> phaseX = slm.get_phase_changeX() # Get X polarization phase profile
"""
return self.get_phase_change()[:,:,:,:,0]
[docs]
def get_phase_changeY(self) -> torch.Tensor:
"""Return phase change for Y polarization component.
Returns:
torch.Tensor: Phase change [B, 1, R, C] for Y component
Examples:
>>> phaseY = slm.get_phase_changeY() # Get Y polarization phase profile
"""
return self.get_phase_change()[:,:,:,:,1]
[docs]
def get_amplitude_changeX(self) -> torch.Tensor:
"""Return amplitude change for X polarization component.
Returns:
torch.Tensor: Amplitude change [B, 1, R, C] for X component
Examples:
>>> ampX = slm.get_amplitude_changeX() # Get X polarization amplitude profile
"""
return self.get_amplitude_change()[:,:,:,:,0]
[docs]
def get_amplitude_changeY(self) -> torch.Tensor:
"""Return amplitude change for Y polarization component.
Returns:
torch.Tensor: Amplitude change [B, 1, R, C] for Y component
Examples:
>>> ampY = slm.get_amplitude_changeY() # Get Y polarization amplitude profile
"""
return self.get_amplitude_change()[:,:,:,:,1]
[docs]
def forward(self, light: 'Light', interp_mode: str = 'nearest') -> 'Light':
"""Apply polarization-dependent modulation to input light.
Args:
light (Light): Input light field
interp_mode (str): Interpolation mode for resizing ('nearest', 'bilinear', etc.)
Returns:
Light: Modulated light field
Examples:
>>> modulated_light = slm.forward(input_light) # Apply polarization modulation
>>> modulated_light = slm.forward(input_light, interp_mode='bilinear') # Use bilinear interpolation
"""
if light.wvl != self.wvl:
raise ValueError(f'Wavelength mismatch: light wavelength {light.wvl} != element wavelength {self.wvl}')
if light.pitch > self.pitch:
light.resize(self.pitch, interp_mode)
light.set_pitch(self.pitch)
elif light.pitch < self.pitch:
self.resize(light.pitch, interp_mode)
self.set_pitch(light.pitch)
r1 = np.abs((light.dim[2] - self.dim[2])//2)
r2 = np.abs(light.dim[2] - self.dim[2]) - r1
pad_width = (r1, r2, 0, 0)
if light.dim[2] > self.dim[2]:
self.pad(pad_width)
elif light.dim[2] < self.dim[2]:
light.pad(pad_width)
c1 = np.abs((light.dim[3] - self.dim[3])//2)
c2 = np.abs(light.dim[3] - self.dim[3]) - c1
pad_width = (0, 0, c1, c2)
if light.dim[3] > self.dim[3]:
self.pad(pad_width)
elif light.dim[3] < self.dim[3]:
light.pad(pad_width)
# Ensure compatible shapes for broadcasting
light_phase = light.get_phase()
phase_change = self.get_phase_change()
# Check if shapes are compatible and reshape if needed
if light_phase.shape != phase_change.shape:
# Ensure light_phase has the same shape as phase_change for proper broadcasting
if light_phase.dim() == 4 and phase_change.dim() == 5:
# Add polarization dimension if missing
light_phase = light_phase.unsqueeze(-1)
if light_phase.shape[-1] == 1:
# Duplicate the phase for both polarizations
light_phase = light_phase.expand(-1, -1, -1, -1, 2)
# Apply phase modulation with proper shape handling
phase = (light_phase + phase_change + np.pi) % (np.pi*2) - np.pi
# Set the modulated phase and amplitude for each polarization component - non-modifying operations
light.set_phaseX(phase[..., 0])
light.set_phaseY(phase[..., 1])
light.set_amplitudeX(light.get_amplitudeX() * self.get_amplitude_change()[..., 0])
light.set_amplitudeY(light.get_amplitudeY() * self.get_amplitude_change()[..., 1])
return light
[docs]
def pad(self, pad_width: tuple, padval: int = 0) -> None:
"""Pad amplitude and phase changes with constant value.
Args:
pad_width (tuple): Padding dimensions (left, right, top, bottom)
padval (int): Padding value (only 0 supported)
Examples:
>>> slm.pad((16,16,16,16)) # Add 16 pixels padding on all sides
"""
if padval == 0:
# Create new padded tensors instead of modifying in-place
padded_amplitude = torch.nn.functional.pad(
self.get_amplitude_change(),
(0,0,0,0,pad_width[2],pad_width[3],pad_width[0],pad_width[1])
)
padded_phase = torch.nn.functional.pad(
self.get_phase_change(),
(0,0,0,0,pad_width[2],pad_width[3],pad_width[0],pad_width[1])
)
self.amplitude_change = padded_amplitude
self.phase_change = padded_phase
else:
raise NotImplementedError('only zero padding supported')
# Create a new dim tuple instead of modifying in-place
new_dim = list(self.dim)
new_dim[2] += pad_width[0] + pad_width[1]
new_dim[3] += pad_width[2] + pad_width[3]
self.dim = tuple(new_dim)
[docs]
def visualize(self, b: int = 0) -> None:
"""Visualize amplitude and phase modulation for both polarizations.
Args:
b (int): Batch index to visualize, default 0
Examples:
>>> slm.visualize() # Visualize first batch
>>> slm.visualize(b=1) # Visualize second batch
"""
plt.figure(figsize=(13,8))
plt.subplot(221)
plt.imshow(self.get_amplitude_changeX().data.cpu()[b,...].squeeze(), cmap='inferno')
plt.title('amplitude change X')
plt.colorbar()
plt.subplot(222)
plt.imshow(self.get_phase_changeX().data.cpu()[b,...].squeeze(), cmap='hsv')
plt.title('phase change X')
plt.colorbar()
plt.subplot(223)
plt.imshow(self.get_amplitude_changeY().data.cpu()[b,...].squeeze(), cmap='inferno')
plt.title('amplitude change Y')
plt.colorbar()
plt.subplot(224)
plt.imshow(self.get_phase_changeY().data.cpu()[b,...].squeeze(), cmap='hsv')
plt.title('phase change Y')
plt.colorbar()
plt.suptitle(
f'{self.name}, '
f'({self.dim[2]},{self.dim[3]}), '
f'pitch:{self.pitch/1e-6:.2f}[um], '
f'wvl:{self.wvl/1e-9:.2f}[nm], '
f'device:{self.device}'
)
plt.show()
[docs]
class Aperture(OpticalElement):
"""Aperture optical element for amplitude modulation.
Implement square or circular aperture that modulate light amplitude.
Support both polarized and non-polarized light.
"""
[docs]
def __init__(self, dim: tuple, pitch: float, aperture_diameter: float, aperture_shape: str, wvl: float, device: str = 'cpu', polar: str = 'non'):
"""Create aperture optical element instance.
Args:
dim (tuple): Field dimensions (B, 1, R, C) for batch, channels, rows, cols
pitch (float): Pixel pitch in meters
aperture_diameter (float): Diameter of aperture in meters
aperture_shape (str): Shape of aperture ('square' or 'circle')
wvl (float): Wavelength in meters
device (str): Device for computation ('cpu', 'cuda:0', etc.)
polar (str): Polarization mode ('non', 'x', 'y', 'xy')
Examples:
>>> aperture = Aperture(dim=(1,1,1024,1024), pitch=6.4e-6,
... aperture_diameter=1e-3, aperture_shape='circle',
... wvl=633e-9)
"""
super().__init__(dim, pitch, wvl, device=device, name="aperture", polar=polar)
self.aperture_diameter = aperture_diameter
self.aperture_shape = aperture_shape
self.amplitude_change = torch.zeros((self.dim[2], self.dim[3]), device=device)
if self.aperture_shape == 'square':
self.set_square()
elif self.aperture_shape == 'circle':
self.set_circle()
else:
return NotImplementedError
[docs]
def set_square(self) -> None:
"""Set square aperture amplitude modulation.
Create square aperture mask centered on optical axis.
Examples:
>>> aperture.set_square()
"""
self.aperture_shape = 'square'
[x, y] = np.mgrid[-self.dim[2]//2:self.dim[2]//2, -self.dim[3]//2:self.dim[3]//2].astype(np.float32)
r = self.pitch * np.asarray([abs(x), abs(y)]).max(axis=0)
r = np.expand_dims(np.expand_dims(r, axis=0), axis=0)
max_val = self.aperture_diameter / 2
amp = (r <= max_val).astype(np.float32)
amp_copy = amp.copy() # Make a copy to avoid modifying the array in-place
amp_copy[amp_copy == 0] = 1e-20
self.set_field_change(torch.tensor(amp_copy, device=self.device))
[docs]
def set_circle(self, cx: float = 0, cy: float = 0, dia: float = None) -> None:
"""Set circular aperture amplitude modulation.
Create circular aperture mask with optional offset and diameter.
Args:
cx (float): Center x-offset in pixels
cy (float): Center y-offset in pixels
dia (float, optional): Circle diameter in meters
Examples:
>>> aperture.set_circle() # Centered circle
>>> aperture.set_circle(cx=10, cy=-10, dia=2e-3) # Offset circle
"""
[x, y] = np.mgrid[-self.dim[2]//2:self.dim[2]//2, -self.dim[3]//2:self.dim[3]//2].astype(np.float32)
r2 = (x-cx) ** 2 + (y-cy) ** 2
r2_copy = r2.copy()
r2_copy[r2_copy < 0] = 1e-20
r = self.pitch * np.sqrt(r2_copy)
r = np.expand_dims(np.expand_dims(r, axis=0), axis=0)
if dia is not None:
self.aperture_diameter = dia
self.aperture_shape = 'circle'
max_val = self.aperture_diameter / 2
amp = (r <= max_val).astype(np.float32)
amp_copy = amp.copy() # Make a copy to avoid modifying the array in-place
amp_copy[amp_copy == 0] = 1e-20
self.set_field_change(torch.tensor(amp_copy, device=self.device))
[docs]
def quantize(x: Union[torch.Tensor, np.ndarray], levels: int, vmin: float = None, vmax: float = None, include_vmax: bool = True) -> Union[torch.Tensor, np.ndarray]:
"""Quantize floating point array.
Discretize input array into specified number of levels.
Args:
x (torch.Tensor or np.ndarray): Input array to quantize
levels (int): Number of quantization levels
vmin (float, optional): Minimum value for quantization
vmax (float, optional): Maximum value for quantization
include_vmax (bool): Whether to include max value in quantization
True: Quantize with spacing of 1/levels
False: Quantize with spacing of 1/(levels-1)
Returns:
torch.Tensor or np.ndarray: Quantized array
Examples:
>>> x = torch.randn(100)
>>> x_quant = quantize(x, levels=8)
>>> x_quant = quantize(x, levels=16, vmin=-1, vmax=1)
"""
if include_vmax is False:
if levels == 0:
return x
if vmin is None:
vmin = x.min()
if vmax is None:
vmax = x.max()
normalized = (x - vmin) / (vmax - vmin + 1e-16)
if isinstance(x, np.ndarray):
levelized = np.floor(normalized * levels) / (levels - 1)
elif isinstance(x, torch.Tensor):
levelized = (normalized * levels).floor() / (levels - 1)
result = levelized * (vmax - vmin) + vmin
# Create a copy to ensure no in-place operations
if isinstance(x, np.ndarray):
result_copy = result.copy()
result_copy[result_copy < vmin] = vmin
result_copy[result_copy > vmax] = vmax
return result_copy
else: # torch.Tensor
# For tensors, we use clamp which returns a new tensor
return torch.clamp(result, min=vmin, max=vmax)
elif include_vmax is True:
space = (x.max()-x.min())/levels
vmin = x.min()
vmax = vmin + space*(levels-1)
if isinstance(x, np.ndarray):
result = (np.floor((x-vmin)/space))*space + vmin
# Create a copy to ensure no in-place operations
result_copy = result.copy()
result_copy[result_copy < vmin] = vmin
result_copy[result_copy > vmax] = vmax
return result_copy
elif isinstance(x, torch.Tensor):
result = (((x-vmin)/space).floor())*space + vmin
# For tensors, we use clamp which returns a new tensor
return torch.clamp(result, min=vmin, max=vmax)