DRLX Example
This example demonstrates how to use DRLX to train a model with a custom prompt pipeline and reward model. The prompt pipeline will repeatedly provide the same prompt, “Photo of a mad scientist panda”, and the reward model will reward images for having high contrast.
Custom Prompt Pipeline
First, we define a custom prompt pipeline that only gives a single phrase “Photo of a mad scientist panda” over and over.
from drlx.pipeline import PromptPipeline
class MadScientistPandaPrompts(PromptPipeline):
"""
Custom prompt pipeline that only gives a single phrase "Photo of a mad scientist panda" over and over.
"""
def __getitem__(self, index):
return "Photo of a mad scientist panda"
def __len__(self):
return 100000 # arbitrary
Custom Reward
Next, we define a custom reward model that rewards images for having high contrast. The contrast is calculated as the standard deviation of the pixel intensities.
from drlx.reward_modelling import RewardModel
import numpy as np
import torch
class HighContrastReward(RewardModel):
"""
Rewards high contrast in the image.
"""
def forward(self, images, prompts):
# If the input is a list of PIL Images, convert to numpy array
if isinstance(images, list):
images = np.array([np.array(img) for img in images])
# Calculate the standard deviation of the pixel intensities for each image
contrast = images.std(axis=(1,2,3)) # N
return torch.from_numpy(contrast)
Training Setup
Now, we set up the training process. We use the MadScientistPandaPrompts as the prompt pipeline and the HighContrastReward as the reward model.
from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig
from drlx.reward_modelling.toy_rewards import JPEGCompressability
from drlx.reward_modelling.aesthetics import Aesthetics
from drlx.utils import get_latest_checkpoint
# Pipeline first
from drlx.pipeline.pickapic_prompts import PickAPicPrompts
import torch
pipe = MadScientistPandaPrompts()
config = DRLXConfig.load_yaml("configs/ddpo_sd.yml")
trainer = DDPOTrainer(config)
trainer.train(pipe, HighContrastReward())
For accelerated training, simply run the following command:
accelerate launch -m [script]
Loading the Model and Performing Inference
After training, we can load the model and perform inference with it using a default sampler.
# Load the trainer from a checkpoint if you wanted to resume training
# Trainer by default saves both output and checkpoint in seperate folders specified by run_name
checkpoint_path = "checkpoints/run_name"
output_path = "output/run_name"
trainer.load_checkpoint(checkpoint_path)
# Otherwise, you can just use a pretrained pipeline
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(output_path, local_files_only = True)
To actually run this code or make tweaks, please see the notebooks or scripts under the examples folder.