Trainers

DRLX provides a base trainer class and specific trainers for different methods. The base trainer class provides the basic functionalities such as setting up the optimizer, scheduler, and model, saving and loading checkpoints. The specific trainers extend the base trainer and implement the training process for the specific method.

BaseTrainer

class drlx.trainer.BaseTrainer(config: DRLXConfig)

Bases: object

Base class for any DRLX trainer

get_arch(config)

Get model class from arch_name in config file. Currently only supports LDMUNet

load_checkpoint(fp: str, index: int | None = None) Dict[str, Any]

Basic checkpoint loading for derived trainers to use.

Parameters:
  • fp (str) – Path to load checkpoint from

  • index (Optional[int]) – When provided, uses fp as root and loads subdirectory with numerical name given by index

Returns:

Dictionary of components and their states

Return type:

Dict

save_checkpoint(fp: str, components: Dict[str, Any], index: int | None = None)

Basic checkpoint saving for any derived trainer to use

Parameters:
  • fp (str) – Path to save checkpoint to

  • components (Dict) – Dictionary of all components to save (i.e. model, optimizer, scheduler, etc.)

  • index (Optional[int]) – When provided, uses fp as a root folder and puts checkpoint under a subdirectory that is named numerically with index

setup_optimizer()

Returns an optimizer derived from an instance’s config

setup_scheduler()

Returns a learning rate scheduler derived from an instance’s config

abstract train(pipeline: Pipeline, reward_fn: Callable[[Iterable[Image], Iterable[str]], Tensor[Tensor]])

Trains model on a given pipeline using a given reward function.

Parameters:
  • pipeline – Data pipeline used for training

  • reward_fn – Function used to get rewards. Should take tuples of images (either as a sequence of numpy arrays, or as a list of images)

DDPOTrainer

class drlx.trainer.ddpo_trainer.DDPOTrainer(config: DRLXConfig)

Bases: BaseTrainer

DDPO Accelerated Trainer initilization from config. During init, sets up model, optimizer, sampler and logging

Parameters:

config (DRLXConfig) – DRLX config

extract_pipeline()

Return original pipeline with finetuned denoiser plugged in

Returns:

Diffusers pipeline

load_checkpoint(fp: str)

Load checkpoint

Parameters:

fp – File path to checkpoint to load from

loss(x_t: Tensor[Tensor], log_probs_t: Tensor[Tensor], advantages: Tensor[Tensor], prompts: Iterable[str])

Get loss for training

Parameters:
  • x_t (torch.Tensor) – Samples across time steps and across batch

  • log_probs_t (torch.Tensor) – Log probabilities for each sample prediction

Advantages:

Advantages associated with each image across batch

Prompts:

Prompts used for generation across the batch

Returns:

loss

Return type:

torch.Tensor

sample(prompts: Iterable[str]) Tuple[Tensor]

Sample predictions, predictions at time steps and log probabilities from sampler

Parameters:

prompts (Iterable[str]) – Batched prompts to use for sampling

Returns:

3 Tensors: final predictions for latent, all step predictions during denoising process, and log probabilities for each prediction

Return type:

Tuple[torch.Tensor]

sample_and_calculate_rewards(prompts: Iterable[str], reward_fn: Callable) Tuple

Samples a batch of images and calculates the rewards for each image

Parameters:
  • prompts (Iterable[str]) – Batch of prompts to sample with

  • reward_fn (Callable[[np.ndarray, Iterable[str]], Iterable[float]]) – Function to be called on final images and prompts to be used for reward computation

Returns:

Final images, rewards, all step predictions, log probabilities for predictions

Return type:

Tuple

save_checkpoint(fp: str, components=None)

Save checkpoint in main process

Parameters:

fp – File path to save checkpoint to

save_pretrained(fp: str)

Save model into pretrained pipeline so it can be loaded in pipeline later

Parameters:

fp – File path to save to

setup_model()

Set up model from config.

train(prompt_pipeline, reward_fn)

Trains the model based on config parameters. Needs to be passed a prompt pipeline and reward function.

Parameters:
  • prompt_pipeline (PromptPipeline) – Pipeline to draw text prompts from. Should be composed of just strings.

  • reward_fn – Any function that returns a tensor of scalar rewards given np array of images (uint8) and text prompts (strings).

It is fine to have a reward function that only rewards images without looking at prompts, simply add prompts as a dummy input. :type reward_fn: Callable[[np.array, Iterable[str], torch.Tensor]