Pipeline
The pipeline module in DRLX is used for data preparation when training some RL model. It includes a base class Pipeline and two subclasses PromptPipeline, PickAPicPrompts and ImagenetAnimalPrompts.
Pipeline
- class drlx.pipeline.Pipeline(prep_fn: Callable | None = None)
Bases:
DatasetPipeline for data during RL training. Subclasses should define some dataset with getitem and len methods.
- Parameters:
prep_fn (Callable) – Function that will be called on iterable of data elements from the pipeline. Not always required, and by default is simply an identity function.
- create_loader(**kwargs) DataLoader
Create dataloader over self. Assumes __getitem__ and __len__ are implemented.
- Parameters:
kwargs – Keyword arguments for the created pytorch dataloader
- Returns:
Dataloader for dataset within pipeline
- Return type:
DataLoader
- create_train_loader(**kwargs) DataLoader
Create loader for training data. Default behaviour is to just call create_loader (i.e. assumes there is no split)
- abstract create_val_loader(**kwargs) DataLoader
Create validation loader.
- classmethod make_default_collate(prep: Callable)
Creates a default collate function for the dataloader that assumes dataset elements are tuples of images and strings.
PromptPipeline
PickAPicPrompts
ImagenetAnimalPrompts
- class drlx.pipeline.imagenet_animal_prompts.ImagenetAnimalPrompts(prefix='A picture of a ', postfix=', 4k unreal engine', num=10000, *args, **kwargs)
Bases:
PromptPipelinePipeline of prompts consisting of animals from ImageNet, as used in the original DDPO paper.