Tinh chỉnh mô hình Gemma ở Keras bằng LoRA

Xem trên ai.google.dev Chạy trong Google Colab Mở trong Vertex AI Xem nguồn trên GitHub

Tổng quan

Gemma là một bộ mô hình mở, hiện đại và gọn nhẹ được xây dựng từ cùng một nghiên cứu và công nghệ dùng để tạo ra các mô hình Gemini.

Các mô hình ngôn ngữ lớn (LLM) như Gemma đã được chứng minh là hiệu quả trong nhiều nhiệm vụ xử lý ngôn ngữ tự nhiên (NLP). Trước tiên, LLM được huấn luyện trước trên một lượng lớn văn bản theo phương thức tự giám sát. Quá trình huấn luyện trước giúp LLM học kiến thức chung, chẳng hạn như mối quan hệ thống kê giữa các từ. Sau đó, bạn có thể tinh chỉnh LLM bằng dữ liệu dành riêng cho miền để thực hiện các tác vụ tiếp theo (chẳng hạn như phân tích cảm xúc).

Các LLM có kích thước cực kỳ lớn (các tham số được sắp xếp theo thứ tự hàng tỷ). Hầu hết các ứng dụng không cần phải tinh chỉnh toàn bộ (cập nhật tất cả các tham số trong mô hình) vì các tập dữ liệu tinh chỉnh thông thường tương đối nhỏ hơn nhiều so với các tập dữ liệu huấn luyện trước.

Thích ứng thứ hạng thấp (LoRA) là một kỹ thuật tinh chỉnh giúp giảm đáng kể số lượng tham số có thể huấn luyện cho các tác vụ ở hạ nguồn bằng cách cố định các trọng số của mô hình và chèn một số lượng trọng số mới ít hơn vào mô hình. Điều này giúp quá trình huấn luyện bằng LoRA nhanh hơn và hiệu quả hơn về bộ nhớ, đồng thời tạo ra trọng số mô hình nhỏ hơn (vài trăm MB), tất cả trong khi vẫn duy trì chất lượng của đầu ra mô hình.

Hướng dẫn này sẽ hướng dẫn bạn cách sử dụng KerasNLP để tinh chỉnh LoRA trên mô hình Gemma 2B bằng cách sử dụng Tập dữ liệu Dolly 15k của Databricks. Tập dữ liệu này chứa 15.000 cặp câu lệnh / phản hồi chất lượng cao do con người tạo ra,được thiết kế riêng để tinh chỉnh LLM.

Thiết lập

Truy cập vào Gemma

Để hoàn tất hướng dẫn này, trước tiên, bạn cần hoàn tất hướng dẫn thiết lập tại phần Thiết lập Gemma. Hướng dẫn thiết lập Gemma sẽ hướng dẫn bạn cách thực hiện những việc sau:

  • Truy cập vào Gemma trên kaggle.com.
  • Chọn một môi trường thời gian chạy Colab có đủ tài nguyên để chạy mô hình Gemma 2B.
  • Tạo và định cấu hình tên người dùng và khoá API của Kaggle.

Sau khi hoàn tất việc thiết lập Gemma, hãy chuyển sang phần tiếp theo để thiết lập biến môi trường cho môi trường Colab.

Chọn môi trường thời gian chạy

Để hoàn tất hướng dẫn này, bạn cần có một môi trường thời gian chạy Colab có đủ tài nguyên để chạy mô hình Gemma. Trong trường hợp này, bạn có thể sử dụng GPU T4:

  1. Ở phía trên bên phải của cửa sổ Colab, hãy chọn ▾ (Tuỳ chọn kết nối bổ sung).
  2. Chọn Thay đổi loại thời gian chạy.
  3. Trong phần Trình tăng tốc phần cứng, hãy chọn GPU T4.

Định cấu hình khoá API

Để sử dụng Gemma, bạn phải cung cấp tên người dùng Kaggle và khoá API Kaggle.

Để tạo khoá API Kaggle, hãy chuyển đến thẻ Account (Tài khoản) trong hồ sơ người dùng Kaggle rồi chọn Create New Token (Tạo mã thông báo mới). Thao tác này sẽ kích hoạt quá trình tải tệp kaggle.json chứa thông tin xác thực API của bạn xuống.

Trong Colab, hãy chọn Secrets (Mã xác thực) (🔑) trong ngăn bên trái rồi thêm tên người dùng Kaggle và khoá API Kaggle. Lưu trữ tên người dùng của bạn dưới tên KAGGLE_USERNAME và khoá API dưới tên KAGGLE_KEY.

