Sampling

DRLX provides multiple samplers. Different methods often require a specific sampling procedure. A default sampler is also included for inference purposes.

Sampler

class drlx.sampling.Sampler(config: SamplerConfig = SamplerConfig(guidance_scale=5.0, guidance_rescale=None, num_inference_steps=50, eta=1, postprocess=False, img_size=512))

Bases: object

Generic class for sampling generations using a denoiser. Assumes LDMUnet

cfg_rescale(pred: Tensor[Tensor])

Applies classifier free guidance to prediction and rescales if cfg_rescaling is enabled

Parameters:

pred – Assumed to be batched repeated prediction with first half consisting of unconditioned (empty token) predictions and second half being conditioned predictions

sample(prompts: Iterable[str], denoiser, device=None, show_progress: bool = False, accelerator=None)

Samples latents given some prompts and a denoiser

Parameters:
  • prompts – Text prompts for image generation (to condition denoiser)

  • denoiser – Model to use for denoising

  • device – Device on which to perform model inference

  • show_progress – Whether to display a progress bar for the sampling steps

  • accelerator – Accelerator object for accelerated training (optional)

Returns:

Latents unless postprocess flag is set to true in config, in which case VAE decoded latents are returned (i.e. images)

DDPOSampler

class drlx.sampling.DDPOSampler(config: SamplerConfig = SamplerConfig(guidance_scale=5.0, guidance_rescale=None, num_inference_steps=50, eta=1, postprocess=False, img_size=512))

Bases: Sampler

compute_loss(prompts, denoiser, device, show_progress: bool = False, advantages=None, old_preds=None, old_log_probs=None, method_config: DDPOConfig | None = None, accelerator=None)

Computes the loss for the DDPO sampling process. This function is used to train the denoiser model.

Parameters:
  • prompts – Text prompts to condition the denoiser

  • denoiser – Denoising model

  • device – Device to perform model inference on

  • show_progress – Whether to display a progress bar for the sampling steps

  • advantages – Normalized advantages obtained from reward computation

  • old_preds – Previous predictions from past model

  • old_log_probs – Log probabilities of predictions from past model

  • method_config – Configuration for the DDPO method

  • accelerator – Accelerator object for accelerated training (optional)

Returns:

Total loss computed over the sampling process

sample(prompts, denoiser, device, show_progress: bool = False, accelerator=None) Iterable[Tensor]

DDPO sampling is analagous to playing a game in an RL environment. This function samples given denoiser and prompts but in addition to giving latents also gives log probabilities for predictions as well as ALL predictions (i.e. at each timestep)

Parameters:
  • prompts – Text prompts to condition denoiser

  • denoiser – Denoising model

  • device – Device to do inference on

  • show_progress – Display progress bar?

  • accelerator – Accelerator object for accelerated training (optional)

Returns:

triple of final denoised latents, all model predictions, all log probabilities for each prediction

step_and_logprobs(scheduler, pred: Tensor[Tensor], t: float, latents: Tensor[Tensor], old_pred: Tensor[Tensor] | None = None)

Steps backwards using scheduler. Considers the prediction as an action sampled from a normal distribution and returns average log probability for that prediction. Can also be used to find probability of current model giving some other prediction (old_pred)

Parameters:
  • scheduler – Scheduler being used for diffusion process

  • pred – Denoiser prediction with CFG and scaling accounted for

  • t – Timestep in diffusion process

  • latents – Latent vector given as input to denoiser

  • old_pred – Alternate prediction. If given, computes log probability of current model predicting alternative output.