MultiModel RAG
MultiModel RAG
1 Setup
[ ]: %pip install -U "unstructured[all-docs]" pillow lxml pillow
%pip install -U chromadb tiktoken
%pip install -U langchain langchain-community langchain-openai langchain-groq
%pip install -U python_dotenv
%pip install -U langchain-ollama
%pip install -U transformers
%pip install -qU "langchain-chroma>=0.1.2"
%pip install -qU langchain-openai
%pip install PyMuPDF
[1]: True
output_path = "./pdf/"
file_path = output_path + 'attention_is_all_you_need.pdf'
# Reference: https://docs.unstructured.io/open-source/core-functionality/
↪chunking
chunks = partition_pdf(
filename=file_path,
infer_table_structure=True, # extract tables
strategy="hi_res", # mandatory to infer tables
1
extract_image_block_to_payload=True, # if true, will extract base64 for␣
↪API usage
chunking_strategy="by_title", # or 'basic'
max_characters=10000, # defaults to 500
combine_text_under_n_chars=2000, # defaults to 0
new_after_n_chars=6000,
# extract_images_in_pdf=True, # deprecated
)
/Users/a2024/miniforge3/envs/multimodelrag/lib/python3.12/site-
packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update
jupyter and ipywidgets. See
https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
[3]: len(chunks)
[3]: 17
2
[ ]: chunks[0].metadata.orig_elements
[7]: chunks[0].metadata.orig_elements[0].to_dict()
[ ]: chunks[1].to_dict()
if "CompositeElement" in str(type((chunk))):
texts.append(chunk)
[ ]: tables[0].to_dict()
[13]: tables[0].metadata.text_as_html
3
>O(n)</td><td>O(n)</td></tr><tr><td>Convolutional</td><td>O(k-n-
d?)</td><td>O(1)</td><td>O(logy(n))</td></tr><tr><td>Self-Attention
(restricted)</td><td>O(r-n-d)</td><td>ol)</td><td>O(n/r)</td></tr></table>'
images = get_images_base64(chunks)
def display_base64_image(base64_code):
# Decode the base64 string to binary
image_data = base64.b64decode(base64_code)
# Display the image
display(Image(data=image_data))
display_base64_image(images[0])
4
5
3 Summary of the Data
3.1 Text and Table Summaries
[16]: from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_ollama.llms import OllamaLLM
# Prompt
prompt_text = """
You are an assistant tasked with summarizing tables and text.
Give a concise summary of the table or text.
"""
prompt = ChatPromptTemplate.from_template(prompt_text)
# Summary chain
model = OllamaLLM(temperature=0.5, model="llama3.1:8b")
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
# Summarize tables
tables_html = [table.metadata.text_as_html for table in tables]
table_summaries = summarize_chain.batch(tables_html, {"max_concurrency": 3})
[ ]: text_summaries
[ ]: table_summaries
6
the image is part of a research paper explaining the␣
↪ transformers
architecture. Be specific about graphs, such as bar plots."""
messages = [
(
"user",
[
{"type": "text", "text": prompt_template},
{
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{image}"},
},
],
)
]
prompt = ChatPromptTemplate.from_messages(messages)
image_summaries = chain.batch(images)
[ ]: image_summaries
7
docstore=store,
id_key=id_key,
)
3.5 Load the summaries and link the to the original data
[23]: # Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [
Document(page_content=summary, metadata={id_key: doc_ids[i]}) for i,␣
↪summary in enumerate(text_summaries)
]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))
# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
Document(page_content=summary, metadata={id_key: table_ids[i]}) for i,␣
↪summary in enumerate(table_summaries)
]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))
]
retriever.vectorstore.add_documents(summary_img)
retriever.docstore.mset(list(zip(img_ids, images)))
[ ]: chunks
[ ]: display_base64_image(chunks[1])
8
[ ]: for chunk in chunks:
print(chunk)
Chunk 0
Title 4
NarrativeText 4
NarrativeText 4
UncategorizedText 4
NarrativeText 5
NarrativeText 5
Formula 5
NarrativeText 5
NarrativeText 5
Title 5
NarrativeText 5
ListItem 5
ListItem 5
ListItem 5
9
Chunk 2
ListItem 12
ListItem 12
ListItem 12
ListItem 12
ListItem 12
Footer 12
Title 13
Image 13
FigureCaption 13
Header 13
Image 14
NarrativeText 14
UncategorizedText 14
Image 15
Image 15
FigureCaption 15
Header 15
Chunk 3
Title 3
NarrativeText 3
Footer 3
Image 4
Image 4
NarrativeText 4
NarrativeText 4
Title 4
NarrativeText 4
NarrativeText 4
Formula 4
NarrativeText 4
NarrativeText 4
10
category_to_color = {
'Title': 'orchid',
'Image':'forestgreen',
'Table':'tomato',
}
#Legend
legend_handles = [pataches.Patch(color='deepskyblue', label='Text')]
for category in ['Title', 'Image', 'Table']:
if category in categorites:
legend_handles.append(
pataches.Patch(color=category_to_color[category],␣
↪label=category)
)
ax.axis('off')
ax.legend(handles=legend_handles, loc='upper right')
plt.tight_layout()
plt.show()
]
segments = [doc.metadata for doc in page_docs]
plot_pdf_with_boxes(pdf_page=pdf_page, segments=segments)
if print_text:
for doc in page_docs:
print(f'{doc.page_content}\n')
11
[43]: from langchain_core.documents import Document
def extract_page_numbers_from_chunk(chunk):
elements = chunk.metadata.orig_elements
page_numbers = set()
for element in elements:
page_numbers. add (element.metadata.page_number)
return page_numbers
extract_page_numbers_from_chunk(chunks[3])
display_chunk_pages(chunks[3])
12
13
14
4 RAG Pipeline
[44]: from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_openai import ChatOpenAI
from base64 import b64decode
def parse_docs(docs):
"""Split base64-encoded images and texts"""
b64 = []
text = []
for doc in docs:
try:
b64decode(doc)
b64.append(doc)
except Exception as e:
text.append(doc)
return {"images": b64, "texts": text}
def build_prompt(kwargs):
docs_by_type = kwargs["context"]
user_question = kwargs["question"]
context_text = ""
if len(docs_by_type["texts"]) > 0:
for text_element in docs_by_type["texts"]:
context_text += text_element.text
Context: {context_text}
Question: {user_question}
"""
if len(docs_by_type["images"]) > 0:
for image in docs_by_type["images"]:
prompt_content.append(
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image}"},
15
}
)
return ChatPromptTemplate.from_messages(
[
HumanMessage(content=prompt_content),
]
)
chain = (
{
"context": retriever | RunnableLambda(parse_docs),
"question": RunnablePassthrough(),
}
| RunnableLambda(build_prompt)
| ChatOpenAI(model="gpt-4o-mini")
| StrOutputParser()
)
chain_with_sources = {
"context": retriever | RunnableLambda(parse_docs),
"question": RunnablePassthrough(),
} | RunnablePassthrough().assign(
response=(
RunnableLambda(build_prompt)
| ChatOpenAI(model="gpt-4o-mini")
| StrOutputParser()
)
)
print(response)
2. **Dot Products**: The dot products of the queries and keys are computed,
scaled by the square root of the dimension of the keys (√dk).
16
3. **Softmax**: A softmax function is applied to the scaled dot products to
obtain the attention weights, which indicate the importance of each value based
on its corresponding key.
4. **Weighted Sum**: Finally, these weights are used to compute a weighted sum
of the values (V), resulting in the output of the attention mechanism.
This attention mechanism allows the model to focus on relevant parts of the
input sequence for each output element, enabling it to capture relationships and
dependencies effectively.
print("Response:", response['response'])
#Context
# print("\n\nContext:")
# for text in response['context']['texts']:
# print(text.text)
# print("Page number: ", text.metadata.page_number)
# print("\n" + "-"*50 + "\n")
The outputs from all the attention heads are concatenated and linearly
transformed to produce the final output. This approach enables the model to
attend to information from different representation subspaces at various
positions, enhancing its ability to learn complex patterns within the data.
17
5 THANK YOU!
18