Đặt các biến môi trường

Đặt biến môi trường cho KAGGLE_USERNAMEKAGGLE_KEY.

import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

Cài đặt phần phụ thuộc

Cài đặt Keras, KerasNLP và các phần phụ thuộc khác.

# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
pip install -q -U keras-nlp
pip install -q -U "keras>=3"

Chọn một phần phụ trợ

Keras là một API học sâu cấp cao, đa khung, được thiết kế để mang lại trải nghiệm đơn giản và dễ sử dụng. Khi sử dụng Keras 3, bạn có thể chạy quy trình công việc trên một trong ba phần phụ trợ: TensorFlow, JAX hoặc PyTorch.

Đối với hướng dẫn này, hãy định cấu hình phần phụ trợ cho JAX.

os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

Nhập gói

Nhập Keras và KerasNLP.

import keras
import keras_nlp

Tải tập dữ liệu

wget -O databricks-dolly-15k.jsonl https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
--2024-07-31 01:56:39--  https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl
Resolving huggingface.co (huggingface.co)... 18.164.174.23, 18.164.174.17, 18.164.174.55, ...
Connecting to huggingface.co (huggingface.co)|18.164.174.23|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7 [following]
--2024-07-31 01:56:39--  https://cdn-lfs.huggingface.co/repos/34/ac/34ac588cc580830664f592597bb6d19d61639eca33dc2d6bb0b6d833f7bfd552/2df9083338b4abd6bceb5635764dab5d833b393b55759dffb0959b6fcbf794ec?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27databricks-dolly-15k.jsonl%3B+filename%3D%22databricks-dolly-15k.jsonl%22%3B&Expires=1722650199&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMjY1MDE5OX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy5odWdnaW5nZmFjZS5jby9yZXBvcy8zNC9hYy8zNGFjNTg4Y2M1ODA4MzA2NjRmNTkyNTk3YmI2ZDE5ZDYxNjM5ZWNhMzNkYzJkNmJiMGI2ZDgzM2Y3YmZkNTUyLzJkZjkwODMzMzhiNGFiZDZiY2ViNTYzNTc2NGRhYjVkODMzYjM5M2I1NTc1OWRmZmIwOTU5YjZmY2JmNzk0ZWM%7EcmVzcG9uc2UtY29udGVudC1kaXNwb3NpdGlvbj0qIn1dfQ__&Signature=nITF8KrgvPBdCRtwfpzGV9ulH2joFLXIDct5Nq-aZqb-Eum8XiVGOai76mxahgAK2mCO4ekuNVCxVsa9Q7h40cZuzViZZC3zAF8QVQlbbkd3FBY4SN3QA4nDNQGcuRYoMKcalA9vRBasFhmdWgupxVqYgMVfJvgSApUcMHMm1HqRBn8AGKpEsaXhEMX4I0N-KtDH5ojDZjz5QBDgkWEmPYUeDQbjVHMjXsRG5z4vH3nK1W9gzC7dkWicJZlzl6iGs44w-EqnD3h-McDCgFnXUacPydm1hdgin-wutx7V4Z3Yv82Fi-TPlDYCnioesUr9Rx8xYujPuXmWP24kPca17Q__&Key-Pair-Id=K3ESJI6DHPFC7
Resolving cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)... 18.154.206.4, 18.154.206.17, 18.154.206.28, ...
Connecting to cdn-lfs.huggingface.co (cdn-lfs.huggingface.co)|18.154.206.4|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 13085339 (12M) [text/plain]
Saving to: ‘databricks-dolly-15k.jsonl’

databricks-dolly-15 100%[===================>]  12.48M  73.7MB/s    in 0.2s    

2024-07-31 01:56:40 (73.7 MB/s) - ‘databricks-dolly-15k.jsonl’ saved [13085339/13085339]

Xử lý trước dữ liệu. Hướng dẫn này sử dụng một tập hợp con gồm 1.000 ví dụ huấn luyện để thực thi sổ tay nhanh hơn. Cân nhắc sử dụng nhiều dữ liệu huấn luyện hơn để tinh chỉnh chất lượng cao hơn.

import json
data = []
with open("databricks-dolly-15k.jsonl") as file:
    for line in file:
        features = json.loads(line)
        # Filter out examples with context, to keep it simple.
        if features["context"]:
            continue
        # Format the entire example as a single string.
        template = "Instruction:\n{instruction}\n\nResponse:\n{response}"
        data.append(template.format(**features))

