Recurrent Neural Network with Attention
In this assignment, I implemented a Recurrent Neural Network (RNN) with a scaled dot-product self-attention mechanism to perform next-token prediction on the TinyStories dataset. This project deepened my understanding of sequential modeling and attention-based architectures.
Model Definition:
- Architecture: RNN with a single hidden layer, scaled dot-product self-attention, and a linear language modeling head.
- Key Components:
- RNN Cell: Maintains temporal context through hidden states.
- Self-Attention Layer: Highlights relevant tokens in a sequence via query-key-value mechanism.
- Language Modeling Head: Outputs token probabilities from the attention-enhanced RNN output.
Tasks Accomplished:
- Forward Propagation: Built a computation graph combining token embeddings, RNN hidden states, attention outputs, and a linear projection layer.
- Training Loop: Used stochastic gradient descent to minimize cross-entropy loss over next-token predictions, with proper shifting of targets.
- Evaluation: Performed validation at regular intervals to monitor generalization and convergence.
Implementation Details:
- Embedding Layer: Used
nn.Embedding
to represent tokens as dense vectors.
- Hyperparameter Tuning: Tuned embedding size, hidden dimensions, and attention settings.
- Attention Module: Computed scaled dot-product attention via Q/K/V projections and softmax-weighted aggregation.
- Output Generation: Implemented both greedy decoding and temperature-based sampling for sequence generation.
Empirical Evaluations:
- Tested effects of embedding/hidden dimensions on model accuracy.
- Studied performance vs. batch size and training sequence count.
- Explored sampling temperatures (0, 0.3, 0.8) for text diversity.
Programming Techniques:
- Vectorization: Leveraged PyTorch matrix ops for efficient forward/backward passes.
- Modular Design: Separated RNN, attention, and decoder logic for readability and reuse.
- Debugging Tools: Verified correctness using unit tests and mini datasets.
Outputs and Results:
- Logged per-epoch training and validation losses for trend analysis.
- Generated coherent sample texts via sampling-based decoding strategies.
This project enhanced my understanding of deep learning for language modeling, including recurrent computation, attention mechanisms, and token-level prediction. It also improved my fluency with PyTorch and experimental design.
- Recurrent Neural Networks
- Self-Attention
- Language Modeling
- Cross-Entropy Loss
- Sequence Modeling
- PyTorch
- Token Embedding
- Text Generation