GenAI - Lab-File - Darab Khan 22SCSE1480055
GenAI - Lab-File - Darab Khan 22SCSE1480055
Submitted to Submitted by
Shiksha Singh Aryan Kumar
22SCSE1180135
1
INDEX
2
Project-1
Description: To implement CNN for classification on the dataset of COVID and NON-COVID from Kaggle.
Code:
rom PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
In [ ]:
In [ ]:
import cv2
import shutil
from glob import glob
# Helper libraries
import matplotlib.pyplot as plt
import math
%matplotlib inline
print(tf.__version__)
2.17.0
3
In [ ]:
data_root='/content/Data'
path_positive_cases = os.path.join('/content/Data/Covid')
path_negative_cases = os.path.join('/content/Data/NoCovid')
In [ ]:
positive_images_ls = glob(os.path.join(path_positive_cases,"*.png"))
negative_images_ls = glob(os.path.join(path_negative_cases,"*.png"))
negative_images_ls.extend(glob(os.path.join(path_negative_cases,"*.jpg")))
In [ ]:
In [ ]:
total_positive_covid = len(positive_images_ls)
total_negative_covid = len(negative_images_ls)
print("Total Positive Cases Covid19 images: {}".format(total_positive_covid))
print("Total Negative Cases Covid19 images: {}".format(total_negative_covid))
4
image_positive = cv2.imread(os.path.join(positive_images_ls[1]))
image_negative = cv2.imread(os.path.join(negative_images_ls[5]))
f = plt.figure(figsize=(8, 8))
f.add_subplot(1, 2, 1)
plt.imshow(image_negative)
f.add_subplot(1,2, 2)
plt.imshow(image_positive)
Out[ ]:
<matplotlib.image.AxesImage at 0x79568114c250>
In [ ]:
In [ ]:
print(cases['class'], num_to_select)
CT_COVID 125
CT_NonCOVID 122
In [ ]:
6
for images in cases['images']:
if images.split('/')[-1] not in (image_test_files): #exclude test files from shutil.copy
shutil.copy2(images, 'train/' + cases['class'])
In [ ]:
total_train_covid = len(os.listdir('/content/train/CT_COVID'))
total_train_noncovid = len(os.listdir('/content/train/CT_NonCOVID'))
total_test_covid = len(os.listdir('/content/test/CT_COVID'))
total_test_noncovid = len(os.listdir('/content/test/CT_NonCOVID'))
batch_size = 128
epochs = 15
IMG_HEIGHT = 150
IMG_WIDTH = 150
In [ ]:
In [ ]:
7
train_dir = os.path.join('/content/train')
test_dir = os.path.join('/content/test')
In [ ]:
test_data_gen = test_image_generator.flow_from_directory(batch_size=batch_size,
directory=test_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
In [ ]:
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
model = Sequential([
Conv2D(16, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
8
MaxPooling2D(),
Conv2D(32, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(512, activation='relu'),
Dense(1)
])
/usr/local/lib/python3.10/dist-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do
not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an
`Input(shape)` object as the first layer in the model instead.
super().__init__(activity_regularizer=activity_regularizer, **kwargs)
In [ ]:
model.compile(optimizer='adam',
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=['accuracy'])
In [ ]:
model.summary()
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━
━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━
━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ conv2d (Conv2D) │ (None, 150, 150, 16) │ 448 │
├──────────────────────────────────────┼──────────────────────
───────┼─────────────────┤
│ max_pooling2d (MaxPooling2D) │ (None, 75, 75, 16) │ 0│
9
├──────────────────────────────────────┼──────────────────────
───────┼─────────────────┤
│ conv2d_1 (Conv2D) │ (None, 75, 75, 32) │ 4,640 │
├──────────────────────────────────────┼──────────────────────
───────┼─────────────────┤
│ max_pooling2d_1 (MaxPooling2D) │ (None, 37, 37, 32) │ 0│
├──────────────────────────────────────┼──────────────────────
───────┼─────────────────┤
│ conv2d_2 (Conv2D) │ (None, 37, 37, 64) │ 18,496 │
├──────────────────────────────────────┼──────────────────────
───────┼─────────────────┤
│ max_pooling2d_2 (MaxPooling2D) │ (None, 18, 18, 64) │ 0│
├──────────────────────────────────────┼──────────────────────
───────┼─────────────────┤
│ flatten (Flatten) │ (None, 20736) │ 0│
├──────────────────────────────────────┼──────────────────────
───────┼─────────────────┤
│ dense (Dense) │ (None, 512) │ 10,617,344 │
├──────────────────────────────────────┼──────────────────────
───────┼─────────────────┤
│ dense_1 (Dense) │ (None, 1) │ 513 │
└──────────────────────────────────────┴──────────────────────
───────┴─────────────────┘
Total params: 10,641,441 (40.59 MB)
Trainable params: 10,641,441 (40.59 MB)
Non-trainable params: 0 (0.00 B)
In [ ]:
Collecting keras.preprocessing
Downloading Keras_Preprocessing-1.1.2-py2.py3-none-any.whl.metadata (1.9 kB)
Requirement already satisfied: numpy>=1.9.1 in /usr/local/lib/python3.10/dist-packages (from
keras.preprocessing) (1.26.4)
Requirement already satisfied: six>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from
keras.preprocessing) (1.16.0)
10
Downloading Keras_Preprocessing-1.1.2-py2.py3-none-any.whl (42 kB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
42.6/42.6 kB 2.7 MB/s eta 0:00:00
Installing collected packages: keras.preprocessing
Successfully installed keras.preprocessing-1.1.2
In [ ]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator # Use the correct import statement
for ImageDataGenerator
def data_generator(generator):
try:
for data_batch, labels_batch in generator:
yield data_batch, labels_batch
except Exception as e:
print(f"Error loading image: {e}")
history = model.fit(
data_generator(train_data_gen), # Wrap the generator with error handling
steps_per_epoch=total_train // batch_size,
epochs=epochs,
validation_data=data_generator(test_data_gen), # Wrap the validation generator as well
validation_steps=total_test // batch_size
)
Epoch 1/15
8/17 ━━━━━━━━━━━━━━━━━━━━ 9s 1s/step - accuracy: 0.4848 - loss: 2.5313 Error
loading image: Truncated File Read
11
/usr/lib/python3.10/contextlib.py:153: UserWarning: Your input ran out of data; interrupting training. Make
sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to
use the `.repeat()` function when building your dataset.
self.gen.throw(typ, value, traceback)
17/17 ━━━━━━━━━━━━━━━━━━━━ 32s 721ms/step - accuracy: 0.4898 - loss: 2.1836 -
val_accuracy: 0.4922 - val_loss: 0.6919
Epoch 2/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 16s 913ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.5210 - val_loss: 0.6932
Epoch 3/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 0s 29ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.5391 - val_loss: 0.6946
Epoch 4/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.4706 - val_loss: 0.6903
Epoch 5/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 56ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.5000 - val_loss: 0.6921
Epoch 6/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 46ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.5126 - val_loss: 0.6929
Epoch 7/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 54ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.5469 - val_loss: 0.6952
Epoch 8/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 45ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.4622 - val_loss: 0.6896
Epoch 9/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 53ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.5234 - val_loss: 0.6937
Epoch 10/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 48ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.4874 - val_loss: 0.6913
Epoch 11/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 56ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.5078 - val_loss: 0.6926
12
Epoch 12/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 49ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.5042 - val_loss: 0.6924
Epoch 13/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 56ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.4609 - val_loss: 0.6895
Epoch 14/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 50ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.5546 - val_loss: 0.6958
Epoch 15/15
17/17 ━━━━━━━━━━━━━━━━━━━━ 1s 57ms/step - accuracy: 0.0000e+00 - loss:
0.0000e+00 - val_accuracy: 0.4844 - val_loss: 0.6911
In [ ]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss=history.history['loss']
val_loss=history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
13
plt.show()
14
Project-1
Description: To implement Transformers for LLMs.
Code:
In [1]:
15
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib)
(0.12.1)
Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib)
(4.54.1)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib)
(1.4.7)
Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib)
(24.1)
Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib)
(10.4.0)
Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib)
(3.2.0)
Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from
matplotlib) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas) (2024.2)
Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas)
(2024.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7-
>matplotlib) (1.16.0)
Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2-
>torch) (3.0.1)
Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy-
>torch) (1.3.0)
Downloading tiktoken-0.8.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.2 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.2/1.2
MB 14.5 MB/s eta 0:00:00
Installing collected packages: tiktoken
Successfully installed tiktoken-0.8.0
In [2]:
import os
import requests
import pandas as pd
import matplotlib.pyplot as plt
import math
import tiktoken
import torch
16
import torch.nn as nn
In [3]:
# Hyperparameters
batch_size = 4 # How many batches per training step
context_length = 16 # Length of the token chunk each batch
d_model = 64 # The vector size of the token embeddings
num_layers = 8 # Number of transformer blocks
num_heads = 4 # Number of heads in Multi-head attention
learning_rate = 1e-3 # 0.001
dropout = 0.1 # Dropout rate
max_iters = 5000 # Total of training iterations
eval_interval = 50 # How often to evaluate the model
eval_iters = 20 # How many iterations to average the loss over when evaluating the model
device = 'cuda' if torch.cuda.is_available() else 'cpu' # Instead of using the cpu, we'll use the GPU if it's
available.
TORCH_SEED = 1337
torch.manual_seed(TORCH_SEED)
Out[3]:
<torch._C.Generator at 0x7a8d42e65ad0>
In [4]:
17
text = f.read()
In [5]:
18
# Get X and Y embedding
x = token_embedding_lookup_table(x_batch.data)
y = token_embedding_lookup_table(y_batch.data)
In [9]:
In [10]:
X = input_embedding_x
19
x_plot = input_embedding_x[0].detach().cpu().numpy()
print("Final Input Embedding of x: \n", pd.DataFrame(x_plot))
7 8 9 ... 54 55 56 57 \
0 1.460907 0.438821 0.568639 ... 1.452847 1.467532 1.984371 0.791459
1 0.610698 -0.504885 1.616139 ... 0.448665 -2.071404 0.672218 2.796822
2 -1.933866 -1.671411 1.150713 ... -0.686295 -0.096573 -2.178207 0.451573
3 -0.635150 0.029269 1.307228 ... -0.607196 2.235353 -0.330132 0.043395
4 -0.318029 -0.214993 0.665824 ... -0.545532 -0.345191 0.142483 1.507527
5 -1.104829 0.616293 0.783596 ... 1.080138 -0.161667 -1.531240 1.557224
6 0.084035 1.446481 0.074499 ... 0.716758 1.740306 -1.434396 0.254191
7 -1.715202 -0.558720 -1.187362 ... 0.746224 0.008394 -0.101539 -0.490031
8 -3.191279 -0.665482 -1.169012 ... -0.079880 3.122454 -2.058005 1.065758
20
9 0.723305 -1.070696 -1.178378 ... -0.280317 1.782857 -1.054796 1.835647
10 -1.642119 0.801727 -1.764521 ... -0.340848 3.639359 -0.644349 3.063182
11 0.825862 -0.736930 0.177177 ... -0.280090 1.006504 1.357095 1.059691
12 1.224846 -0.545266 -1.121737 ... -1.307938 2.227912 -1.178179 1.575556
13 -1.562514 0.118980 0.325631 ... 0.397413 0.479963 0.217271 2.055336
14 1.099753 -1.613777 -1.243054 ... 0.359625 1.258582 1.067245 1.663585
15 0.268369 -0.826254 -0.871297 ... 1.625291 0.643504 -1.032252 1.950458
58 59 60 61 62 63
0 0.456751 0.789545 -2.045897 1.099113 0.166035 3.139895
1 0.437217 2.284778 -1.708361 -0.152342 0.699953 0.343699
2 -0.500310 1.891739 0.073773 0.102590 -0.138560 0.840906
3 1.379010 1.228062 -0.214476 0.090425 -0.173942 -0.830836
4 0.027854 -0.008453 -1.213767 -0.094705 0.886935 1.103025
5 -0.348979 1.005130 -0.339587 0.136550 -0.597760 1.012034
6 -0.240991 0.542596 0.039506 2.375268 0.122078 1.598531
7 -0.253113 1.570766 -0.642553 0.589649 -0.934843 -1.100760
8 0.787592 2.288829 0.152384 1.582828 -0.278190 1.500612
9 -2.389836 2.359888 -0.895393 2.478489 1.150637 2.515316
10 1.745668 1.311434 -0.703046 0.981743 0.178330 0.443150
11 1.175118 1.527424 -0.844068 1.202420 -0.927549 0.908416
12 0.974123 1.321630 -0.581522 -1.142979 -0.997342 2.858498
13 0.841140 0.250887 1.338835 1.003394 1.427358 1.072308
14 0.892081 1.761903 0.728198 1.667248 -1.340948 0.224300
15 0.948557 2.855112 -0.750874 2.738219 0.865055 2.720319
21
# Define Query, Key, Value weight matrices
Wq = nn.Linear(d_model, d_model)
Wk = nn.Linear(d_model, d_model)
Wv = nn.Linear(d_model, d_model)
In [13]:
In [15]:
attention_score = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_model // num_heads) # [4, 4, 16, 16] #[4,
4, 16, 16] [batch_size, num_heads, context_length, context_length]
print(pd.DataFrame(attention_score[0][0].detach().cpu().numpy()))
0 1 2 3 4 5 6 \
0 0.021498 0.759700 1.114897 0.861203 0.996177 0.087225 0.020397
1 -0.186354 0.423521 0.930679 0.397063 0.689583 -0.023029 -1.295206
22
2 -0.116987 -0.521399 -0.102028 -0.048670 -0.094187 0.211903 0.785083
3 0.677378 0.094782 0.187630 1.087047 0.387743 0.641712 1.120173
4 0.631352 -0.280073 -0.932875 -0.581271 -0.206257 0.127698 -0.102794
5 -0.327146 0.068143 0.148191 -0.391794 0.132223 0.017728 0.268415
6 -0.444339 0.258554 0.455449 -0.077464 0.134794 -0.053327 0.101352
7 -0.245799 0.580802 0.932249 0.135226 0.357541 -0.006203 0.136920
8 -0.303093 0.018890 0.015307 0.294602 0.369652 -0.219553 -0.192504
9 -0.136668 0.443592 0.470723 -0.184153 0.221327 -0.058326 -0.335395
10 0.694354 0.946938 0.422884 0.789917 0.990797 -0.172340 0.704143
11 0.352159 0.977470 1.040248 1.317057 1.359986 0.120361 0.799884
12 -0.238808 0.104332 0.052890 -0.593303 -0.794684 -0.462218 0.311895
13 -0.517919 0.067904 0.636224 0.144879 0.018605 0.213238 0.280446
14 -0.085574 -0.059928 0.476232 -0.008881 -0.143431 0.318674 0.256556
15 -0.184219 -0.247370 0.608031 0.304992 0.352951 0.453829 0.285533
7 8 9 10 11 12 13 \
0 0.353933 -0.263620 -0.481791 -1.015287 -0.027799 -0.770300 1.028327
1 0.035366 -0.226877 -1.094293 -0.077335 0.149487 -1.291372 -0.246792
2 0.184757 -0.163499 0.059236 0.515760 0.117891 0.133373 0.033434
3 0.259390 -0.339914 0.120134 -0.394649 -0.246437 -0.579486 0.994677
4 -0.442630 0.287699 -0.108314 0.540231 0.316796 -0.293156 -0.427342
5 0.222270 0.331641 -0.015495 -0.366638 0.014746 -0.224710 0.378336
6 0.162352 0.333664 0.069486 -0.037877 -0.015095 -0.301473 0.237264
7 0.361665 0.171941 -0.240028 -0.460423 -0.134655 0.122417 0.401689
8 -0.606256 -0.548382 -0.053605 -0.832274 -0.339201 -0.297512 0.289678
9 0.530910 0.210424 0.113577 -0.280486 0.202264 0.089873 0.381280
10 -0.240142 -0.371990 -0.299596 -0.832197 -0.359611 -0.637496 0.765209
11 0.680050 0.324915 -0.482536 -0.499427 0.151246 -0.254691 1.035710
12 0.081620 -0.338989 -0.082385 -0.711630 -0.425134 0.018855 0.009447
13 0.551979 0.049975 -0.170539 0.008342 0.121066 0.532295 0.402487
14 0.872133 0.463599 0.522785 0.704850 0.512636 0.430275 0.164187
15 0.616212 0.059369 0.419577 0.663161 0.387546 0.333471 0.249560
23
14 15
0 0.241509 0.527406
1 -0.591006 0.052637
2 0.032123 0.521154
3 -0.021960 0.356922
4 -0.393042 -0.474839
5 0.081683 0.197241
6 0.111261 0.401928
7 0.278085 0.003031
8 0.017494 0.183737
9 0.191927 0.236189
10 0.152228 0.264614
11 0.668065 0.268745
12 -0.214531 -0.060118
13 0.094902 0.129446
14 0.546852 0.545444
15 0.546221 0.853163
In [16]:
0 1 2 3 4 5 6 \
0 0.021498 -inf -inf -inf -inf -inf -inf
1 -0.186354 0.423521 -inf -inf -inf -inf -inf
2 -0.116987 -0.521399 -0.102028 -inf -inf -inf -inf
3 0.677378 0.094782 0.187630 1.087047 -inf -inf -inf
4 0.631352 -0.280073 -0.932875 -0.581271 -0.206257 -inf -inf
5 -0.327146 0.068143 0.148191 -0.391794 0.132223 0.017728 -inf
6 -0.444339 0.258554 0.455449 -0.077464 0.134794 -0.053327 0.101352
7 -0.245799 0.580802 0.932249 0.135226 0.357541 -0.006203 0.136920
24
8 -0.303093 0.018890 0.015307 0.294602 0.369652 -0.219553 -0.192504
9 -0.136668 0.443592 0.470723 -0.184153 0.221327 -0.058326 -0.335395
10 0.694354 0.946938 0.422884 0.789917 0.990797 -0.172340 0.704143
11 0.352159 0.977470 1.040248 1.317057 1.359986 0.120361 0.799884
12 -0.238808 0.104332 0.052890 -0.593303 -0.794684 -0.462218 0.311895
13 -0.517919 0.067904 0.636224 0.144879 0.018605 0.213238 0.280446
14 -0.085574 -0.059928 0.476232 -0.008881 -0.143431 0.318674 0.256556
15 -0.184219 -0.247370 0.608031 0.304992 0.352951 0.453829 0.285533
7 8 9 10 11 12 13 \
0 -inf -inf -inf -inf -inf -inf -inf
1 -inf -inf -inf -inf -inf -inf -inf
2 -inf -inf -inf -inf -inf -inf -inf
3 -inf -inf -inf -inf -inf -inf -inf
4 -inf -inf -inf -inf -inf -inf -inf
5 -inf -inf -inf -inf -inf -inf -inf
6 -inf -inf -inf -inf -inf -inf -inf
7 0.361665 -inf -inf -inf -inf -inf -inf
8 -0.606256 -0.548382 -inf -inf -inf -inf -inf
9 0.530910 0.210424 0.113577 -inf -inf -inf -inf
10 -0.240142 -0.371990 -0.299596 -0.832197 -inf -inf -inf
11 0.680050 0.324915 -0.482536 -0.499427 0.151246 -inf -inf
12 0.081620 -0.338989 -0.082385 -0.711630 -0.425134 0.018855 -inf
13 0.551979 0.049975 -0.170539 0.008342 0.121066 0.532295 0.402487
14 0.872133 0.463599 0.522785 0.704850 0.512636 0.430275 0.164187
15 0.616212 0.059369 0.419577 0.663161 0.387546 0.333471 0.249560
14 15
0 -inf -inf
1 -inf -inf
2 -inf -inf
3 -inf -inf
4 -inf -inf
25
5 -inf -inf
6 -inf -inf
7 -inf -inf
8 -inf -inf
9 -inf -inf
10 -inf -inf
11 -inf -inf
12 -inf -inf
13 -inf -inf
14 0.546852 -inf
15 0.546221 0.853163
In [18]:
import pandas as pd
import torch
import math
0 1 2 3 4 5 6 \
0 1.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
1 0.352088 0.647912 0.000000 0.000000 0.000000 0.000000 0.000000
2 0.372795 0.248792 0.378413 0.000000 0.000000 0.000000 0.000000
3 0.271920 0.151853 0.166628 0.409599 0.000000 0.000000 0.000000
26
4 0.427102 0.171674 0.089371 0.127027 0.184826 0.000000 0.000000
5 0.124600 0.185008 0.200427 0.116800 0.197252 0.175912 0.000000
6 0.083958 0.169561 0.206461 0.121170 0.149823 0.124131 0.144895
7 0.069452 0.158736 0.225584 0.101663 0.126974 0.088256 0.101836
8 0.088831 0.122574 0.122136 0.161488 0.174074 0.096571 0.099218
9 0.073767 0.131784 0.135409 0.070346 0.105520 0.079778 0.060472
10 0.121218 0.156050 0.092400 0.133374 0.163047 0.050953 0.122411
11 0.060166 0.112440 0.119725 0.157907 0.164834 0.047718 0.094145
12 0.072694 0.102453 0.097315 0.050997 0.041695 0.058140 0.126086
13 0.034521 0.062015 0.109476 0.066978 0.059032 0.071716 0.076702
14 0.042108 0.043202 0.073850 0.045464 0.039741 0.063085 0.059286
15 0.035042 0.032898 0.077386 0.057155 0.059963 0.066327 0.056053
7 8 9 10 11 12 13 \
0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
1 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
2 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
3 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
4 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
5 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
6 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
7 0.127499 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
8 0.065600 0.069508 0.000000 0.000000 0.000000 0.000000 0.000000
9 0.143809 0.104376 0.094741 0.000000 0.000000 0.000000 0.000000
10 0.047613 0.041731 0.044864 0.026339 0.000000 0.000000 0.000000
11 0.083513 0.058549 0.026113 0.025675 0.049215 0.000000 0.000000
12 0.100152 0.065764 0.085003 0.045306 0.060336 0.094059 0.000000
13 0.100631 0.060913 0.048859 0.058429 0.065401 0.098669 0.086658
14 0.109721 0.072923 0.077370 0.092820 0.076588 0.070533 0.054055
15 0.078021 0.044707 0.064094 0.081772 0.062073 0.058806 0.054073
14 15
0 0.000000 0.000000
27
1 0.000000 0.000000
2 0.000000 0.000000
3 0.000000 0.000000
4 0.000000 0.000000
5 0.000000 0.000000
6 0.000000 0.000000
7 0.000000 0.000000
8 0.000000 0.000000
9 0.000000 0.000000
10 0.000000 0.000000
11 0.000000 0.000000
12 0.000000 0.000000
13 0.000000 0.000000
14 0.079254 0.000000
15 0.072747 0.098883
In [20]:
In [22]:
28
print(output.shape)
In [24]:
In [25]:
In [26]:
29
0 1 2 3 4 5 6 \
0 0.667483 -0.389633 -0.851936 -0.394342 -0.998247 0.262954 0.480094
1 -0.038112 0.322328 -0.031820 -0.584922 -0.112201 -0.738884 0.440067
2 -0.628515 0.205482 -0.180367 -0.370944 0.412005 0.792747 -0.194784
3 -1.048590 0.793600 -0.683289 -0.619398 -0.317157 0.628633 -0.024611
4 -0.716813 -0.497194 -0.200344 -1.548010 -0.219922 0.368706 -0.773210
5 -0.378771 1.055804 -0.537877 0.158270 -0.003828 -0.635562 0.394238
6 -0.310792 -0.171071 0.622006 -0.609642 0.586624 0.604495 0.867905
7 -0.922160 0.500379 -0.034140 0.611128 -0.062309 0.604757 0.336614
8 0.318265 0.177395 0.507006 -0.731018 -0.162960 0.023586 0.724885
9 1.179621 0.450486 -0.439893 0.052057 0.397160 -1.076628 0.009521
10 0.329482 -0.249196 -0.487763 0.146532 0.203778 0.126987 -0.013209
11 0.491254 0.254098 -0.805063 0.847444 -0.868649 -0.204639 0.187366
12 0.908632 -0.443110 0.164215 -0.860427 -0.831706 -0.385907 -0.045611
13 -0.267631 -0.724435 -0.408752 0.460729 -0.677400 0.245174 0.588925
14 -0.226249 0.275237 -0.203830 0.269060 -0.362902 0.613476 0.069186
15 0.445209 0.237637 0.394769 0.191370 0.002953 0.345110 0.509143
30
13 -0.478098 -0.365851 -1.100504 ... 0.009546 -0.100925 0.359181 -0.584053
14 -0.556882 -0.456754 0.220161 ... 0.158783 0.210422 0.217538 -0.163716
15 -0.298419 -0.693817 0.393190 ... -0.114032 0.024332 0.364460 -0.628126
31