Autoregressive Transformer trained on hausa_datamix
This is an autoregressive decoder-only transformer model trained on the hausa_datamix dataset using JAX and Flax NNX.
Model Details
- Model Type: Autoregressive Decoder-only Transformer
- Framework: JAX + Flax NNX
- Dataset: hausa_datamix
- Parameters: ~83.9M
- Precision: Mixed (FP32 parameters, BF16 computation)
Architecture
- Hidden Size: 512
- Number of Layers: 8
- Attention Heads: 8
- Intermediate Size: 2048
- Max Position Embeddings: 256
- Vocab Size: 49152
- Rotary Position Embeddings: True
Training Details
- Training Steps: 3,120
- Batch Size: 32
- Gradient Accumulation: 4
- Learning Rate: 0.0003
- Training Duration: 0.35 hours
- Final Eval Loss: 0.9853516221046448
- Final Eval Perplexity: 2.820033550262451
Usage
# This model was trained with JAX/Flax and requires the custom transformer implementation
# to load and use. See the repository for implementation details.
from transformers import AutoTokenizer
import jax.numpy as jnp
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("thiomajid/hausa_lm")
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Example text generation (requires custom model loading)
prompt = "Once upon a time, there was a little"
# ... (model loading and generation code)
Training Configuration
model:
hidden_size: 512
num_layers: 8
num_attention_heads: 8
intermediate_size: 2048
max_position_embeddings: 256
training:
learning_rate: 0.0003
batch_size: 32
epochs: 10
warmup_ratio: 0.1
Files
config.json
: Model configurationtrain_history.json
: Training metrics and durationtokenizer/
: hausa_lm tokenizer filesmodel_checkpoint/
: Best model checkpointtensorboard_logs/
: Training logs for TensorBoard
License
MIT License - see LICENSE file for details.
- Downloads last month
- 11