AdityaBDhruva commited on
Commit
95c54f8
·
verified ·
1 Parent(s): 55d87b2

Upload Tranformer_LLM.ipynb

Browse files
Files changed (1) hide show
  1. Tranformer_LLM.ipynb +1762 -0
Tranformer_LLM.ipynb ADDED
@@ -0,0 +1,1762 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 21,
6
+ "id": "initial_id",
7
+ "metadata": {
8
+ "ExecuteTime": {
9
+ "end_time": "2025-04-20T10:30:19.511335Z",
10
+ "start_time": "2025-04-20T10:30:14.130243Z"
11
+ },
12
+ "collapsed": true,
13
+ "id": "initial_id"
14
+ },
15
+ "outputs": [],
16
+ "source": [
17
+ "import numpy as np\n",
18
+ "import pandas as pd\n",
19
+ "import torch\n",
20
+ "import torch.nn as nn\n",
21
+ "import math\n",
22
+ "#import tensorflow as tf"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": 22,
28
+ "id": "420a4dfdadcdee66",
29
+ "metadata": {
30
+ "ExecuteTime": {
31
+ "end_time": "2025-04-20T10:30:21.755678Z",
32
+ "start_time": "2025-04-20T10:30:21.729677Z"
33
+ },
34
+ "colab": {
35
+ "base_uri": "https://localhost:8080/"
36
+ },
37
+ "id": "420a4dfdadcdee66",
38
+ "outputId": "a0132552-6de3-4c64-c3ab-73cdf858dbc0"
39
+ },
40
+ "outputs": [
41
+ {
42
+ "name": "stdout",
43
+ "output_type": "stream",
44
+ "text": [
45
+ "55955\n",
46
+ "India, officially the Republic of India,[j][21] is a country in South Asia. It is the seventh-larges\n"
47
+ ]
48
+ }
49
+ ],
50
+ "source": [
51
+ "with open(\"C:/Users/adity/Projects_of_Aditya/Working/India, officially the Republic of I.txt\",'r',encoding='utf-8') as f:\n",
52
+ " raw_text=f.read()\n",
53
+ "print(len(raw_text))\n",
54
+ "print(raw_text[:100])"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 23,
60
+ "id": "YJ4KwDtekrSy",
61
+ "metadata": {
62
+ "id": "YJ4KwDtekrSy"
63
+ },
64
+ "outputs": [],
65
+ "source": [
66
+ "train_ratio = 0.9\n",
67
+ "train_size = int(train_ratio * len(raw_text))\n",
68
+ "train_text = raw_text[:train_size]\n",
69
+ "val_text = raw_text[train_size:]"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": 24,
75
+ "id": "ebcdc51c",
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "class BinarizeFunction(torch.autograd.Function):\n",
80
+ " @staticmethod\n",
81
+ " def forward(ctx, input):\n",
82
+ " ctx.save_for_backward(input)\n",
83
+ " return torch.sign(input)\n",
84
+ " @staticmethod\n",
85
+ " def backward(ctx, grad_output):\n",
86
+ " input, = ctx.saved_tensors\n",
87
+ " mask=(input.abs()<=1).float()\n",
88
+ " grad_input = grad_output * mask\n",
89
+ " return grad_input"
90
+ ]
91
+ },
92
+ {
93
+ "cell_type": "code",
94
+ "execution_count": 25,
95
+ "id": "6dd4cfd0",
96
+ "metadata": {},
97
+ "outputs": [],
98
+ "source": [
99
+ "class QuantizedLinear(nn.Module):\n",
100
+ " def __init__(self, in_features, out_features, bias=True):\n",
101
+ " super(QuantizedLinear, self).__init__()\n",
102
+ " self.in_features = in_features\n",
103
+ " self.out_features = out_features\n",
104
+ " self.weight = nn.Parameter(torch.Tensor(out_features, in_features))\n",
105
+ " if bias:\n",
106
+ " self.bias = nn.Parameter(torch.Tensor(out_features))\n",
107
+ " else:\n",
108
+ " self.register_parameter('bias', None)\n",
109
+ " self.reset_parameters()\n",
110
+ "\n",
111
+ " def reset_parameters(self):\n",
112
+ " nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))\n",
113
+ " if self.bias is not None:\n",
114
+ " fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)\n",
115
+ " bound = 1 / math.sqrt(fan_in)\n",
116
+ " nn.init.uniform_(self.bias, -bound, bound)\n",
117
+ " def forward(self, input):\n",
118
+ " weight = BinarizeFunction.apply(self.weight)\n",
119
+ " if self.bias is not None:\n",
120
+ " return torch.nn.functional.linear(input, weight, self.bias)\n",
121
+ " else:\n",
122
+ " return torch.nn.functional.linear(input, weight)\n",
123
+ " def extra_repr(self):\n",
124
+ " return 'in_features={}, out_features={}, bias={}'.format(\n",
125
+ " self.in_features, self.out_features, self.bias is not None\n",
126
+ " )"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "code",
131
+ "execution_count": 26,
132
+ "id": "dd29070035dafb99",
133
+ "metadata": {
134
+ "ExecuteTime": {
135
+ "end_time": "2025-04-20T10:30:25.020010Z",
136
+ "start_time": "2025-04-20T10:30:24.959908Z"
137
+ },
138
+ "id": "dd29070035dafb99"
139
+ },
140
+ "outputs": [],
141
+ "source": [
142
+ "from torch.utils.data import Dataset, DataLoader\n",
143
+ "import tiktoken\n",
144
+ "\n",
145
+ "class GPTTokenizerDataset(Dataset):\n",
146
+ " def __init__(self, txt, tokenizer, max_length, stride):\n",
147
+ " self.tokenizer = tokenizer\n",
148
+ " self.input_ids = []\n",
149
+ " self.target_ids = []\n",
150
+ " token_ids = self.tokenizer.encode(txt)\n",
151
+ "\n",
152
+ " for i in range(0, len(token_ids) - max_length, stride):\n",
153
+ " input_chunk = token_ids[i:i + max_length]\n",
154
+ " target_chunk = token_ids[i + 1:i + max_length+1]\n",
155
+ " self.input_ids.append(torch.tensor(input_chunk))\n",
156
+ " self.target_ids.append(torch.tensor(target_chunk))\n",
157
+ " def __len__(self):\n",
158
+ " return len(self.input_ids)\n",
159
+ " def __getitem__(self, idx):\n",
160
+ " return self.input_ids[idx], self.target_ids[idx]\n",
161
+ "def create_dataloader_v1(txt, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True):\n",
162
+ " tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n",
163
+ " dataset = GPTTokenizerDataset(txt, tokenizer, max_length, stride)\n",
164
+ " dataloader = DataLoader(\n",
165
+ " dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last\n",
166
+ " )\n",
167
+ " return dataloader"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": 27,
173
+ "id": "40a9c2660445b78c",
174
+ "metadata": {
175
+ "ExecuteTime": {
176
+ "end_time": "2025-04-20T10:30:52.552634Z",
177
+ "start_time": "2025-04-20T10:30:52.545337Z"
178
+ },
179
+ "id": "40a9c2660445b78c"
180
+ },
181
+ "outputs": [],
182
+ "source": [
183
+ "def generate_text(model,idx,max_new_tokens,context_size,temperature=0.4,top_k=3):\n",
184
+ " for _ in range(max_new_tokens):\n",
185
+ " idx_cond=idx[:,-context_size:]\n",
186
+ " with torch.no_grad():\n",
187
+ " logits=model(idx_cond)\n",
188
+ " logits=logits[:,-1,:]\n",
189
+ " if top_k is not None:\n",
190
+ " top_logits,_=torch.topk(logits,top_k)\n",
191
+ " min_val=top_logits[:,-1]\n",
192
+ " logits=torch.where(logits<min_val,torch.tensor(float('-inf')).to(logits.device),logits)\n",
193
+ " if temperature>0.0:\n",
194
+ " logits=logits/temperature\n",
195
+ " probs=torch.softmax(logits,dim=-1)\n",
196
+ " idx_next=torch.multinomial(probs,num_samples=1)\n",
197
+ " else:\n",
198
+ " idx_next=torch.argmax(logits,dim=-1,keepdim=True)\n",
199
+ " idx=torch.cat((idx,idx_next),dim=1)\n",
200
+ " return idx"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 28,
206
+ "id": "22a98021f476cc4d",
207
+ "metadata": {
208
+ "ExecuteTime": {
209
+ "end_time": "2025-04-20T10:30:56.399874Z",
210
+ "start_time": "2025-04-20T10:30:55.660994Z"
211
+ },
212
+ "id": "22a98021f476cc4d"
213
+ },
214
+ "outputs": [],
215
+ "source": [
216
+ "tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n",
217
+ "def text_to_token_ids(text,tokenizer):\n",
218
+ " encoded=tokenizer.encode(text,allowed_special={'<|endoftext|>'})\n",
219
+ " encoded_tensor=torch.tensor(encoded).unsqueeze(0)\n",
220
+ " return encoded_tensor\n",
221
+ "def token_ids_to_text(token_ids,tokenizer):\n",
222
+ " flat=token_ids.squeeze(0)\n",
223
+ " return tokenizer.decode(flat.tolist())"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "markdown",
228
+ "id": "c34f6594f2501fd3",
229
+ "metadata": {
230
+ "id": "c34f6594f2501fd3"
231
+ },
232
+ "source": [
233
+ "Coding up the Attention model:- Here we would be creating a class of the causal attention and instantiating multiple times for the multihead attention model."
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "markdown",
238
+ "id": "779103be54de3305",
239
+ "metadata": {
240
+ "id": "779103be54de3305"
241
+ },
242
+ "source": [
243
+ "Now for example if we set the number of heads we want is 10, then what exactly happens:-\n",
244
+ "--> we obtain a tensor with ten sets of context vector matrices.\n",
245
+ "--> In each context vector matrix the rows represent the context vectors corresponding to the tokens, and the columns corresponding to the embedding dimension specified via d_out.\n",
246
+ "--> Final embedding dimension is 10 x 10."
247
+ ]
248
+ },
249
+ {
250
+ "cell_type": "markdown",
251
+ "id": "55a1ded1a5143e4b",
252
+ "metadata": {
253
+ "id": "55a1ded1a5143e4b"
254
+ },
255
+ "source": [
256
+ "IMPLEMENTING THE PARALLEL METHOD OF IMPLEMENTATION."
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "code",
261
+ "execution_count": 29,
262
+ "id": "9ffdb4830dd6536c",
263
+ "metadata": {
264
+ "ExecuteTime": {
265
+ "end_time": "2025-04-20T10:31:00.004231Z",
266
+ "start_time": "2025-04-20T10:30:59.989116Z"
267
+ },
268
+ "id": "9ffdb4830dd6536c"
269
+ },
270
+ "outputs": [],
271
+ "source": [
272
+ "class MultiHeadAttention(nn.Module):\n",
273
+ " def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):\n",
274
+ " super().__init__()\n",
275
+ " assert d_out % num_heads == 0, \"d_out must be divisible by num_heads\"\n",
276
+ " self.d_out = d_out\n",
277
+ " self.num_heads = num_heads\n",
278
+ " self.head_dim = d_out // num_heads\n",
279
+ " self.W_query = QuantizedLinear(d_in, d_out, bias=qkv_bias)\n",
280
+ " self.W_key = QuantizedLinear(d_in, d_out, bias=qkv_bias)\n",
281
+ " self.W_value = QuantizedLinear(d_in, d_out, bias=qkv_bias)\n",
282
+ " self.out_proj = QuantizedLinear(d_out, d_out)\n",
283
+ " self.dropout = nn.Dropout(dropout)\n",
284
+ " self.register_buffer(\n",
285
+ " 'mask',\n",
286
+ " torch.triu(torch.ones(context_length, context_length), diagonal=1)\n",
287
+ " )\n",
288
+ " def forward(self, x):\n",
289
+ " b, num_tokens, d_in = x.shape\n",
290
+ " keys = self.W_key(x)\n",
291
+ " queries = self.W_query(x)\n",
292
+ " values = self.W_value(x)\n",
293
+ " keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)\n",
294
+ " values = values.view(b, num_tokens, self.num_heads, self.head_dim)\n",
295
+ " queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)\n",
296
+ " keys = keys.transpose(1, 2)\n",
297
+ " queries = queries.transpose(1, 2)\n",
298
+ " values = values.transpose(1, 2)\n",
299
+ " attn_scores = queries @ keys.transpose(2, 3)\n",
300
+ " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n",
301
+ " attn_scores.masked_fill_(mask_bool, -torch.inf)\n",
302
+ " attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)\n",
303
+ " attn_weights = self.dropout(attn_weights)\n",
304
+ " context_vec = (attn_weights @ values).transpose(1, 2)\n",
305
+ " context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)\n",
306
+ " context_vec = self.out_proj(context_vec)\n",
307
+ " return context_vec"
308
+ ]
309
+ },
310
+ {
311
+ "cell_type": "code",
312
+ "execution_count": 30,
313
+ "id": "a361c4d3",
314
+ "metadata": {},
315
+ "outputs": [
316
+ {
317
+ "name": "stdout",
318
+ "output_type": "stream",
319
+ "text": [
320
+ "Vocab size: 100277\n"
321
+ ]
322
+ }
323
+ ],
324
+ "source": [
325
+ "config_tokenizer=tiktoken.get_encoding(\"cl100k_base\")\n",
326
+ "actual_vocab_size=config_tokenizer.n_vocab\n",
327
+ "print(\"Vocab size:\", actual_vocab_size)"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": 31,
333
+ "id": "4f7ad555c6c06399",
334
+ "metadata": {
335
+ "ExecuteTime": {
336
+ "end_time": "2025-04-20T10:31:03.321536Z",
337
+ "start_time": "2025-04-20T10:31:03.313914Z"
338
+ },
339
+ "id": "4f7ad555c6c06399"
340
+ },
341
+ "outputs": [],
342
+ "source": [
343
+ "#Defining the parameters\n",
344
+ "GPT_CONFIG={\n",
345
+ " 'vocab_size':actual_vocab_size,\n",
346
+ " 'context_length':256, # Change it to 1024 or greater if you have gpu\n",
347
+ " 'embedding_dim':512,\n",
348
+ " 'num_heads':16,\n",
349
+ " 'n_layers':12,\n",
350
+ " 'dropout':0.1,\n",
351
+ " 'qkv_bias':False #Whether to include a bias layer in the linear layers of the multi head attention for query,key and value computations.\n",
352
+ "}"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "markdown",
357
+ "id": "47e51a02ecec92d5",
358
+ "metadata": {
359
+ "id": "47e51a02ecec92d5"
360
+ },
361
+ "source": [
362
+ "Coding up the placeholder architecture, it is like the mothership from where all the robots will branch out"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": 32,
368
+ "id": "4bb79e5ab1baf62a",
369
+ "metadata": {
370
+ "ExecuteTime": {
371
+ "end_time": "2025-04-20T10:31:06.415202Z",
372
+ "start_time": "2025-04-20T10:31:06.403427Z"
373
+ },
374
+ "id": "4bb79e5ab1baf62a"
375
+ },
376
+ "outputs": [],
377
+ "source": [
378
+ "class GPT_Model(nn.Module):\n",
379
+ " def __init__(self, cfg):\n",
380
+ " #The __init__ constructor of this GPTModel class initializes the token and positional embedding layers using the configurations passed in via a Python dictionary, cfg.\n",
381
+ " super().__init__()\n",
382
+ " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"embedding_dim\"])\n",
383
+ " self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"embedding_dim\"])\n",
384
+ " self.drop_emb = nn.Dropout(cfg[\"dropout\"])\n",
385
+ " self.trf_blocks = nn.Sequential(\n",
386
+ " *[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])]\n",
387
+ " )\n",
388
+ " self.final_norm = LayerNormalization(cfg[\"embedding_dim\"])\n",
389
+ " self.out_head = QuantizedLinear(cfg[\"embedding_dim\"], cfg[\"vocab_size\"], bias=False)\n",
390
+ " def forward(self,in_idx):\n",
391
+ " batch_size,seq_len=in_idx.shape\n",
392
+ " in_idx = torch.clamp(in_idx, 0, self.tok_emb.num_embeddings - 1) #This was initially commented out\n",
393
+ " token_embeddings=self.tok_emb(in_idx)\n",
394
+ " positions = torch.arange(seq_len, device=in_idx.device).unsqueeze(0) #this is the extra added line\n",
395
+ " positional_embeddings=self.pos_emb(positions)\n",
396
+ " x=token_embeddings+positional_embeddings\n",
397
+ " x=self.drop_emb(x)\n",
398
+ " x=self.trf_blocks(x)\n",
399
+ " x=self.final_norm(x)\n",
400
+ " logits=self.out_head(x)\n",
401
+ " return logits"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": 33,
407
+ "id": "72748550",
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": [
411
+ "class LayerNormalization(nn.Module):\n",
412
+ " def __init__(self, emb_dim):\n",
413
+ " super().__init__()\n",
414
+ " self.eps = 1e-5\n",
415
+ " self.scale = nn.Parameter(torch.ones(emb_dim))\n",
416
+ " self.shift = nn.Parameter(torch.zeros(emb_dim))\n",
417
+ " def forward(self,x):\n",
418
+ " mean= x.mean(-1, keepdim=True)\n",
419
+ " variance = x.var(-1, keepdim=True)\n",
420
+ " norm_x=(x-mean)/(torch.sqrt(variance+self.eps))\n",
421
+ " return self.scale*norm_x + self.shift"
422
+ ]
423
+ },
424
+ {
425
+ "cell_type": "code",
426
+ "execution_count": 34,
427
+ "id": "b81d6de9cdc325eb",
428
+ "metadata": {
429
+ "ExecuteTime": {
430
+ "end_time": "2025-04-20T10:31:09.094024Z",
431
+ "start_time": "2025-04-20T10:31:09.082533Z"
432
+ },
433
+ "id": "b81d6de9cdc325eb"
434
+ },
435
+ "outputs": [],
436
+ "source": [
437
+ "class TransformerBlock(nn.Module):\n",
438
+ " def __init__(self,config):\n",
439
+ " super().__init__()\n",
440
+ " self.att=MultiHeadAttention(\n",
441
+ " d_in=config[\"embedding_dim\"],\n",
442
+ " d_out=config[\"embedding_dim\"],\n",
443
+ " context_length=config['context_length'],\n",
444
+ " dropout=config['dropout'],\n",
445
+ " num_heads=config['num_heads'],\n",
446
+ " qkv_bias=config['qkv_bias']\n",
447
+ " )\n",
448
+ " self.ff=FeedForward(config)\n",
449
+ " self.norm1=LayerNormalization(config[\"embedding_dim\"])\n",
450
+ " self.norm2=LayerNormalization(config[\"embedding_dim\"])\n",
451
+ " self.drop_resid=nn.Dropout(config['dropout'])\n",
452
+ " def forward(self,x):\n",
453
+ " shortcut=x\n",
454
+ " x=self.norm1(x)\n",
455
+ " x=self.att(x)\n",
456
+ " x=self.drop_resid(x)\n",
457
+ " x=x+shortcut\n",
458
+ " shortcut=x\n",
459
+ " x=self.norm2(x)\n",
460
+ " x=self.ff(x)\n",
461
+ " x=self.drop_resid(x)\n",
462
+ " x=x+shortcut\n",
463
+ " return x"
464
+ ]
465
+ },
466
+ {
467
+ "cell_type": "markdown",
468
+ "id": "ee7086fdb0d258aa",
469
+ "metadata": {
470
+ "id": "ee7086fdb0d258aa"
471
+ },
472
+ "source": [
473
+ "We will use swish activation function."
474
+ ]
475
+ },
476
+ {
477
+ "cell_type": "code",
478
+ "execution_count": 35,
479
+ "id": "aafae17704f79949",
480
+ "metadata": {
481
+ "ExecuteTime": {
482
+ "end_time": "2025-04-20T10:31:14.198107Z",
483
+ "start_time": "2025-04-20T10:31:14.183061Z"
484
+ },
485
+ "id": "aafae17704f79949"
486
+ },
487
+ "outputs": [],
488
+ "source": [
489
+ "class Swish(nn.Module):\n",
490
+ " def __init__(self):\n",
491
+ " super(Swish, self).__init__()\n",
492
+ " def forward(self, x):\n",
493
+ " return x * torch.sigmoid(x)"
494
+ ]
495
+ },
496
+ {
497
+ "cell_type": "code",
498
+ "execution_count": 36,
499
+ "id": "4b3a9eeaf0282a32",
500
+ "metadata": {
501
+ "ExecuteTime": {
502
+ "end_time": "2025-04-20T10:31:16.572707Z",
503
+ "start_time": "2025-04-20T10:31:16.567278Z"
504
+ },
505
+ "id": "4b3a9eeaf0282a32"
506
+ },
507
+ "outputs": [],
508
+ "source": [
509
+ "class FeedForward(nn.Module):\n",
510
+ " def __init__(self, config):\n",
511
+ " super().__init__()\n",
512
+ " self.layers=nn.Sequential(\n",
513
+ " nn.Linear(config[\"embedding_dim\"], 4*config[\"embedding_dim\"]),\n",
514
+ " Swish(),\n",
515
+ " nn.Linear(4*config[\"embedding_dim\"], config[\"embedding_dim\"]),\n",
516
+ " )\n",
517
+ " def forward(self, x):\n",
518
+ " return self.layers(x)"
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": 37,
524
+ "id": "3888c877e7bb59fa",
525
+ "metadata": {
526
+ "ExecuteTime": {
527
+ "end_time": "2025-04-20T10:31:37.956131Z",
528
+ "start_time": "2025-04-20T10:31:37.943199Z"
529
+ },
530
+ "id": "3888c877e7bb59fa"
531
+ },
532
+ "outputs": [],
533
+ "source": [
534
+ "class DeepNeuralNetwork(nn.Module):\n",
535
+ " def __init__(self, layer_sizes,use_shortcut):\n",
536
+ " super().__init__()\n",
537
+ " self.layers=nn.ModuleList([\n",
538
+ " #We would be implementing 10 layers\n",
539
+ " nn.Sequential(nn.Linear(layer_sizes[0], layer_sizes[1])),\n",
540
+ " nn.Sequential(nn.Linear(layer_sizes[1], layer_sizes[2])),\n",
541
+ " nn.Sequential(nn.Linear(layer_sizes[2], layer_sizes[3])),\n",
542
+ " nn.Sequential(nn.Linear(layer_sizes[3], layer_sizes[4])),\n",
543
+ " nn.Sequential(nn.Linear(layer_sizes[4], layer_sizes[5])),\n",
544
+ " nn.Sequential(nn.Linear(layer_sizes[5], layer_sizes[6])),\n",
545
+ " nn.Sequential(nn.Linear(layer_sizes[6], layer_sizes[7])),\n",
546
+ " nn.Sequential(nn.Linear(layer_sizes[7], layer_sizes[8])),\n",
547
+ " nn.Sequential(nn.Linear(layer_sizes[8], layer_sizes[9])),\n",
548
+ " nn.Sequential(nn.Linear(layer_sizes[9], layer_sizes[10])),\n",
549
+ " ])\n",
550
+ " def forward(self,x):\n",
551
+ " for layer in self.layers:\n",
552
+ " #Computing the output of the current layer\n",
553
+ " layer_output=layer(x)\n",
554
+ " #Check if shortcut can be applied\n",
555
+ " if self.use_shortcut and x.shape==layer_output.shape:\n",
556
+ " x=x+layer_output\n",
557
+ " else:\n",
558
+ " x=layer_output\n",
559
+ " return x\n",
560
+ "def print_gradients(model,x):\n",
561
+ " #First would be the forward pass\n",
562
+ " output = model(x)\n",
563
+ " target=torch.tensor([0,])\n",
564
+ " #Loss calculation\n",
565
+ " loss=nn.MSELoss()\n",
566
+ " loss=loss(output,target)\n",
567
+ " loss.backward()\n",
568
+ " for name, param in model.named_parameters():\n",
569
+ " if 'weight' in name:\n",
570
+ " print(f\"{name} grad: {param.grad}\")"
571
+ ]
572
+ },
573
+ {
574
+ "cell_type": "markdown",
575
+ "id": "78ab409a0177825",
576
+ "metadata": {
577
+ "id": "78ab409a0177825"
578
+ },
579
+ "source": [
580
+ "Now let us initialise"
581
+ ]
582
+ },
583
+ {
584
+ "cell_type": "code",
585
+ "execution_count": 38,
586
+ "id": "6710dda1f52d8b41",
587
+ "metadata": {
588
+ "ExecuteTime": {
589
+ "end_time": "2025-04-20T10:31:41.037621Z",
590
+ "start_time": "2025-04-20T10:31:40.974254Z"
591
+ },
592
+ "colab": {
593
+ "base_uri": "https://localhost:8080/"
594
+ },
595
+ "id": "6710dda1f52d8b41",
596
+ "outputId": "c2753e89-89dc-4c5b-c086-53132aded738"
597
+ },
598
+ "outputs": [
599
+ {
600
+ "name": "stdout",
601
+ "output_type": "stream",
602
+ "text": [
603
+ "tensor([[36, 24, 61, 0, 41, 81, 18, 26, 93, 88],\n",
604
+ " [26, 96, 17, 74, 20, 82, 52, 43, 96, 70]])\n"
605
+ ]
606
+ }
607
+ ],
608
+ "source": [
609
+ "batch_size = 2 # Number of samples in the batch\n",
610
+ "sequence_length = 10 # Length of each sequence\n",
611
+ "vocab_size = 100 # Size of the vocabulary\n",
612
+ "batch = torch.randint(0, vocab_size, (batch_size, sequence_length))\n",
613
+ "print(batch)"
614
+ ]
615
+ },
616
+ {
617
+ "cell_type": "code",
618
+ "execution_count": 39,
619
+ "id": "b376992b9eb9a68c",
620
+ "metadata": {
621
+ "ExecuteTime": {
622
+ "end_time": "2025-04-20T10:31:44.349704Z",
623
+ "start_time": "2025-04-20T10:31:43.391715Z"
624
+ },
625
+ "colab": {
626
+ "base_uri": "https://localhost:8080/"
627
+ },
628
+ "id": "b376992b9eb9a68c",
629
+ "outputId": "f67dc607-f218-4c20-848d-47212f38b749"
630
+ },
631
+ "outputs": [
632
+ {
633
+ "name": "stdout",
634
+ "output_type": "stream",
635
+ "text": [
636
+ "Input batch:\n",
637
+ " tensor([[36, 24, 61, 0, 41, 81, 18, 26, 93, 88],\n",
638
+ " [26, 96, 17, 74, 20, 82, 52, 43, 96, 70]])\n",
639
+ "Output batch:\n",
640
+ " torch.Size([2, 10, 100277])\n",
641
+ "tensor([[[ 1.6182e+01, -1.6015e+01, -9.4095e+00, ..., 3.0794e-03,\n",
642
+ " 2.9054e+01, 1.6988e+01],\n",
643
+ " [ 5.2240e+00, 2.7572e+01, -6.9735e+00, ..., -8.0013e+00,\n",
644
+ " -4.0101e-01, 2.8758e+01],\n",
645
+ " [ 6.6475e+00, -1.1150e+01, 7.9781e+00, ..., -2.5136e+01,\n",
646
+ " 7.3388e+00, 9.9231e+00],\n",
647
+ " ...,\n",
648
+ " [-4.3846e+00, -1.7154e+01, 1.0174e+01, ..., -4.6591e+00,\n",
649
+ " -8.3947e+00, 1.1043e+01],\n",
650
+ " [ 3.5968e+01, -2.7967e+00, -2.8498e+01, ..., -2.2024e+00,\n",
651
+ " -1.1003e+01, -2.4883e-02],\n",
652
+ " [ 1.9451e+01, -3.6966e+01, 7.5978e+00, ..., 9.3602e+00,\n",
653
+ " 8.6090e+00, -2.6628e+00]],\n",
654
+ "\n",
655
+ " [[-2.8687e+01, 1.6627e+01, -1.4998e+01, ..., -1.7184e+01,\n",
656
+ " 2.0726e+01, 8.0321e+00],\n",
657
+ " [-4.0979e+01, 6.5536e-01, 4.1383e+00, ..., -1.2853e+01,\n",
658
+ " -1.7279e+01, -1.3240e+01],\n",
659
+ " [-1.9607e+01, 2.3471e+00, 7.2976e+00, ..., 4.8977e-01,\n",
660
+ " -1.7134e+01, 3.4321e+00],\n",
661
+ " ...,\n",
662
+ " [-1.1025e+01, -2.4218e+00, 2.6663e+01, ..., 1.4770e+00,\n",
663
+ " -4.0925e+01, 5.0661e-01],\n",
664
+ " [-3.4426e+01, -2.2701e+00, 2.6099e+01, ..., -1.2846e+01,\n",
665
+ " -2.4183e+01, -4.9127e+01],\n",
666
+ " [ 1.6595e+00, -1.6062e+00, 1.8436e+01, ..., 3.3674e+01,\n",
667
+ " -3.5222e+01, -2.4692e+01]]], grad_fn=<UnsafeViewBackward0>)\n"
668
+ ]
669
+ }
670
+ ],
671
+ "source": [
672
+ "torch.manual_seed(123)\n",
673
+ "model=GPT_Model(GPT_CONFIG)\n",
674
+ "out=model(batch)\n",
675
+ "print(\"Input batch:\\n\",batch)\n",
676
+ "print(\"Output batch:\\n\",out.shape)\n",
677
+ "print(out)"
678
+ ]
679
+ },
680
+ {
681
+ "cell_type": "markdown",
682
+ "id": "32204ab3e2917ca1",
683
+ "metadata": {
684
+ "id": "32204ab3e2917ca1"
685
+ },
686
+ "source": [
687
+ "Displaying the number of parameters for the GPT model"
688
+ ]
689
+ },
690
+ {
691
+ "cell_type": "code",
692
+ "execution_count": 40,
693
+ "id": "bfd0d944c222bfbf",
694
+ "metadata": {
695
+ "ExecuteTime": {
696
+ "end_time": "2025-04-20T10:31:49.707504Z",
697
+ "start_time": "2025-04-20T10:31:49.699751Z"
698
+ },
699
+ "colab": {
700
+ "base_uri": "https://localhost:8080/"
701
+ },
702
+ "id": "bfd0d944c222bfbf",
703
+ "outputId": "bbad64e6-f379-475e-80d5-6d3fe5e79824"
704
+ },
705
+ "outputs": [
706
+ {
707
+ "name": "stdout",
708
+ "output_type": "stream",
709
+ "text": [
710
+ "Total number of parameters: 140625920\n",
711
+ "Token embedding layer shape: torch.Size([100277, 512])\n",
712
+ "Output layer shape: torch.Size([100277, 512])\n"
713
+ ]
714
+ }
715
+ ],
716
+ "source": [
717
+ "total_parameters=sum(p.numel() for p in model.parameters())\n",
718
+ "print(f\"Total number of parameters: {total_parameters}\")\n",
719
+ "print(\"Token embedding layer shape:\", model.tok_emb.weight.shape)\n",
720
+ "print(\"Output layer shape:\", model.out_head.weight.shape)"
721
+ ]
722
+ },
723
+ {
724
+ "cell_type": "markdown",
725
+ "id": "c2b39710a7897efb",
726
+ "metadata": {
727
+ "id": "c2b39710a7897efb"
728
+ },
729
+ "source": [
730
+ "Number of trainable parameters in the model"
731
+ ]
732
+ },
733
+ {
734
+ "cell_type": "code",
735
+ "execution_count": 41,
736
+ "id": "e047e3f5d5b4e540",
737
+ "metadata": {
738
+ "ExecuteTime": {
739
+ "end_time": "2025-04-20T10:31:53.034490Z",
740
+ "start_time": "2025-04-20T10:31:53.027104Z"
741
+ },
742
+ "colab": {
743
+ "base_uri": "https://localhost:8080/"
744
+ },
745
+ "id": "e047e3f5d5b4e540",
746
+ "outputId": "b1793806-df53-4cf2-a09d-52e8485bb35f"
747
+ },
748
+ "outputs": [
749
+ {
750
+ "name": "stdout",
751
+ "output_type": "stream",
752
+ "text": [
753
+ "Number of trainable parameters considering weight tying: 89284096\n"
754
+ ]
755
+ }
756
+ ],
757
+ "source": [
758
+ "total_params_gpt2 = total_parameters - sum(p.numel() for p in model.out_head.parameters())\n",
759
+ "print(f\"Number of trainable parameters considering weight tying: {total_params_gpt2}\")"
760
+ ]
761
+ },
762
+ {
763
+ "cell_type": "code",
764
+ "execution_count": 42,
765
+ "id": "f611c62fb559142f",
766
+ "metadata": {
767
+ "ExecuteTime": {
768
+ "end_time": "2025-04-20T10:31:57.287950Z",
769
+ "start_time": "2025-04-20T10:31:57.279346Z"
770
+ },
771
+ "colab": {
772
+ "base_uri": "https://localhost:8080/"
773
+ },
774
+ "id": "f611c62fb559142f",
775
+ "outputId": "24b7ef8b-df10-40a3-b192-46a8d32cf3e3"
776
+ },
777
+ "outputs": [
778
+ {
779
+ "name": "stdout",
780
+ "output_type": "stream",
781
+ "text": [
782
+ "Total size of the model : 536.45 MB\n"
783
+ ]
784
+ }
785
+ ],
786
+ "source": [
787
+ "total_size_in_bytes=total_parameters*4\n",
788
+ "\n",
789
+ "total_size_of_the_model_in_MB=total_size_in_bytes/(1024*1024)\n",
790
+ "print(f\"Total size of the model : {total_size_of_the_model_in_MB:.2f} MB\")"
791
+ ]
792
+ },
793
+ {
794
+ "cell_type": "markdown",
795
+ "id": "645fa9c01a21b0e3",
796
+ "metadata": {
797
+ "id": "645fa9c01a21b0e3"
798
+ },
799
+ "source": [
800
+ "Total size of the model : 341.55 MB\n",
801
+ "Number of trainable parameters considering weight tying: 63935488\n"
802
+ ]
803
+ },
804
+ {
805
+ "cell_type": "markdown",
806
+ "id": "e32325eb6463fa21",
807
+ "metadata": {
808
+ "id": "e32325eb6463fa21"
809
+ },
810
+ "source": [
811
+ "The next step is to now decode these tensors to proper text. Which would be coding up in the subsequent steps"
812
+ ]
813
+ },
814
+ {
815
+ "cell_type": "code",
816
+ "execution_count": 43,
817
+ "id": "af8f873de4b1ea1f",
818
+ "metadata": {
819
+ "ExecuteTime": {
820
+ "end_time": "2025-04-20T10:36:18.521800Z",
821
+ "start_time": "2025-04-20T10:36:18.507080Z"
822
+ },
823
+ "colab": {
824
+ "base_uri": "https://localhost:8080/"
825
+ },
826
+ "id": "af8f873de4b1ea1f",
827
+ "outputId": "8761b2e0-af06-4027-fc7b-b09c306d69cf"
828
+ },
829
+ "outputs": [
830
+ {
831
+ "name": "stdout",
832
+ "output_type": "stream",
833
+ "text": [
834
+ "[9906, 11, 358, 1097, 2467, 488, 64, 13]\n"
835
+ ]
836
+ }
837
+ ],
838
+ "source": [
839
+ "#Let us try out the decoding procedure\n",
840
+ "start_context=\"Hello, I am Aditya.\"\n",
841
+ "tokenizer = tiktoken.get_encoding(\"cl100k_base\")\n",
842
+ "encoded=tokenizer.encode(start_context)\n",
843
+ "print(encoded)"
844
+ ]
845
+ },
846
+ {
847
+ "cell_type": "code",
848
+ "execution_count": 44,
849
+ "id": "baf2d02c627a5911",
850
+ "metadata": {
851
+ "ExecuteTime": {
852
+ "end_time": "2025-04-20T10:32:31.432690Z",
853
+ "start_time": "2025-04-20T10:32:31.416839Z"
854
+ },
855
+ "colab": {
856
+ "base_uri": "https://localhost:8080/"
857
+ },
858
+ "id": "baf2d02c627a5911",
859
+ "outputId": "b6a59155-048a-49e4-c1b5-683dbbad8f0a"
860
+ },
861
+ "outputs": [
862
+ {
863
+ "data": {
864
+ "text/plain": [
865
+ "GPT_Model(\n",
866
+ " (tok_emb): Embedding(100277, 512)\n",
867
+ " (pos_emb): Embedding(256, 512)\n",
868
+ " (drop_emb): Dropout(p=0.1, inplace=False)\n",
869
+ " (trf_blocks): Sequential(\n",
870
+ " (0): TransformerBlock(\n",
871
+ " (att): MultiHeadAttention(\n",
872
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
873
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
874
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
875
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
876
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
877
+ " )\n",
878
+ " (ff): FeedForward(\n",
879
+ " (layers): Sequential(\n",
880
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
881
+ " (1): Swish()\n",
882
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
883
+ " )\n",
884
+ " )\n",
885
+ " (norm1): LayerNormalization()\n",
886
+ " (norm2): LayerNormalization()\n",
887
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
888
+ " )\n",
889
+ " (1): TransformerBlock(\n",
890
+ " (att): MultiHeadAttention(\n",
891
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
892
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
893
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
894
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
895
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
896
+ " )\n",
897
+ " (ff): FeedForward(\n",
898
+ " (layers): Sequential(\n",
899
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
900
+ " (1): Swish()\n",
901
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
902
+ " )\n",
903
+ " )\n",
904
+ " (norm1): LayerNormalization()\n",
905
+ " (norm2): LayerNormalization()\n",
906
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
907
+ " )\n",
908
+ " (2): TransformerBlock(\n",
909
+ " (att): MultiHeadAttention(\n",
910
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
911
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
912
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
913
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
914
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
915
+ " )\n",
916
+ " (ff): FeedForward(\n",
917
+ " (layers): Sequential(\n",
918
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
919
+ " (1): Swish()\n",
920
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
921
+ " )\n",
922
+ " )\n",
923
+ " (norm1): LayerNormalization()\n",
924
+ " (norm2): LayerNormalization()\n",
925
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
926
+ " )\n",
927
+ " (3): TransformerBlock(\n",
928
+ " (att): MultiHeadAttention(\n",
929
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
930
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
931
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
932
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
933
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
934
+ " )\n",
935
+ " (ff): FeedForward(\n",
936
+ " (layers): Sequential(\n",
937
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
938
+ " (1): Swish()\n",
939
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
940
+ " )\n",
941
+ " )\n",
942
+ " (norm1): LayerNormalization()\n",
943
+ " (norm2): LayerNormalization()\n",
944
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
945
+ " )\n",
946
+ " (4): TransformerBlock(\n",
947
+ " (att): MultiHeadAttention(\n",
948
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
949
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
950
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
951
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
952
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
953
+ " )\n",
954
+ " (ff): FeedForward(\n",
955
+ " (layers): Sequential(\n",
956
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
957
+ " (1): Swish()\n",
958
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
959
+ " )\n",
960
+ " )\n",
961
+ " (norm1): LayerNormalization()\n",
962
+ " (norm2): LayerNormalization()\n",
963
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
964
+ " )\n",
965
+ " (5): TransformerBlock(\n",
966
+ " (att): MultiHeadAttention(\n",
967
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
968
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
969
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
970
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
971
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
972
+ " )\n",
973
+ " (ff): FeedForward(\n",
974
+ " (layers): Sequential(\n",
975
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
976
+ " (1): Swish()\n",
977
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
978
+ " )\n",
979
+ " )\n",
980
+ " (norm1): LayerNormalization()\n",
981
+ " (norm2): LayerNormalization()\n",
982
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
983
+ " )\n",
984
+ " (6): TransformerBlock(\n",
985
+ " (att): MultiHeadAttention(\n",
986
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
987
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
988
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
989
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
990
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
991
+ " )\n",
992
+ " (ff): FeedForward(\n",
993
+ " (layers): Sequential(\n",
994
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
995
+ " (1): Swish()\n",
996
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
997
+ " )\n",
998
+ " )\n",
999
+ " (norm1): LayerNormalization()\n",
1000
+ " (norm2): LayerNormalization()\n",
1001
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1002
+ " )\n",
1003
+ " (7): TransformerBlock(\n",
1004
+ " (att): MultiHeadAttention(\n",
1005
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1006
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1007
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1008
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
1009
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1010
+ " )\n",
1011
+ " (ff): FeedForward(\n",
1012
+ " (layers): Sequential(\n",
1013
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1014
+ " (1): Swish()\n",
1015
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1016
+ " )\n",
1017
+ " )\n",
1018
+ " (norm1): LayerNormalization()\n",
1019
+ " (norm2): LayerNormalization()\n",
1020
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1021
+ " )\n",
1022
+ " (8): TransformerBlock(\n",
1023
+ " (att): MultiHeadAttention(\n",
1024
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1025
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1026
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1027
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
1028
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1029
+ " )\n",
1030
+ " (ff): FeedForward(\n",
1031
+ " (layers): Sequential(\n",
1032
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1033
+ " (1): Swish()\n",
1034
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1035
+ " )\n",
1036
+ " )\n",
1037
+ " (norm1): LayerNormalization()\n",
1038
+ " (norm2): LayerNormalization()\n",
1039
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1040
+ " )\n",
1041
+ " (9): TransformerBlock(\n",
1042
+ " (att): MultiHeadAttention(\n",
1043
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1044
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1045
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1046
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
1047
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1048
+ " )\n",
1049
+ " (ff): FeedForward(\n",
1050
+ " (layers): Sequential(\n",
1051
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1052
+ " (1): Swish()\n",
1053
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1054
+ " )\n",
1055
+ " )\n",
1056
+ " (norm1): LayerNormalization()\n",
1057
+ " (norm2): LayerNormalization()\n",
1058
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1059
+ " )\n",
1060
+ " (10): TransformerBlock(\n",
1061
+ " (att): MultiHeadAttention(\n",
1062
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1063
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1064
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1065
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
1066
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1067
+ " )\n",
1068
+ " (ff): FeedForward(\n",
1069
+ " (layers): Sequential(\n",
1070
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1071
+ " (1): Swish()\n",
1072
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1073
+ " )\n",
1074
+ " )\n",
1075
+ " (norm1): LayerNormalization()\n",
1076
+ " (norm2): LayerNormalization()\n",
1077
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1078
+ " )\n",
1079
+ " (11): TransformerBlock(\n",
1080
+ " (att): MultiHeadAttention(\n",
1081
+ " (W_query): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1082
+ " (W_key): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1083
+ " (W_value): QuantizedLinear(in_features=512, out_features=512, bias=False)\n",
1084
+ " (out_proj): QuantizedLinear(in_features=512, out_features=512, bias=True)\n",
1085
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1086
+ " )\n",
1087
+ " (ff): FeedForward(\n",
1088
+ " (layers): Sequential(\n",
1089
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1090
+ " (1): Swish()\n",
1091
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1092
+ " )\n",
1093
+ " )\n",
1094
+ " (norm1): LayerNormalization()\n",
1095
+ " (norm2): LayerNormalization()\n",
1096
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1097
+ " )\n",
1098
+ " )\n",
1099
+ " (final_norm): LayerNormalization()\n",
1100
+ " (out_head): QuantizedLinear(in_features=512, out_features=100277, bias=False)\n",
1101
+ ")"
1102
+ ]
1103
+ },
1104
+ "execution_count": 44,
1105
+ "metadata": {},
1106
+ "output_type": "execute_result"
1107
+ }
1108
+ ],
1109
+ "source": [
1110
+ "model.eval()"
1111
+ ]
1112
+ },
1113
+ {
1114
+ "cell_type": "code",
1115
+ "execution_count": 45,
1116
+ "id": "8e6a5e5afc3272d6",
1117
+ "metadata": {
1118
+ "ExecuteTime": {
1119
+ "end_time": "2025-04-20T10:36:21.766425Z",
1120
+ "start_time": "2025-04-20T10:36:21.340642Z"
1121
+ },
1122
+ "colab": {
1123
+ "base_uri": "https://localhost:8080/"
1124
+ },
1125
+ "id": "8e6a5e5afc3272d6",
1126
+ "outputId": "4b2dcdff-161f-47c8-cca4-e84a9e117e2f"
1127
+ },
1128
+ "outputs": [
1129
+ {
1130
+ "name": "stdout",
1131
+ "output_type": "stream",
1132
+ "text": [
1133
+ "Output:\n",
1134
+ " tensor([[ 9906, 11, 358, 1097, 2467, 488, 64, 13, 48400, 85624,\n",
1135
+ " 1993, 61732, 73414, 87133]])\n"
1136
+ ]
1137
+ }
1138
+ ],
1139
+ "source": [
1140
+ "model.eval()\n",
1141
+ "out=generate_text(model=model,idx=torch.tensor(encoded).unsqueeze(0),max_new_tokens=6,context_size=GPT_CONFIG[\"context_length\"])\n",
1142
+ "print(\"Output:\\n\",out)"
1143
+ ]
1144
+ },
1145
+ {
1146
+ "cell_type": "code",
1147
+ "execution_count": 46,
1148
+ "id": "1ffca81eb2e208dd",
1149
+ "metadata": {
1150
+ "ExecuteTime": {
1151
+ "end_time": "2025-04-20T10:36:31.970156Z",
1152
+ "start_time": "2025-04-20T10:36:30.980631Z"
1153
+ },
1154
+ "colab": {
1155
+ "base_uri": "https://localhost:8080/"
1156
+ },
1157
+ "id": "1ffca81eb2e208dd",
1158
+ "outputId": "5d1b6fe6-0368-46c9-ead1-7cc1a3174322"
1159
+ },
1160
+ "outputs": [
1161
+ {
1162
+ "name": "stdout",
1163
+ "output_type": "stream",
1164
+ "text": [
1165
+ "Output text:\n",
1166
+ " Hello, I am Aditya I want to become a CEO one day of my own company steadily;/*\tmodel collateral字符 Lois Middletonarios_DECL loophole\n"
1167
+ ]
1168
+ }
1169
+ ],
1170
+ "source": [
1171
+ "start_context=\"Hello, I am Aditya I want to become a CEO one day of my own company\"\n",
1172
+ "token_ids=generate_text(model=model,idx=text_to_token_ids(start_context,tokenizer),max_new_tokens=10,context_size=GPT_CONFIG[\"context_length\"])\n",
1173
+ "print(\"Output text:\\n\",token_ids_to_text(token_ids,tokenizer))"
1174
+ ]
1175
+ },
1176
+ {
1177
+ "cell_type": "code",
1178
+ "execution_count": 47,
1179
+ "id": "yxZH4QzR-ydZ",
1180
+ "metadata": {
1181
+ "colab": {
1182
+ "base_uri": "https://localhost:8080/"
1183
+ },
1184
+ "id": "yxZH4QzR-ydZ",
1185
+ "outputId": "d46883fa-15f6-44e9-d69f-797a3af7a8c4"
1186
+ },
1187
+ "outputs": [
1188
+ {
1189
+ "name": "stdout",
1190
+ "output_type": "stream",
1191
+ "text": [
1192
+ "torch.Size([1, 14, 100277])\n"
1193
+ ]
1194
+ }
1195
+ ],
1196
+ "source": [
1197
+ "inputs=torch.tensor([[ 9906, 11, 358, 1097, 2467, 488, 64, 13, 41867, 40540,\n",
1198
+ " 15145, 30876, 46468, 30001]]) # Remove extra comma and parenthesis to make it a tensor\n",
1199
+ "with torch.no_grad():\n",
1200
+ " logits=model(inputs)\n",
1201
+ "probas=torch.softmax(logits,dim=-1)\n",
1202
+ "print(probas.shape)"
1203
+ ]
1204
+ },
1205
+ {
1206
+ "cell_type": "code",
1207
+ "execution_count": 48,
1208
+ "id": "MTItfymWGhRZ",
1209
+ "metadata": {
1210
+ "id": "MTItfymWGhRZ"
1211
+ },
1212
+ "outputs": [],
1213
+ "source": [
1214
+ "torch.manual_seed(123)\n",
1215
+ "train_loader=create_dataloader_v1(train_text,batch_size=4,max_length=GPT_CONFIG[\"context_length\"],\n",
1216
+ " stride=GPT_CONFIG['context_length'],\n",
1217
+ " drop_last=True,\n",
1218
+ " shuffle=True\n",
1219
+ " )\n",
1220
+ "val_loader=create_dataloader_v1(val_text,batch_size=4,max_length=GPT_CONFIG[\"context_length\"],\n",
1221
+ " stride=GPT_CONFIG['context_length'],\n",
1222
+ " drop_last=True,\n",
1223
+ " shuffle=True\n",
1224
+ " )"
1225
+ ]
1226
+ },
1227
+ {
1228
+ "cell_type": "code",
1229
+ "execution_count": 49,
1230
+ "id": "e853b287",
1231
+ "metadata": {},
1232
+ "outputs": [
1233
+ {
1234
+ "name": "stdout",
1235
+ "output_type": "stream",
1236
+ "text": [
1237
+ "Train loader:\n",
1238
+ "torch.Size([4, 256]) torch.Size([4, 256])\n",
1239
+ "torch.Size([4, 256]) torch.Size([4, 256])\n",
1240
+ "torch.Size([4, 256]) torch.Size([4, 256])\n",
1241
+ "torch.Size([4, 256]) torch.Size([4, 256])\n",
1242
+ "torch.Size([4, 256]) torch.Size([4, 256])\n",
1243
+ "torch.Size([4, 256]) torch.Size([4, 256])\n",
1244
+ "torch.Size([4, 256]) torch.Size([4, 256])\n",
1245
+ "torch.Size([4, 256]) torch.Size([4, 256])\n",
1246
+ "torch.Size([4, 256]) torch.Size([4, 256])\n",
1247
+ "torch.Size([4, 256]) torch.Size([4, 256])\n",
1248
+ "torch.Size([4, 256]) torch.Size([4, 256])\n",
1249
+ "\n",
1250
+ " Validation Loader:\n",
1251
+ "torch.Size([4, 256]) torch.Size([4, 256])\n"
1252
+ ]
1253
+ }
1254
+ ],
1255
+ "source": [
1256
+ "print(\"Train loader:\")\n",
1257
+ "for x,y in train_loader:\n",
1258
+ " print(x.shape,y.shape)\n",
1259
+ "print(\"\\n Validation Loader:\")\n",
1260
+ "for x,y in val_loader:\n",
1261
+ " print(x.shape,y.shape)\n",
1262
+ "# The output implies that the model has 18 training set batches with 2 samples and 256 tokens each"
1263
+ ]
1264
+ },
1265
+ {
1266
+ "cell_type": "code",
1267
+ "execution_count": 50,
1268
+ "id": "Df2uwuFnmOp3",
1269
+ "metadata": {
1270
+ "id": "Df2uwuFnmOp3"
1271
+ },
1272
+ "outputs": [],
1273
+ "source": [
1274
+ "def calculation_of_loss(input_batch,target_batch,model,device):\n",
1275
+ " input_batch,target_batch=input_batch.to(device),target_batch.to(device)\n",
1276
+ " logits=model(input_batch)\n",
1277
+ " loss=torch.nn.functional.cross_entropy(logits.flatten(0,1),target_batch.flatten())\n",
1278
+ " return loss"
1279
+ ]
1280
+ },
1281
+ {
1282
+ "cell_type": "code",
1283
+ "execution_count": 51,
1284
+ "id": "hdoiK6MLcrYV",
1285
+ "metadata": {
1286
+ "id": "hdoiK6MLcrYV"
1287
+ },
1288
+ "outputs": [],
1289
+ "source": [
1290
+ "def loss_loader(data_loader, model, device, num_batches=4):\n",
1291
+ " total_loss = 0 \n",
1292
+ " for i, (input_batch, target_batch) in enumerate(data_loader):\n",
1293
+ " if i < num_batches:\n",
1294
+ " loss = calculation_of_loss(input_batch, target_batch, model, device)\n",
1295
+ " total_loss += loss.item()\n",
1296
+ " else:\n",
1297
+ " break\n",
1298
+ " return total_loss / num_batches"
1299
+ ]
1300
+ },
1301
+ {
1302
+ "cell_type": "code",
1303
+ "execution_count": 52,
1304
+ "id": "x89QUR65ePEs",
1305
+ "metadata": {
1306
+ "colab": {
1307
+ "base_uri": "https://localhost:8080/",
1308
+ "height": 383
1309
+ },
1310
+ "id": "x89QUR65ePEs",
1311
+ "outputId": "7b4bc307-b3fb-45b7-d067-724481f7bbce"
1312
+ },
1313
+ "outputs": [
1314
+ {
1315
+ "name": "stdout",
1316
+ "output_type": "stream",
1317
+ "text": [
1318
+ "Train loss: 98.4413\n",
1319
+ "Validation loss: 24.3542\n"
1320
+ ]
1321
+ }
1322
+ ],
1323
+ "source": [
1324
+ "device='cpu'\n",
1325
+ "model.to(device)\n",
1326
+ "train_loss = loss_loader(train_loader, model, device='cpu',num_batches=4)\n",
1327
+ "val_loss=loss_loader(val_loader,model,device='cpu',num_batches=4)\n",
1328
+ "print(f\"Train loss: {train_loss:.4f}\")\n",
1329
+ "print(f\"Validation loss: {val_loss:.4f}\")"
1330
+ ]
1331
+ },
1332
+ {
1333
+ "cell_type": "code",
1334
+ "execution_count": 53,
1335
+ "id": "4aa447fc",
1336
+ "metadata": {},
1337
+ "outputs": [
1338
+ {
1339
+ "name": "stdout",
1340
+ "output_type": "stream",
1341
+ "text": [
1342
+ "11\n",
1343
+ "1\n"
1344
+ ]
1345
+ }
1346
+ ],
1347
+ "source": [
1348
+ "print(len(train_loader))\n",
1349
+ "print(len(val_loader))"
1350
+ ]
1351
+ },
1352
+ {
1353
+ "cell_type": "code",
1354
+ "execution_count": 54,
1355
+ "id": "a0020a0e",
1356
+ "metadata": {},
1357
+ "outputs": [],
1358
+ "source": [
1359
+ "def train_the_model(model,train_loader,val_loader,epochs=1,learning_rate=3e-4):\n",
1360
+ " optimizer=torch.optim.AdamW(model.parameters(),lr=learning_rate)\n",
1361
+ " for epoch in range(epochs):\n",
1362
+ " model.train()\n",
1363
+ " for i,(input_batch,target_batch) in enumerate(train_loader):\n",
1364
+ " input_batch,target_batch=input_batch.to(device),target_batch.to(device)\n",
1365
+ " optimizer.zero_grad()\n",
1366
+ " logits=model(input_batch)\n",
1367
+ " loss=torch.nn.functional.cross_entropy(logits.flatten(0,1),target_batch.flatten())\n",
1368
+ " loss.backward()\n",
1369
+ " optimizer.step()\n",
1370
+ " if i%100==0:\n",
1371
+ " print(f\"Epoch {epoch+1}/{epochs}, Batch {i}/{len(train_loader)}, Loss: {loss.item():.4f}\")\n",
1372
+ " model.eval()\n",
1373
+ " train_loss = loss_loader(train_loader, model, device='cpu',num_batches=4)\n",
1374
+ " val_loss = loss_loader(val_loader, model, device='cpu',num_batches=4)\n",
1375
+ " print(f\"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}\")\n",
1376
+ " return train_loss, val_loss"
1377
+ ]
1378
+ },
1379
+ {
1380
+ "cell_type": "code",
1381
+ "execution_count": 55,
1382
+ "id": "b8407429",
1383
+ "metadata": {},
1384
+ "outputs": [],
1385
+ "source": [
1386
+ "def evaluate_model(model,train_loader, val_loader, device='cpu', num_batches=4):\n",
1387
+ " model.eval()\n",
1388
+ " with torch.no_grad():\n",
1389
+ " train_loss = loss_loader(train_loader, model, device=device, num_batches=num_batches)\n",
1390
+ " val_loss = loss_loader(val_loader, model, device=device, num_batches=num_batches)\n",
1391
+ " model.train()\n",
1392
+ " print(f\"Train Loss: {train_loss:.4f}\")\n",
1393
+ " print(f\"Validation Loss: {val_loss:.4f}\")\n",
1394
+ " return train_loss, val_loss"
1395
+ ]
1396
+ },
1397
+ {
1398
+ "cell_type": "code",
1399
+ "execution_count": 56,
1400
+ "id": "96d3965f",
1401
+ "metadata": {},
1402
+ "outputs": [
1403
+ {
1404
+ "name": "stdout",
1405
+ "output_type": "stream",
1406
+ "text": [
1407
+ "Epoch 1/10, Batch 0/11, Loss: 98.6930\n",
1408
+ "Epoch 1/10, Train Loss: 94.4102, Validation Loss: 23.4683\n"
1409
+ ]
1410
+ }
1411
+ ],
1412
+ "source": [
1413
+ "torch.manual_seed(123)\n",
1414
+ "model=GPT_Model(GPT_CONFIG)\n",
1415
+ "model.to(device)\n",
1416
+ "train_loss, val_loss = train_the_model(model, train_loader, val_loader, epochs=10, learning_rate=3e-4)"
1417
+ ]
1418
+ },
1419
+ {
1420
+ "cell_type": "code",
1421
+ "execution_count": 57,
1422
+ "id": "fac91e1d",
1423
+ "metadata": {},
1424
+ "outputs": [
1425
+ {
1426
+ "name": "stdout",
1427
+ "output_type": "stream",
1428
+ "text": [
1429
+ "Output text:\n",
1430
+ " Hi\traise pitched že beh Difference_rg Commons licens\tsh taped LSUesco microseconds haberhandleRequest\n",
1431
+ "Output text:\n",
1432
+ " Can you talk in english-authored Alert 값을 together Arlington Pert DatePicker CitProductName/mswonerrassouth995 considerably\n",
1433
+ "Output text:\n",
1434
+ " Yup little bit less chinese\tll amongst Companies_Details_Details_Details_Details(diistribute sampano PUasingbowerazzo\n"
1435
+ ]
1436
+ },
1437
+ {
1438
+ "ename": "RuntimeError",
1439
+ "evalue": "Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)",
1440
+ "output_type": "error",
1441
+ "traceback": [
1442
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
1443
+ "\u001b[31mRuntimeError\u001b[39m Traceback (most recent call last)",
1444
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[57]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;28;01mTrue\u001b[39;00m:\n\u001b[32m 2\u001b[39m start_context=\u001b[38;5;28minput\u001b[39m()\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m token_ids=\u001b[43mgenerate_text\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m=\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43midx\u001b[49m\u001b[43m=\u001b[49m\u001b[43mtext_to_token_ids\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstart_context\u001b[49m\u001b[43m,\u001b[49m\u001b[43mtokenizer\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43mmax_new_tokens\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m15\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mcontext_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mGPT_CONFIG\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mcontext_length\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m,\u001b[49m\u001b[43mtemperature\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m0.4\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43mtop_k\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m3\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 4\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mOutput text:\u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[33m\"\u001b[39m,token_ids_to_text(token_ids,tokenizer))\n",
1445
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[27]\u001b[39m\u001b[32m, line 5\u001b[39m, in \u001b[36mgenerate_text\u001b[39m\u001b[34m(model, idx, max_new_tokens, context_size, temperature, top_k)\u001b[39m\n\u001b[32m 3\u001b[39m idx_cond=idx[:,-context_size:]\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m torch.no_grad():\n\u001b[32m----> \u001b[39m\u001b[32m5\u001b[39m logits=\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43midx_cond\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 6\u001b[39m logits=logits[:,-\u001b[32m1\u001b[39m,:]\n\u001b[32m 7\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m top_k \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n",
1446
+ "\u001b[36mFile \u001b[39m\u001b[32m~\\AppData\\Roaming\\Python\\Python313\\site-packages\\torch\\nn\\modules\\module.py:1751\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1749\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1750\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1751\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
1447
+ "\u001b[36mFile \u001b[39m\u001b[32m~\\AppData\\Roaming\\Python\\Python313\\site-packages\\torch\\nn\\modules\\module.py:1762\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1757\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1758\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1759\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1760\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1761\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1762\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1764\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1765\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
1448
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[32]\u001b[39m\u001b[32m, line 16\u001b[39m, in \u001b[36mGPT_Model.forward\u001b[39m\u001b[34m(self, in_idx)\u001b[39m\n\u001b[32m 14\u001b[39m batch_size,seq_len=in_idx.shape\n\u001b[32m 15\u001b[39m in_idx = torch.clamp(in_idx, \u001b[32m0\u001b[39m, \u001b[38;5;28mself\u001b[39m.tok_emb.num_embeddings - \u001b[32m1\u001b[39m) \u001b[38;5;66;03m#This was initially commented out\u001b[39;00m\n\u001b[32m---> \u001b[39m\u001b[32m16\u001b[39m token_embeddings=\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mtok_emb\u001b[49m\u001b[43m(\u001b[49m\u001b[43min_idx\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 17\u001b[39m positions = torch.arange(seq_len, device=in_idx.device).unsqueeze(\u001b[32m0\u001b[39m) \u001b[38;5;66;03m#this is the extra added line\u001b[39;00m\n\u001b[32m 18\u001b[39m positional_embeddings=\u001b[38;5;28mself\u001b[39m.pos_emb(positions)\n",
1449
+ "\u001b[36mFile \u001b[39m\u001b[32m~\\AppData\\Roaming\\Python\\Python313\\site-packages\\torch\\nn\\modules\\module.py:1751\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1749\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1750\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1751\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
1450
+ "\u001b[36mFile \u001b[39m\u001b[32m~\\AppData\\Roaming\\Python\\Python313\\site-packages\\torch\\nn\\modules\\module.py:1762\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1757\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1758\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1759\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1760\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1761\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1762\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1764\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1765\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n",
1451
+ "\u001b[36mFile \u001b[39m\u001b[32m~\\AppData\\Roaming\\Python\\Python313\\site-packages\\torch\\nn\\modules\\sparse.py:190\u001b[39m, in \u001b[36mEmbedding.forward\u001b[39m\u001b[34m(self, input)\u001b[39m\n\u001b[32m 189\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mforward\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;28minput\u001b[39m: Tensor) -> Tensor:\n\u001b[32m--> \u001b[39m\u001b[32m190\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mF\u001b[49m\u001b[43m.\u001b[49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\n\u001b[32m 191\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[32m 192\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 193\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 194\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmax_norm\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 195\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mnorm_type\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 196\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 197\u001b[39m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43msparse\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 198\u001b[39m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
1452
+ "\u001b[36mFile \u001b[39m\u001b[32m~\\AppData\\Roaming\\Python\\Python313\\site-packages\\torch\\nn\\functional.py:2551\u001b[39m, in \u001b[36membedding\u001b[39m\u001b[34m(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)\u001b[39m\n\u001b[32m 2545\u001b[39m \u001b[38;5;66;03m# Note [embedding_renorm set_grad_enabled]\u001b[39;00m\n\u001b[32m 2546\u001b[39m \u001b[38;5;66;03m# XXX: equivalent to\u001b[39;00m\n\u001b[32m 2547\u001b[39m \u001b[38;5;66;03m# with torch.no_grad():\u001b[39;00m\n\u001b[32m 2548\u001b[39m \u001b[38;5;66;03m# torch.embedding_renorm_\u001b[39;00m\n\u001b[32m 2549\u001b[39m \u001b[38;5;66;03m# remove once script supports set_grad_enabled\u001b[39;00m\n\u001b[32m 2550\u001b[39m _no_grad_embedding_renorm_(weight, \u001b[38;5;28minput\u001b[39m, max_norm, norm_type)\n\u001b[32m-> \u001b[39m\u001b[32m2551\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mtorch\u001b[49m\u001b[43m.\u001b[49m\u001b[43membedding\u001b[49m\u001b[43m(\u001b[49m\u001b[43mweight\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mpadding_idx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mscale_grad_by_freq\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msparse\u001b[49m\u001b[43m)\u001b[49m\n",
1453
+ "\u001b[31mRuntimeError\u001b[39m: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead (while checking arguments for embedding)"
1454
+ ]
1455
+ }
1456
+ ],
1457
+ "source": [
1458
+ "while True:\n",
1459
+ " start_context=input()\n",
1460
+ " token_ids=generate_text(model=model,idx=text_to_token_ids(start_context,tokenizer),max_new_tokens=15,context_size=GPT_CONFIG[\"context_length\"],temperature=0.4,top_k=3)\n",
1461
+ " print(\"Output text:\\n\",token_ids_to_text(token_ids,tokenizer))"
1462
+ ]
1463
+ },
1464
+ {
1465
+ "cell_type": "code",
1466
+ "execution_count": 61,
1467
+ "id": "19ea61ce",
1468
+ "metadata": {},
1469
+ "outputs": [],
1470
+ "source": [
1471
+ "optimizer=torch.optim.AdamW(model.parameters(),lr=3e-4)\n",
1472
+ "torch.save({\"model weights and biases\":model.state_dict(),\n",
1473
+ " \"optimizer_weights\":optimizer.state_dict(),},\n",
1474
+ " \"GPT_model.pth\")"
1475
+ ]
1476
+ },
1477
+ {
1478
+ "cell_type": "code",
1479
+ "execution_count": null,
1480
+ "id": "d88da3c8",
1481
+ "metadata": {},
1482
+ "outputs": [
1483
+ {
1484
+ "data": {
1485
+ "text/plain": [
1486
+ "GPT_Model(\n",
1487
+ " (tok_emb): Embedding(100277, 512)\n",
1488
+ " (pos_emb): Embedding(256, 512)\n",
1489
+ " (drop_emb): Dropout(p=0.1, inplace=False)\n",
1490
+ " (trf_blocks): Sequential(\n",
1491
+ " (0): TransformerBlock(\n",
1492
+ " (att): MultiHeadAttention(\n",
1493
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1494
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1495
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1496
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1497
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1498
+ " )\n",
1499
+ " (ff): FeedForward(\n",
1500
+ " (layers): Sequential(\n",
1501
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1502
+ " (1): Swish()\n",
1503
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1504
+ " )\n",
1505
+ " )\n",
1506
+ " (norm1): LayerNormalization()\n",
1507
+ " (norm2): LayerNormalization()\n",
1508
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1509
+ " )\n",
1510
+ " (1): TransformerBlock(\n",
1511
+ " (att): MultiHeadAttention(\n",
1512
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1513
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1514
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1515
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1516
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1517
+ " )\n",
1518
+ " (ff): FeedForward(\n",
1519
+ " (layers): Sequential(\n",
1520
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1521
+ " (1): Swish()\n",
1522
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1523
+ " )\n",
1524
+ " )\n",
1525
+ " (norm1): LayerNormalization()\n",
1526
+ " (norm2): LayerNormalization()\n",
1527
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1528
+ " )\n",
1529
+ " (2): TransformerBlock(\n",
1530
+ " (att): MultiHeadAttention(\n",
1531
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1532
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1533
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1534
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1535
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1536
+ " )\n",
1537
+ " (ff): FeedForward(\n",
1538
+ " (layers): Sequential(\n",
1539
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1540
+ " (1): Swish()\n",
1541
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1542
+ " )\n",
1543
+ " )\n",
1544
+ " (norm1): LayerNormalization()\n",
1545
+ " (norm2): LayerNormalization()\n",
1546
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1547
+ " )\n",
1548
+ " (3): TransformerBlock(\n",
1549
+ " (att): MultiHeadAttention(\n",
1550
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1551
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1552
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1553
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1554
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1555
+ " )\n",
1556
+ " (ff): FeedForward(\n",
1557
+ " (layers): Sequential(\n",
1558
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1559
+ " (1): Swish()\n",
1560
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1561
+ " )\n",
1562
+ " )\n",
1563
+ " (norm1): LayerNormalization()\n",
1564
+ " (norm2): LayerNormalization()\n",
1565
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1566
+ " )\n",
1567
+ " (4): TransformerBlock(\n",
1568
+ " (att): MultiHeadAttention(\n",
1569
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1570
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1571
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1572
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1573
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1574
+ " )\n",
1575
+ " (ff): FeedForward(\n",
1576
+ " (layers): Sequential(\n",
1577
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1578
+ " (1): Swish()\n",
1579
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1580
+ " )\n",
1581
+ " )\n",
1582
+ " (norm1): LayerNormalization()\n",
1583
+ " (norm2): LayerNormalization()\n",
1584
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1585
+ " )\n",
1586
+ " (5): TransformerBlock(\n",
1587
+ " (att): MultiHeadAttention(\n",
1588
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1589
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1590
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1591
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1592
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1593
+ " )\n",
1594
+ " (ff): FeedForward(\n",
1595
+ " (layers): Sequential(\n",
1596
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1597
+ " (1): Swish()\n",
1598
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1599
+ " )\n",
1600
+ " )\n",
1601
+ " (norm1): LayerNormalization()\n",
1602
+ " (norm2): LayerNormalization()\n",
1603
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1604
+ " )\n",
1605
+ " (6): TransformerBlock(\n",
1606
+ " (att): MultiHeadAttention(\n",
1607
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1608
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1609
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1610
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1611
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1612
+ " )\n",
1613
+ " (ff): FeedForward(\n",
1614
+ " (layers): Sequential(\n",
1615
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1616
+ " (1): Swish()\n",
1617
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1618
+ " )\n",
1619
+ " )\n",
1620
+ " (norm1): LayerNormalization()\n",
1621
+ " (norm2): LayerNormalization()\n",
1622
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1623
+ " )\n",
1624
+ " (7): TransformerBlock(\n",
1625
+ " (att): MultiHeadAttention(\n",
1626
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1627
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1628
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1629
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1630
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1631
+ " )\n",
1632
+ " (ff): FeedForward(\n",
1633
+ " (layers): Sequential(\n",
1634
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1635
+ " (1): Swish()\n",
1636
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1637
+ " )\n",
1638
+ " )\n",
1639
+ " (norm1): LayerNormalization()\n",
1640
+ " (norm2): LayerNormalization()\n",
1641
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1642
+ " )\n",
1643
+ " (8): TransformerBlock(\n",
1644
+ " (att): MultiHeadAttention(\n",
1645
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1646
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1647
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1648
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1649
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1650
+ " )\n",
1651
+ " (ff): FeedForward(\n",
1652
+ " (layers): Sequential(\n",
1653
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1654
+ " (1): Swish()\n",
1655
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1656
+ " )\n",
1657
+ " )\n",
1658
+ " (norm1): LayerNormalization()\n",
1659
+ " (norm2): LayerNormalization()\n",
1660
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1661
+ " )\n",
1662
+ " (9): TransformerBlock(\n",
1663
+ " (att): MultiHeadAttention(\n",
1664
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1665
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1666
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1667
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1668
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1669
+ " )\n",
1670
+ " (ff): FeedForward(\n",
1671
+ " (layers): Sequential(\n",
1672
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1673
+ " (1): Swish()\n",
1674
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1675
+ " )\n",
1676
+ " )\n",
1677
+ " (norm1): LayerNormalization()\n",
1678
+ " (norm2): LayerNormalization()\n",
1679
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1680
+ " )\n",
1681
+ " (10): TransformerBlock(\n",
1682
+ " (att): MultiHeadAttention(\n",
1683
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1684
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1685
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1686
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1687
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1688
+ " )\n",
1689
+ " (ff): FeedForward(\n",
1690
+ " (layers): Sequential(\n",
1691
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1692
+ " (1): Swish()\n",
1693
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1694
+ " )\n",
1695
+ " )\n",
1696
+ " (norm1): LayerNormalization()\n",
1697
+ " (norm2): LayerNormalization()\n",
1698
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1699
+ " )\n",
1700
+ " (11): TransformerBlock(\n",
1701
+ " (att): MultiHeadAttention(\n",
1702
+ " (W_query): Linear(in_features=512, out_features=512, bias=False)\n",
1703
+ " (W_key): Linear(in_features=512, out_features=512, bias=False)\n",
1704
+ " (W_value): Linear(in_features=512, out_features=512, bias=False)\n",
1705
+ " (out_proj): Linear(in_features=512, out_features=512, bias=True)\n",
1706
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1707
+ " )\n",
1708
+ " (ff): FeedForward(\n",
1709
+ " (layers): Sequential(\n",
1710
+ " (0): Linear(in_features=512, out_features=2048, bias=True)\n",
1711
+ " (1): Swish()\n",
1712
+ " (2): Linear(in_features=2048, out_features=512, bias=True)\n",
1713
+ " )\n",
1714
+ " )\n",
1715
+ " (norm1): LayerNormalization()\n",
1716
+ " (norm2): LayerNormalization()\n",
1717
+ " (drop_resid): Dropout(p=0.1, inplace=False)\n",
1718
+ " )\n",
1719
+ " )\n",
1720
+ " (final_norm): LayerNormalization()\n",
1721
+ " (out_head): Linear(in_features=512, out_features=100277, bias=False)\n",
1722
+ ")"
1723
+ ]
1724
+ },
1725
+ "execution_count": 44,
1726
+ "metadata": {},
1727
+ "output_type": "execute_result"
1728
+ }
1729
+ ],
1730
+ "source": [
1731
+ "#Load the weights using the following code\n",
1732
+ "#model = GPT_Model(GPT_CONFIG)\n",
1733
+ "#model.load_state_dict(torch.load(\"GPT_model.pth\"))\n",
1734
+ "#model.eval()"
1735
+ ]
1736
+ }
1737
+ ],
1738
+ "metadata": {
1739
+ "colab": {
1740
+ "provenance": []
1741
+ },
1742
+ "kernelspec": {
1743
+ "display_name": "Python 3",
1744
+ "language": "python",
1745
+ "name": "python3"
1746
+ },
1747
+ "language_info": {
1748
+ "codemirror_mode": {
1749
+ "name": "ipython",
1750
+ "version": 3
1751
+ },
1752
+ "file_extension": ".py",
1753
+ "mimetype": "text/x-python",
1754
+ "name": "python",
1755
+ "nbconvert_exporter": "python",
1756
+ "pygments_lexer": "ipython3",
1757
+ "version": "3.13.2"
1758
+ }
1759
+ },
1760
+ "nbformat": 4,
1761
+ "nbformat_minor": 5
1762
+ }