Pipeline

The pipeline module in DRLX is used for data preparation when training some RL model. It includes a base class Pipeline and two subclasses PromptPipeline, PickAPicPrompts and ImagenetAnimalPrompts.

Pipeline

class drlx.pipeline.Pipeline(prep_fn: Callable | None = None)

Bases: Dataset

Pipeline for data during RL training. Subclasses should define some dataset with getitem and len methods.

Parameters:

prep_fn (Callable) – Function that will be called on iterable of data elements from the pipeline. Not always required, and by default is simply an identity function.

create_loader(**kwargs) DataLoader

Create dataloader over self. Assumes __getitem__ and __len__ are implemented.

Parameters:

kwargs – Keyword arguments for the created pytorch dataloader

Returns:

Dataloader for dataset within pipeline

Return type:

DataLoader

create_train_loader(**kwargs) DataLoader

Create loader for training data. Default behaviour is to just call create_loader (i.e. assumes there is no split)

abstract create_val_loader(**kwargs) DataLoader

Create validation loader.

classmethod make_default_collate(prep: Callable)

Creates a default collate function for the dataloader that assumes dataset elements are tuples of images and strings.

PromptPipeline

class drlx.pipeline.PromptPipeline(prep_fn: Callable | None = None)

Bases: Pipeline

Base class for a pipeline that provides text prompts only.

classmethod make_default_collate(prep: Callable)

Default collate for a prompt pipeline which assumes the dataset elements are simply strings.

PickAPicPrompts

ImagenetAnimalPrompts

class drlx.pipeline.imagenet_animal_prompts.ImagenetAnimalPrompts(prefix='A picture of a ', postfix=', 4k unreal engine', num=10000, *args, **kwargs)

Bases: PromptPipeline

Pipeline of prompts consisting of animals from ImageNet, as used in the original DDPO paper.