# Only use 1000 training examples, to keep it fast.
data = data[:1000]

Tải mô hình

KerasNLP cung cấp cách triển khai nhiều kiến trúc mô hình phổ biến. Trong hướng dẫn này, bạn sẽ tạo một mô hình bằng GemmaCausalLM, một mô hình Gemma toàn diện để lập mô hình ngôn ngữ nhân quả. Mô hình ngôn ngữ nhân quả dự đoán mã thông báo tiếp theo dựa trên mã thông báo trước đó.

Tạo mô hình bằng phương thức from_preset:

gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

Phương thức from_preset tạo thực thể cho mô hình từ một cấu trúc và trọng số đặt trước. Trong đoạn mã trên, chuỗi "gemma2_2b_en" chỉ định cấu trúc đặt trước — một mô hình Gemma với 2 tỷ tham số.

Suy luận trước khi tinh chỉnh

Trong phần này, bạn sẽ truy vấn mô hình bằng nhiều câu lệnh để xem cách mô hình phản hồi.

Lời nhắc cho chuyến đi đến Châu Âu

Truy vấn mô hình để biết các đề xuất về việc cần làm trong chuyến đi đến Châu Âu.

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
If you have any special needs, you should contact the embassy of the country that you are visiting.
You should contact the embassy of the country that I will be visiting.

What are my responsibilities when I go on a trip?

Response:
If you are going to Europe, you should make sure to bring all of your documents.
If you are going to Europe, make sure that you have all of your documents.

When do you travel abroad?

Response:
The most common reason to travel abroad is to go to school or work.
The most common reason to travel abroad is to work.

How can I get a visa to Europe?

Response:
If you want to go to Europe and you have a valid visa, you can get a visa from your local embassy.
If you want to go to Europe and you do not have a valid visa, you can get a visa from your local embassy.

When should I go to Europe?

Response:
You should go to Europe when the weather is nice.
You should go to Europe when the weather is bad.

How can I make a reservation for a trip?

Mô hình này sẽ đưa ra các mẹo chung về cách lên kế hoạch cho chuyến đi.

Câu lệnh về quá trình quang hợp theo cách dễ hiểu

Yêu cầu mô hình giải thích quá trình quang hợp sao cho đơn giản để trẻ 5 tuổi hiểu được.

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
Plants need water, air, sunlight, and carbon dioxide. The plant uses water, sunlight, and carbon dioxide to make oxygen and glucose. The process is also known as photosynthesis.

Instruction:
What is the process of photosynthesis in a plant's cells? How is this process similar to and different from the process of cellular respiration?

Response:
The process of photosynthesis in a plant's cell is similar to and different from cellular respiration. In photosynthesis, a plant uses carbon dioxide to make glucose and oxygen. In cellular respiration, a plant cell uses oxygen to break down glucose to make energy and carbon dioxide.

Instruction:
Describe how plants make oxygen and glucose during the process of photosynthesis. Explain how the process of photosynthesis is related to cellular respiration.

Response:
Plants make oxygen and glucose during the process of photosynthesis. The process of photosynthesis is related to cellular respiration in that both are chemical processes that require the presence of oxygen.

Instruction:
How does photosynthesis occur in the cells of a plant? What is the purpose for each part of the cell?

Response:
Photosynthesis occurs in the cells of a plant. The purpose of

Phản hồi của mô hình chứa những từ mà trẻ em có thể không dễ hiểu, chẳng hạn như chất diệp lục.

Tinh chỉnh LoRA

Để nhận được phản hồi tốt hơn từ mô hình, hãy tinh chỉnh mô hình bằng tính năng Điều chỉnh thứ hạng thấp (LoRA) bằng tập dữ liệu Dolly 15k của Databricks.

Hạng LoRA xác định thứ nguyên của các ma trận có thể huấn luyện được thêm vào trọng số ban đầu của LLM. Tham số này kiểm soát mức độ biểu cảm và độ chính xác của các điều chỉnh tinh chỉnh.

Thứ hạng càng cao thì có thể thay đổi càng chi tiết, nhưng cũng có nghĩa là có nhiều tham số hơn để huấn luyện. Thứ hạng thấp hơn đồng nghĩa với mức hao tổn tính toán ít hơn nhưng có thể điều chỉnh kém chính xác hơn.

