# Tokenized Data Loading

## Download Only What You Need

Start with a small subset of shards from `karpathy/fineweb-edu-100B-gpt2-token-shards` (each shard is ~200MB) so you can validate the pipeline quickly. Scale up once training is stable and you know the throughput.

## Downloading in Modal

Since Modal functions run in isolated containers, download data inside the function:

```python
def download_tokenized_data():
    """Download FineWeb-Edu GPT-2 token shards."""
    import os
    from huggingface_hub import hf_hub_download

    data_dir = "/tmp/data/fineweb-edu"
    os.makedirs(data_dir, exist_ok=True)

    def get_file(fname):
        if not os.path.exists(os.path.join(data_dir, fname)):
            print(f"Downloading {fname}...")
            hf_hub_download(
                repo_id="karpathy/fineweb-edu-100B-gpt2-token-shards",
                filename=fname,
                repo_type="dataset",
                local_dir=data_dir,
            )

    # Use a few shards for training and hold out one for validation
    get_file("edu_fineweb_train_000001.bin")
    get_file("edu_fineweb_train_000002.bin")
    get_file("edu_fineweb_train_000003.bin")  # validation holdout

    print("Data download complete!")
    return data_dir
```

The dataset is public (no auth required) and ships only training shards; hold out one shard for validation.

## Dataset Files

| File | Purpose |
|------|---------|
| `edu_fineweb_train_000001.bin` | Training shard |
| `edu_fineweb_train_000003.bin` | Validation holdout shard |

The dataset ships training shards only; reserve the last shard you download as validation.

Files typically store GPT-2 token IDs as `uint16` arrays.

## Memory-Mapped Data Loading

Use memory-mapped files for efficient data loading:

```python
import numpy as np
import torch

class TokenizedDataset:
    def __init__(self, data_dir, split="train", block_size=1024):
        self.block_size = block_size

        import os
        all_shards = sorted([
            os.path.join(data_dir, f)
            for f in os.listdir(data_dir)
            if f.startswith("edu_fineweb_train_") and f.endswith(".bin")
        ])

        if len(all_shards) < 2:
            raise ValueError("Need at least 2 shards to create train/val splits")

        if split == "val":
            self.shards = all_shards[-1:]
        else:
            self.shards = all_shards[:-1]

        self.data = [np.memmap(s, dtype=np.uint16, mode="r") for s in self.shards]
        self.lengths = [len(d) for d in self.data]
        self.total_length = sum(self.lengths)

    def _get_tokens(self, global_idx, length):
        """Get tokens starting at global_idx across shards."""
        cumsum = 0
        for i, shard_len in enumerate(self.lengths):
            if global_idx < cumsum + shard_len:
                local_idx = global_idx - cumsum
                return np.array(self.data[i][local_idx:local_idx + length])
            cumsum += shard_len
        raise IndexError("Index out of range")

    def get_batch(self, batch_size, device="cuda"):
        max_start = self.total_length - self.block_size - 1
        starts = torch.randint(0, max_start, (batch_size,))

        x = torch.zeros(batch_size, self.block_size, dtype=torch.long)
        y = torch.zeros(batch_size, self.block_size, dtype=torch.long)

        for i, start in enumerate(starts):
            tokens = self._get_tokens(start.item(), self.block_size + 1)
            x[i] = torch.from_numpy(tokens[:-1].astype(np.int32))
            y[i] = torch.from_numpy(tokens[1:].astype(np.int32))

        return x.to(device), y.to(device)
```

## Usage

```python
# Create dataset
data_dir = download_tokenized_data()
train_dataset = TokenizedDataset(data_dir, split="train", block_size=1024)
val_dataset = TokenizedDataset(data_dir, split="val", block_size=1024)

# Get batch
x, y = train_dataset.get_batch(batch_size=32, device="cuda")
```

## Memory Efficiency

Memory-mapped files:
- Don't load entire dataset into RAM
- Load data on-demand as needed
- Allow training on datasets larger than available RAM
- Each shard is mapped independently
