Sampling
DRLX provides multiple samplers. Different methods often require a specific sampling procedure. A default sampler is also included for inference purposes.
Sampler
- class drlx.sampling.Sampler(config: SamplerConfig = SamplerConfig(guidance_scale=5.0, guidance_rescale=None, num_inference_steps=50, eta=1, postprocess=False, img_size=512))
Bases:
objectGeneric class for sampling generations using a denoiser. Assumes LDMUnet
- cfg_rescale(pred: Tensor[Tensor])
Applies classifier free guidance to prediction and rescales if cfg_rescaling is enabled
- Parameters:
pred – Assumed to be batched repeated prediction with first half consisting of unconditioned (empty token) predictions and second half being conditioned predictions
- sample(prompts: Iterable[str], denoiser, device=None, show_progress: bool = False, accelerator=None)
Samples latents given some prompts and a denoiser
- Parameters:
prompts – Text prompts for image generation (to condition denoiser)
denoiser – Model to use for denoising
device – Device on which to perform model inference
show_progress – Whether to display a progress bar for the sampling steps
accelerator – Accelerator object for accelerated training (optional)
- Returns:
Latents unless postprocess flag is set to true in config, in which case VAE decoded latents are returned (i.e. images)
DDPOSampler
- class drlx.sampling.DDPOSampler(config: SamplerConfig = SamplerConfig(guidance_scale=5.0, guidance_rescale=None, num_inference_steps=50, eta=1, postprocess=False, img_size=512))
Bases:
Sampler- compute_loss(prompts, denoiser, device, show_progress: bool = False, advantages=None, old_preds=None, old_log_probs=None, method_config: DDPOConfig | None = None, accelerator=None)
Computes the loss for the DDPO sampling process. This function is used to train the denoiser model.
- Parameters:
prompts – Text prompts to condition the denoiser
denoiser – Denoising model
device – Device to perform model inference on
show_progress – Whether to display a progress bar for the sampling steps
advantages – Normalized advantages obtained from reward computation
old_preds – Previous predictions from past model
old_log_probs – Log probabilities of predictions from past model
method_config – Configuration for the DDPO method
accelerator – Accelerator object for accelerated training (optional)
- Returns:
Total loss computed over the sampling process
- sample(prompts, denoiser, device, show_progress: bool = False, accelerator=None) Iterable[Tensor]
DDPO sampling is analagous to playing a game in an RL environment. This function samples given denoiser and prompts but in addition to giving latents also gives log probabilities for predictions as well as ALL predictions (i.e. at each timestep)
- Parameters:
prompts – Text prompts to condition denoiser
denoiser – Denoising model
device – Device to do inference on
show_progress – Display progress bar?
accelerator – Accelerator object for accelerated training (optional)
- Returns:
triple of final denoised latents, all model predictions, all log probabilities for each prediction
- step_and_logprobs(scheduler, pred: Tensor[Tensor], t: float, latents: Tensor[Tensor], old_pred: Tensor[Tensor] | None = None)
Steps backwards using scheduler. Considers the prediction as an action sampled from a normal distribution and returns average log probability for that prediction. Can also be used to find probability of current model giving some other prediction (old_pred)
- Parameters:
scheduler – Scheduler being used for diffusion process
pred – Denoiser prediction with CFG and scaling accounted for
t – Timestep in diffusion process
latents – Latent vector given as input to denoiser
old_pred – Alternate prediction. If given, computes log probability of current model predicting alternative output.