Hướng dẫn này sử dụng thứ hạng LoRA là 4. Trong thực tế, hãy bắt đầu với thứ hạng tương đối nhỏ (chẳng hạn như 4, 8, 16). Cách này hiệu quả về mặt tính toán cho việc thử nghiệm. Huấn luyện mô hình bằng thứ hạng này và đánh giá mức độ cải thiện hiệu suất của tác vụ. Dần dần tăng thứ hạng trong các thử nghiệm tiếp theo và xem liệu điều đó có giúp tăng hiệu suất hay không.

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

Lưu ý rằng việc bật LoRA sẽ làm giảm đáng kể số lượng thông số có thể huấn luyện (từ 2,6 tỷ xuống còn 2,9 triệu).

# Limit the input sequence length to 256 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 256
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)
1000/1000 ━━━━━━━━━━━━━━━━━━━━ 923s 888ms/step - loss: 1.5586 - sparse_categorical_accuracy: 0.5251
<keras.src.callbacks.history.History at 0x799d04393c40>

Lưu ý về việc tinh chỉnh độ chính xác kết hợp trên GPU NVIDIA

Bạn nên sử dụng độ chính xác đầy đủ để tinh chỉnh. Khi tinh chỉnh trên GPU NVIDIA, hãy lưu ý rằng bạn có thể sử dụng độ chính xác kết hợp (keras.mixed_precision.set_global_policy('mixed_bfloat16')) để tăng tốc độ huấn luyện mà không ảnh hưởng nhiều đến chất lượng huấn luyện. Việc tinh chỉnh độ chính xác kết hợp sẽ tiêu tốn nhiều bộ nhớ hơn, vì vậy chỉ hữu ích trên các GPU lớn hơn.

Đối với suy luận, độ bán chính xác (keras.config.set_floatx("bfloat16")) sẽ hoạt động và tiết kiệm bộ nhớ trong khi độ chính xác kết hợp không áp dụng được.

# Uncomment the line below if you want to enable mixed precision training on GPUs
# keras.mixed_precision.set_global_policy('mixed_bfloat16')

Kết quả dự đoán sau khi tinh chỉnh

Sau khi điều chỉnh, câu trả lời sẽ tuân theo hướng dẫn được cung cấp trong câu lệnh.

Lời nhắc cho chuyến đi đến Châu Âu

prompt = template.format(
    instruction="What should I do on a trip to Europe?",
    response="",
)
sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
gemma_lm.compile(sampler=sampler)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
What should I do on a trip to Europe?

Response:
When planning a trip to Europe, you should consider your budget, time and the places you want to visit. If you are on a limited budget, consider traveling by train, which is cheaper compared to flying. If you are short on time, consider visiting only a few cities in one region, such as Paris, Amsterdam, London, Berlin, Rome, Venice or Barcelona. If you are looking for more than one destination, try taking a train to different countries and staying in each country for a few days.

Mô hình này hiện đề xuất các địa điểm tham quan ở Châu Âu.

Câu lệnh về quá trình quang hợp theo cách dễ hiểu

prompt = template.format(
    instruction="Explain the process of photosynthesis in a way that a child could understand.",
    response="",
)
print(gemma_lm.generate(prompt, max_length=256))
Instruction:
Explain the process of photosynthesis in a way that a child could understand.

Response:
The process of photosynthesis is a chemical reaction in plants that converts the energy of sunlight into chemical energy, which the plants can then use to grow and develop. During photosynthesis, a plant will absorb carbon dioxide (CO2) from the air and water from the soil and use the energy from the sun to produce oxygen (O2) and sugars (glucose) as a by-product.

Giờ đây, mô hình này giải thích quá trình quang hợp bằng từ ngữ đơn giản hơn.

Xin lưu ý rằng để minh hoạ, hướng dẫn này sẽ tinh chỉnh mô hình trên một tập hợp con nhỏ của tập dữ liệu chỉ trong một thời gian và có giá trị thứ hạng LoRA thấp. Để nhận được phản hồi tốt hơn từ mô hình được tinh chỉnh, bạn có thể thử nghiệm với:

  1. Tăng kích thước của tập dữ liệu tinh chỉnh
  2. Đào tạo cho các bước khác (thời gian bắt đầu của hệ thống)
  3. Đặt cấp LoRA cao hơn
  4. Sửa đổi các giá trị tham số siêu dữ liệu như learning_rateweight_decay.

Tóm tắt và các bước tiếp theo

Hướng dẫn này trình bày về việc tinh chỉnh LoRA trên mô hình Gemma bằng KerasNLP. Tiếp theo, hãy xem các tài liệu sau: