Trainers
DRLX provides a base trainer class and specific trainers for different methods. The base trainer class provides the basic functionalities such as setting up the optimizer, scheduler, and model, saving and loading checkpoints. The specific trainers extend the base trainer and implement the training process for the specific method.
BaseTrainer
- class drlx.trainer.BaseTrainer(config: DRLXConfig)
Bases:
objectBase class for any DRLX trainer
- get_arch(config)
Get model class from arch_name in config file. Currently only supports LDMUNet
- load_checkpoint(fp: str, index: int | None = None) Dict[str, Any]
Basic checkpoint loading for derived trainers to use.
- Parameters:
fp (str) – Path to load checkpoint from
index (Optional[int]) – When provided, uses fp as root and loads subdirectory with numerical name given by index
- Returns:
Dictionary of components and their states
- Return type:
Dict
- save_checkpoint(fp: str, components: Dict[str, Any], index: int | None = None)
Basic checkpoint saving for any derived trainer to use
- Parameters:
fp (str) – Path to save checkpoint to
components (Dict) – Dictionary of all components to save (i.e. model, optimizer, scheduler, etc.)
index (Optional[int]) – When provided, uses fp as a root folder and puts checkpoint under a subdirectory that is named numerically with index
- setup_optimizer()
Returns an optimizer derived from an instance’s config
- setup_scheduler()
Returns a learning rate scheduler derived from an instance’s config
- abstract train(pipeline: Pipeline, reward_fn: Callable[[Iterable[Image], Iterable[str]], Tensor[Tensor]])
Trains model on a given pipeline using a given reward function.
- Parameters:
pipeline – Data pipeline used for training
reward_fn – Function used to get rewards. Should take tuples of images (either as a sequence of numpy arrays, or as a list of images)
DDPOTrainer
- class drlx.trainer.ddpo_trainer.DDPOTrainer(config: DRLXConfig)
Bases:
BaseTrainerDDPO Accelerated Trainer initilization from config. During init, sets up model, optimizer, sampler and logging
- Parameters:
config (DRLXConfig) – DRLX config
- extract_pipeline()
Return original pipeline with finetuned denoiser plugged in
- Returns:
Diffusers pipeline
- load_checkpoint(fp: str)
Load checkpoint
- Parameters:
fp – File path to checkpoint to load from
- loss(x_t: Tensor[Tensor], log_probs_t: Tensor[Tensor], advantages: Tensor[Tensor], prompts: Iterable[str])
Get loss for training
- Parameters:
x_t (torch.Tensor) – Samples across time steps and across batch
log_probs_t (torch.Tensor) – Log probabilities for each sample prediction
- Advantages:
Advantages associated with each image across batch
- Prompts:
Prompts used for generation across the batch
- Returns:
loss
- Return type:
torch.Tensor
- sample(prompts: Iterable[str]) Tuple[Tensor]
Sample predictions, predictions at time steps and log probabilities from sampler
- Parameters:
prompts (Iterable[str]) – Batched prompts to use for sampling
- Returns:
3 Tensors: final predictions for latent, all step predictions during denoising process, and log probabilities for each prediction
- Return type:
Tuple[torch.Tensor]
- sample_and_calculate_rewards(prompts: Iterable[str], reward_fn: Callable) Tuple
Samples a batch of images and calculates the rewards for each image
- Parameters:
prompts (Iterable[str]) – Batch of prompts to sample with
reward_fn (Callable[[np.ndarray, Iterable[str]], Iterable[float]]) – Function to be called on final images and prompts to be used for reward computation
- Returns:
Final images, rewards, all step predictions, log probabilities for predictions
- Return type:
Tuple
- save_checkpoint(fp: str, components=None)
Save checkpoint in main process
- Parameters:
fp – File path to save checkpoint to
- save_pretrained(fp: str)
Save model into pretrained pipeline so it can be loaded in pipeline later
- Parameters:
fp – File path to save to
- setup_model()
Set up model from config.
- train(prompt_pipeline, reward_fn)
Trains the model based on config parameters. Needs to be passed a prompt pipeline and reward function.
- Parameters:
prompt_pipeline (PromptPipeline) – Pipeline to draw text prompts from. Should be composed of just strings.
reward_fn – Any function that returns a tensor of scalar rewards given np array of images (uint8) and text prompts (strings).
It is fine to have a reward function that only rewards images without looking at prompts, simply add prompts as a dummy input. :type reward_fn: Callable[[np.array, Iterable[str], torch.Tensor]